diff --git a/prism/src/explicit/POMDPSimple.java b/prism/src/explicit/POMDPSimple.java index 32a06d76..9563f271 100644 --- a/prism/src/explicit/POMDPSimple.java +++ b/prism/src/explicit/POMDPSimple.java @@ -40,6 +40,12 @@ import prism.PrismUtils; /** * Simple explicit-state representation of a POMDP. * Basically a {@link MDPSimple} with observability info. + *

+ * POMDPs require that states with the same observation have + * the same set of available actions. This class further requires + * that these actions appear in the same order (in terms of + * choice indexing) in each equivalent state. This is enforced + * when calling setObservation(). */ public class POMDPSimple extends MDPSimple implements POMDP { @@ -261,29 +267,63 @@ public class POMDPSimple extends MDPSimple implements POMDP // Otherwise, check that the actions for existing states with // the same observation match this one else { - // Get and sort action strings for existing state(s) - List observationStateActions = new ArrayList<>(); - int numChoices = getNumChoices(observationState); - for (int i = 0; i < numChoices; i++) { - Object action = getAction(observationState, i); - observationStateActions.add(action == null ? "" : action.toString()); - } - Collections.sort(observationStateActions); - // Get and sort action strings for the new state - List sActions = new ArrayList<>(); - numChoices = getNumChoices(s); - for (int i = 0; i < numChoices; i++) { - Object action = getAction(s, i); - sActions.add(action == null ? "" : action.toString()); - } - Collections.sort(sActions); - // Check match - if (!(observationStateActions.equals(sActions))) { - throw new PrismException("Differing actions found in states: " + observationStateActions + " vs. " + sActions); + checkActionsMatchExactly(s, observationState); + } + } + + /** + * Check that the available actions and their ordering + * in states s1 and s2 match, and throw an exception if not. + */ + protected void checkActionsMatchExactly(int s1, int s2) throws PrismException + { + int numChoices = getNumChoices(s1); + if (numChoices != getNumChoices(s2)) { + throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2)); + } + for (int i = 0; i < numChoices; i++) { + Object action1 = getAction(s1, i); + Object action2 = getAction(s2, i); + if (action1 == null) { + if (action2 != null) { + throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2)); + } + } else { + if (!action1.equals(action2)) { + throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2)); + } } } } + /** + * Check that the *sets* of available actions in states s1 and s2 match, + * and throw an exception if not. + */ + protected void checkActionsMatch(int s1, int s2) throws PrismException + { + // Get and sort action strings for s1 + List s1Actions = new ArrayList<>(); + int numChoices = getNumChoices(s1); + for (int i = 0; i < numChoices; i++) { + Object action = getAction(s1, i); + s1Actions.add(action == null ? "" : action.toString()); + } + Collections.sort(s1Actions); + // Get and sort action strings for s2 + List s2Actions = new ArrayList<>(); + numChoices = getNumChoices(s2); + for (int i = 0; i < numChoices; i++) { + Object action = getAction(s2, i); + s2Actions.add(action == null ? "" : action.toString()); + } + Collections.sort(s2Actions); + // Check match + if (!(s1Actions.equals(s2Actions))) { + throw new PrismException("Differing actions found in states: " + s1Actions + " vs. " + s2Actions); + } + } + // Accessors (for PartiallyObservableModel) @Override