|
|
|
@ -44,6 +44,9 @@ public class MDP extends Model |
|
|
|
// Transition function (Steps) |
|
|
|
protected List<List<Distribution>> trans; |
|
|
|
|
|
|
|
// Action labels |
|
|
|
protected List<List<Object>> actions; |
|
|
|
|
|
|
|
// Rewards |
|
|
|
protected List<List<Double>> transRewards; |
|
|
|
protected Double transRewardsConstant; |
|
|
|
@ -85,6 +88,7 @@ public class MDP extends Model |
|
|
|
for (int i = 0; i < numStates; i++) { |
|
|
|
trans.add(new ArrayList<Distribution>()); |
|
|
|
} |
|
|
|
actions = null; |
|
|
|
clearAllRewards(); |
|
|
|
} |
|
|
|
|
|
|
|
@ -104,6 +108,8 @@ public class MDP extends Model |
|
|
|
} |
|
|
|
maxNumDistrsOk = false; |
|
|
|
trans.get(s).clear(); |
|
|
|
if (actions != null && actions.get(s) != null) |
|
|
|
actions.get(s).clear(); |
|
|
|
if (transRewards != null && transRewards.get(s) != null) |
|
|
|
transRewards.get(s).clear(); |
|
|
|
} |
|
|
|
@ -124,6 +130,8 @@ public class MDP extends Model |
|
|
|
{ |
|
|
|
for (int i = 0; i < numToAdd; i++) { |
|
|
|
trans.add(new ArrayList<Distribution>()); |
|
|
|
if (actions != null) |
|
|
|
actions.add(null); |
|
|
|
if (transRewards != null) |
|
|
|
transRewards.add(null); |
|
|
|
numStates++; |
|
|
|
@ -151,6 +159,9 @@ public class MDP extends Model |
|
|
|
return i; |
|
|
|
} |
|
|
|
set.add(distr); |
|
|
|
// Add null action if necessary |
|
|
|
if (actions != null && actions.get(s) != null) |
|
|
|
actions.get(s).add(null); |
|
|
|
// Add zero reward if necessary |
|
|
|
if (transRewards != null && transRewards.get(s) != null) |
|
|
|
transRewards.get(s).add(0.0); |
|
|
|
@ -170,6 +181,30 @@ public class MDP extends Model |
|
|
|
transRewardsConstant = null; |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* Set the action label for choice i in some state s. |
|
|
|
*/ |
|
|
|
public void setAction(int s, int i, Object o) |
|
|
|
{ |
|
|
|
// If no actions array created yet, create it |
|
|
|
if (actions == null) { |
|
|
|
actions = new ArrayList<List<Object>>(numStates); |
|
|
|
for (int j = 0; j < numStates; j++) |
|
|
|
actions.add(null); |
|
|
|
} |
|
|
|
// If no actions for state i yet, create list |
|
|
|
if (actions.get(s) == null) { |
|
|
|
int n = trans.get(s).size(); |
|
|
|
List<Object> list = new ArrayList<Object>(n); |
|
|
|
for (int j = 0; j < n; j++) { |
|
|
|
list.add(0.0); |
|
|
|
} |
|
|
|
actions.set(s, list); |
|
|
|
} |
|
|
|
// Set actions |
|
|
|
actions.get(s).set(i, o); |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* Set a constant reward for all transitions |
|
|
|
*/ |
|
|
|
@ -231,6 +266,17 @@ public class MDP extends Model |
|
|
|
return trans.get(s).get(i); |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* Get the action (if any) for choice i of state s. |
|
|
|
*/ |
|
|
|
public Object getAction(int s, int i) |
|
|
|
{ |
|
|
|
List<Object> list; |
|
|
|
if (actions == null || (list = actions.get(s)) == null) |
|
|
|
return null; |
|
|
|
return list.get(i); |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* Get the transition reward (if any) for choice i of state s. |
|
|
|
*/ |
|
|
|
@ -615,9 +661,10 @@ public class MDP extends Model |
|
|
|
j = -1; |
|
|
|
for (Distribution distr : trans.get(i)) { |
|
|
|
j++; |
|
|
|
out.write(i + " -> " + i + "." + j + " [ arrowhead=none,label=\"" + j + "\" ];\n"); |
|
|
|
out.write(i + "." + j + " [ shape=circle,width=0.1,height=0.1,label=\"\" ];\n"); |
|
|
|
for (Map.Entry<Integer, Double> e : distr) { |
|
|
|
out.write(i + " -> " + e.getKey() + " [ label=\""); |
|
|
|
out.write(j + ":" + e.getValue() + "\" ];\n"); |
|
|
|
out.write(i + "." + j + " -> " + e.getKey() + " [ label=\"" + e.getValue() + "\" ];\n"); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@ -657,6 +704,8 @@ public class MDP extends Model |
|
|
|
else |
|
|
|
s += ", "; |
|
|
|
s += i + ": " + trans.get(i); |
|
|
|
if (actions != null) |
|
|
|
s += actions.get(i); |
|
|
|
if (transRewards != null) |
|
|
|
s += transRewards.get(i); |
|
|
|
} |
|
|
|
@ -678,6 +727,7 @@ public class MDP extends Model |
|
|
|
return false; |
|
|
|
if (!trans.equals(mdp.trans)) |
|
|
|
return false; |
|
|
|
// TODO: compare actions (complicated: null = null,null,null,...) |
|
|
|
// TODO: compare rewards (complicated: null = 0,0,0,0) |
|
|
|
return true; |
|
|
|
} |
|
|
|
|