diff --git a/prism/src/explicit/POMDPSimple.java b/prism/src/explicit/POMDPSimple.java
index 32a06d76..9563f271 100644
--- a/prism/src/explicit/POMDPSimple.java
+++ b/prism/src/explicit/POMDPSimple.java
@@ -40,6 +40,12 @@ import prism.PrismUtils;
/**
* Simple explicit-state representation of a POMDP.
* Basically a {@link MDPSimple} with observability info.
+ *
+ * POMDPs require that states with the same observation have
+ * the same set of available actions. This class further requires
+ * that these actions appear in the same order (in terms of
+ * choice indexing) in each equivalent state. This is enforced
+ * when calling setObservation().
*/
public class POMDPSimple extends MDPSimple implements POMDP
{
@@ -261,29 +267,63 @@ public class POMDPSimple extends MDPSimple implements POMDP
// Otherwise, check that the actions for existing states with
// the same observation match this one
else {
- // Get and sort action strings for existing state(s)
- List observationStateActions = new ArrayList<>();
- int numChoices = getNumChoices(observationState);
- for (int i = 0; i < numChoices; i++) {
- Object action = getAction(observationState, i);
- observationStateActions.add(action == null ? "" : action.toString());
- }
- Collections.sort(observationStateActions);
- // Get and sort action strings for the new state
- List sActions = new ArrayList<>();
- numChoices = getNumChoices(s);
- for (int i = 0; i < numChoices; i++) {
- Object action = getAction(s, i);
- sActions.add(action == null ? "" : action.toString());
- }
- Collections.sort(sActions);
- // Check match
- if (!(observationStateActions.equals(sActions))) {
- throw new PrismException("Differing actions found in states: " + observationStateActions + " vs. " + sActions);
+ checkActionsMatchExactly(s, observationState);
+ }
+ }
+
+ /**
+ * Check that the available actions and their ordering
+ * in states s1 and s2 match, and throw an exception if not.
+ */
+ protected void checkActionsMatchExactly(int s1, int s2) throws PrismException
+ {
+ int numChoices = getNumChoices(s1);
+ if (numChoices != getNumChoices(s2)) {
+ throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2));
+ }
+ for (int i = 0; i < numChoices; i++) {
+ Object action1 = getAction(s1, i);
+ Object action2 = getAction(s2, i);
+ if (action1 == null) {
+ if (action2 != null) {
+ throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2));
+ }
+ } else {
+ if (!action1.equals(action2)) {
+ throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2));
+ }
}
}
}
+ /**
+ * Check that the *sets* of available actions in states s1 and s2 match,
+ * and throw an exception if not.
+ */
+ protected void checkActionsMatch(int s1, int s2) throws PrismException
+ {
+ // Get and sort action strings for s1
+ List s1Actions = new ArrayList<>();
+ int numChoices = getNumChoices(s1);
+ for (int i = 0; i < numChoices; i++) {
+ Object action = getAction(s1, i);
+ s1Actions.add(action == null ? "" : action.toString());
+ }
+ Collections.sort(s1Actions);
+ // Get and sort action strings for s2
+ List s2Actions = new ArrayList<>();
+ numChoices = getNumChoices(s2);
+ for (int i = 0; i < numChoices; i++) {
+ Object action = getAction(s2, i);
+ s2Actions.add(action == null ? "" : action.toString());
+ }
+ Collections.sort(s2Actions);
+ // Check match
+ if (!(s1Actions.equals(s2Actions))) {
+ throw new PrismException("Differing actions found in states: " + s1Actions + " vs. " + s2Actions);
+ }
+ }
+
// Accessors (for PartiallyObservableModel)
@Override