From 969d7a4caf535d7f8f6fd2275bc93520af0db19c Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Fri, 26 Feb 2021 00:23:17 +0000 Subject: [PATCH] Stricter check in POMDPSimple that actions match for observations. It requires that the available actions are the same and appear in the same order (in terms of choice indexing) in each equivalent state. This always happens for PRISM models built via ConstructModel since the model construction considers actions one by one in the same order in each state (and actions must be unique for POMDPs). --- prism/src/explicit/POMDPSimple.java | 78 ++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 19 deletions(-) diff --git a/prism/src/explicit/POMDPSimple.java b/prism/src/explicit/POMDPSimple.java index 32a06d76..9563f271 100644 --- a/prism/src/explicit/POMDPSimple.java +++ b/prism/src/explicit/POMDPSimple.java @@ -40,6 +40,12 @@ import prism.PrismUtils; /** * Simple explicit-state representation of a POMDP. * Basically a {@link MDPSimple} with observability info. + *

+ * POMDPs require that states with the same observation have + * the same set of available actions. This class further requires + * that these actions appear in the same order (in terms of + * choice indexing) in each equivalent state. This is enforced + * when calling setObservation(). */ public class POMDPSimple extends MDPSimple implements POMDP { @@ -261,29 +267,63 @@ public class POMDPSimple extends MDPSimple implements POMDP // Otherwise, check that the actions for existing states with // the same observation match this one else { - // Get and sort action strings for existing state(s) - List observationStateActions = new ArrayList<>(); - int numChoices = getNumChoices(observationState); - for (int i = 0; i < numChoices; i++) { - Object action = getAction(observationState, i); - observationStateActions.add(action == null ? "" : action.toString()); - } - Collections.sort(observationStateActions); - // Get and sort action strings for the new state - List sActions = new ArrayList<>(); - numChoices = getNumChoices(s); - for (int i = 0; i < numChoices; i++) { - Object action = getAction(s, i); - sActions.add(action == null ? "" : action.toString()); - } - Collections.sort(sActions); - // Check match - if (!(observationStateActions.equals(sActions))) { - throw new PrismException("Differing actions found in states: " + observationStateActions + " vs. " + sActions); + checkActionsMatchExactly(s, observationState); + } + } + + /** + * Check that the available actions and their ordering + * in states s1 and s2 match, and throw an exception if not. + */ + protected void checkActionsMatchExactly(int s1, int s2) throws PrismException + { + int numChoices = getNumChoices(s1); + if (numChoices != getNumChoices(s2)) { + throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2)); + } + for (int i = 0; i < numChoices; i++) { + Object action1 = getAction(s1, i); + Object action2 = getAction(s2, i); + if (action1 == null) { + if (action2 != null) { + throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2)); + } + } else { + if (!action1.equals(action2)) { + throw new PrismException("Differing actions found in states: " + getAvailableActions(s1) + " vs. " + getAvailableActions(s2)); + } } } } + /** + * Check that the *sets* of available actions in states s1 and s2 match, + * and throw an exception if not. + */ + protected void checkActionsMatch(int s1, int s2) throws PrismException + { + // Get and sort action strings for s1 + List s1Actions = new ArrayList<>(); + int numChoices = getNumChoices(s1); + for (int i = 0; i < numChoices; i++) { + Object action = getAction(s1, i); + s1Actions.add(action == null ? "" : action.toString()); + } + Collections.sort(s1Actions); + // Get and sort action strings for s2 + List s2Actions = new ArrayList<>(); + numChoices = getNumChoices(s2); + for (int i = 0; i < numChoices; i++) { + Object action = getAction(s2, i); + s2Actions.add(action == null ? "" : action.toString()); + } + Collections.sort(s2Actions); + // Check match + if (!(s1Actions.equals(s2Actions))) { + throw new PrismException("Differing actions found in states: " + s1Actions + " vs. " + s2Actions); + } + } + // Accessors (for PartiallyObservableModel) @Override