diff --git a/prism/src/explicit/POMDPSimple.java b/prism/src/explicit/POMDPSimple.java index bcf77f16..afd83c38 100644 --- a/prism/src/explicit/POMDPSimple.java +++ b/prism/src/explicit/POMDPSimple.java @@ -27,6 +27,7 @@ package explicit; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -51,8 +52,8 @@ public class POMDPSimple extends MDPSimple implements POMDP * i.e. the Unobservation object corresponding to each unobservation index. */ protected List unobservationsList; - /** Number of choices for each observation (same in states with that observation) */ - protected List observationNumChoices; + /** One state corresponding to each observation (used to look up info about it) */ + protected List observationStates; /** Observable assigned to each state */ protected List observablesMap; @@ -88,7 +89,7 @@ public class POMDPSimple extends MDPSimple implements POMDP super(pomdp); observationsList = new ArrayList<>(pomdp.observationsList); unobservationsList = new ArrayList<>(pomdp.unobservationsList); - observationNumChoices = new ArrayList<>(pomdp.observationNumChoices); + observationStates = new ArrayList<>(pomdp.observationStates); observablesMap = new ArrayList<>(pomdp.observablesMap); unobservablesMap = new ArrayList<>(pomdp.unobservablesMap); } @@ -102,7 +103,12 @@ public class POMDPSimple extends MDPSimple implements POMDP super(pomdp, permut); observationsList = new ArrayList<>(pomdp.observationsList); unobservationsList = new ArrayList<>(pomdp.unobservationsList); - observationNumChoices = new ArrayList<>(pomdp.observationNumChoices); + int numObservations = pomdp.getNumObservations(); + observationStates = new ArrayList<>(numObservations); + for (int o = 0; o < numObservations; o++) { + int s = pomdp.observationStates.get(o); + observationStates.add(s == -1 ? -1 : permut[s]); + } observablesMap = new ArrayList(getNumStates()); unobservablesMap = new ArrayList(getNumStates()); for (int s = 0; s < numStates; s++) { @@ -141,7 +147,7 @@ public class POMDPSimple extends MDPSimple implements POMDP { observationsList = new ArrayList<>(); unobservationsList = new ArrayList<>(); - observationNumChoices = new ArrayList<>(); + observationStates = new ArrayList<>(); observablesMap = new ArrayList<>(); unobservablesMap = new ArrayList<>(); } @@ -153,7 +159,7 @@ public class POMDPSimple extends MDPSimple implements POMDP { observationsList = new ArrayList<>(); unobservationsList = new ArrayList<>(); - observationNumChoices = new ArrayList<>(); + observationStates = new ArrayList<>(); observablesMap = new ArrayList<>(numStates); unobservablesMap = new ArrayList<>(numStates); for (int i = 0; i < numStates; i++) { @@ -201,17 +207,24 @@ public class POMDPSimple extends MDPSimple implements POMDP } /** - * Set the observation info for a state + * Set the observation info for a state. + * If the actions for existing states with this observation do not match, + * an explanatory exception is thrown (so this should be done after transitions + * have been added to the state). */ public void setObservation(int s, Observation observ, Unobservation unobserv) throws PrismException { int oIndex = observationsList.indexOf(observ); if (oIndex == -1) { observationsList.add(observ); - observationNumChoices.add(-1); + observationStates.add(-1); oIndex = observationsList.size() - 1; } - setObservation(s, oIndex); + try { + setObservation(s, oIndex); + } catch (PrismException e) { + throw new PrismException("Problem with observation " + observ + ": " + e.getMessage()); + } int unobservIndex = unobservationsList.indexOf(unobserv); if (unobservIndex == -1) { unobservationsList.add(unobserv); @@ -221,24 +234,43 @@ public class POMDPSimple extends MDPSimple implements POMDP } /** - * Assign observation o to state s - * (so observation has already been added to the list) - * Also update (and check) info about number of choices for an observation - * (so should be done after transitions have been added) + * Assign observation with index o to state s. + * (assumes observation has already been added to the list) + * If the actions for existing states with this observation do not match, + * an explanatory exception is thrown (so this should be done after transitions + * have been added to the state). */ protected void setObservation(int s, int o) throws PrismException { // Set observation observablesMap.set(s, o); - // Update observation choice counts - // and check that this value matches - int numChoicesObs = observationNumChoices.get(o); - int numChoicesState = getNumChoices(s); - if (numChoicesObs == -1) { - observationNumChoices.set(o, numChoicesState); - } else { - if (numChoicesState != numChoicesObs) { - throw new PrismException("Conflicting numbers of choices for observation " + getObservation(o) + ": " + numChoicesObs + " vs. " + numChoicesState); + // If this is first state with this observation, store its index + int observationState = observationStates.get(o); + if (observationState == -1) { + observationStates.set(o, s); + } + // Otherwise, check that the actions for existing states with + // the same observation match this one + else { + // Get and sort action strings for existing state(s) + List observationStateActions = new ArrayList<>(); + int numChoices = getNumChoices(observationState); + for (int i = 0; i < numChoices; i++) { + Object action = getAction(observationState, i); + observationStateActions.add(action == null ? "" : action.toString()); + } + Collections.sort(observationStateActions); + // Get and sort action strings for the new state + List sActions = new ArrayList<>(); + numChoices = getNumChoices(s); + for (int i = 0; i < numChoices; i++) { + Object action = getAction(s, i); + sActions.add(action == null ? "" : action.toString()); + } + Collections.sort(sActions); + // Check match + if (!(observationStateActions.equals(sActions))) { + throw new PrismException("Differing actions found in states: " + observationStateActions + " vs. " + sActions); } } } @@ -272,7 +304,7 @@ public class POMDPSimple extends MDPSimple implements POMDP @Override public int getNumChoicesForObservation(int o) { - return observationNumChoices.get(o); + return getNumChoices(observationStates.get(o)); } // Accessors (for POMDP)