|
|
|
@ -32,10 +32,8 @@ import java.util.List; |
|
|
|
|
|
|
|
public class STPGRewardsSimple extends MDPRewardsSimple implements STPGRewards |
|
|
|
{ |
|
|
|
/** Nested transition rewards (level 1) */ |
|
|
|
protected List<List<Double>> nestedTransRewards1; |
|
|
|
/** Nested transition rewards (level 2) */ |
|
|
|
protected List<List<List<Double>>> nestedTransRewards2; |
|
|
|
/** Nested transition rewards */ |
|
|
|
protected List<List<List<Double>>> nestedTransRewards; |
|
|
|
|
|
|
|
/** |
|
|
|
* Constructor: all zero rewards. |
|
|
|
@ -44,46 +42,12 @@ public class STPGRewardsSimple extends MDPRewardsSimple implements STPGRewards |
|
|
|
public STPGRewardsSimple(int numStates) |
|
|
|
{ |
|
|
|
super(numStates); |
|
|
|
// Initially lists are just null (denoting all 0) |
|
|
|
nestedTransRewards1 = null; |
|
|
|
nestedTransRewards2 = null; |
|
|
|
// Initially list is just null (denoting all 0) |
|
|
|
nestedTransRewards = null; |
|
|
|
} |
|
|
|
|
|
|
|
// Mutators |
|
|
|
|
|
|
|
/** |
|
|
|
* Set the reward for the {@code i}th nested transition of state {@code s} to {@code r}. |
|
|
|
*/ |
|
|
|
public void setNestedTransitionReward(int s, int i, double r) |
|
|
|
{ |
|
|
|
List<Double> list; |
|
|
|
// Nothing to do for zero reward |
|
|
|
if (r == 0.0) |
|
|
|
return; |
|
|
|
// If no rewards array created yet, create it |
|
|
|
if (nestedTransRewards1 == null) { |
|
|
|
nestedTransRewards1 = new ArrayList<List<Double>>(numStates); |
|
|
|
for (int j = 0; j < numStates; j++) |
|
|
|
nestedTransRewards1.add(null); |
|
|
|
} |
|
|
|
// If no rewards for state s yet, create list |
|
|
|
if (nestedTransRewards1.get(s) == null) { |
|
|
|
list = new ArrayList<Double>(); |
|
|
|
nestedTransRewards1.set(s, list); |
|
|
|
} else { |
|
|
|
list = nestedTransRewards1.get(s); |
|
|
|
} |
|
|
|
// If list not big enough, extend |
|
|
|
int n = i - list.size() + 1; |
|
|
|
if (n > 0) { |
|
|
|
for (int j = 0; j < n; j++) { |
|
|
|
list.add(0.0); |
|
|
|
} |
|
|
|
} |
|
|
|
// Set reward |
|
|
|
list.set(i, r); |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* Set the reward for the {@code i},{@code j}th nested transition of state {@code s} to {@code r}. |
|
|
|
*/ |
|
|
|
@ -95,17 +59,17 @@ public class STPGRewardsSimple extends MDPRewardsSimple implements STPGRewards |
|
|
|
if (r == 0.0) |
|
|
|
return; |
|
|
|
// If no rewards array created yet, create it |
|
|
|
if (nestedTransRewards2 == null) { |
|
|
|
nestedTransRewards2 = new ArrayList<List<List<Double>>>(numStates); |
|
|
|
if (nestedTransRewards == null) { |
|
|
|
nestedTransRewards = new ArrayList<List<List<Double>>>(numStates); |
|
|
|
for (int k = 0; k < numStates; k++) |
|
|
|
nestedTransRewards2.add(null); |
|
|
|
nestedTransRewards.add(null); |
|
|
|
} |
|
|
|
// If no rewards for state s yet, create list1 |
|
|
|
if (nestedTransRewards2.get(s) == null) { |
|
|
|
if (nestedTransRewards.get(s) == null) { |
|
|
|
list1 = new ArrayList<List<Double>>(); |
|
|
|
nestedTransRewards2.set(s, list1); |
|
|
|
nestedTransRewards.set(s, list1); |
|
|
|
} else { |
|
|
|
list1 = nestedTransRewards2.get(s); |
|
|
|
list1 = nestedTransRewards.get(s); |
|
|
|
} |
|
|
|
// If list1 not big enough, extend |
|
|
|
int n1 = i - list1.size() + 1; |
|
|
|
@ -138,33 +102,19 @@ public class STPGRewardsSimple extends MDPRewardsSimple implements STPGRewards |
|
|
|
public void clearRewards(int s) |
|
|
|
{ |
|
|
|
super.clearRewards(s); |
|
|
|
if (nestedTransRewards1 != null && nestedTransRewards1.size() > s) { |
|
|
|
nestedTransRewards1.set(s, null); |
|
|
|
} |
|
|
|
if (nestedTransRewards2 != null && nestedTransRewards2.size() > s) { |
|
|
|
nestedTransRewards2.set(s, null); |
|
|
|
if (nestedTransRewards != null && nestedTransRewards.size() > s) { |
|
|
|
nestedTransRewards.set(s, null); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Accessors |
|
|
|
|
|
|
|
@Override |
|
|
|
public double getNestedTransitionReward(int s, int i) |
|
|
|
{ |
|
|
|
List<Double> list; |
|
|
|
if (nestedTransRewards1 == null || (list = nestedTransRewards1.get(s)) == null) |
|
|
|
return 0.0; |
|
|
|
if (list.size() <= i) |
|
|
|
return 0.0; |
|
|
|
return list.get(i); |
|
|
|
} |
|
|
|
|
|
|
|
@Override |
|
|
|
public double getNestedTransitionReward(int s, int i, int j) |
|
|
|
{ |
|
|
|
List<List<Double>> list1; |
|
|
|
List<Double> list2; |
|
|
|
if (nestedTransRewards2 == null || (list1 = nestedTransRewards2.get(s)) == null) |
|
|
|
if (nestedTransRewards == null || (list1 = nestedTransRewards.get(s)) == null) |
|
|
|
return 0.0; |
|
|
|
if (list1.size() <= i || (list2 = list1.get(i)) == null) |
|
|
|
return 0.0; |
|
|
|
@ -176,6 +126,6 @@ public class STPGRewardsSimple extends MDPRewardsSimple implements STPGRewards |
|
|
|
@Override |
|
|
|
public String toString() |
|
|
|
{ |
|
|
|
return super.toString() + "; ntr1: " + nestedTransRewards1 + "; ntr2:" + nestedTransRewards2; |
|
|
|
return super.toString() + "; ntr:" + nestedTransRewards; |
|
|
|
} |
|
|
|
} |