Browse Source

Add check that POMDP states with the same observation have the same actions.

Previously, we just compared the numbers of choices.
accumulation-v4.7
Dave Parker 5 years ago
parent
commit
55bb529f48
  1. 78
      prism/src/explicit/POMDPSimple.java

78
prism/src/explicit/POMDPSimple.java

@ -27,6 +27,7 @@
package explicit;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -51,8 +52,8 @@ public class POMDPSimple extends MDPSimple implements POMDP
* i.e. the Unobservation object corresponding to each unobservation index. */
protected List<Unobservation> unobservationsList;
/** Number of choices for each observation (same in states with that observation) */
protected List<Integer> observationNumChoices;
/** One state corresponding to each observation (used to look up info about it) */
protected List<Integer> observationStates;
/** Observable assigned to each state */
protected List<Integer> observablesMap;
@ -88,7 +89,7 @@ public class POMDPSimple extends MDPSimple implements POMDP
super(pomdp);
observationsList = new ArrayList<>(pomdp.observationsList);
unobservationsList = new ArrayList<>(pomdp.unobservationsList);
observationNumChoices = new ArrayList<>(pomdp.observationNumChoices);
observationStates = new ArrayList<>(pomdp.observationStates);
observablesMap = new ArrayList<>(pomdp.observablesMap);
unobservablesMap = new ArrayList<>(pomdp.unobservablesMap);
}
@ -102,7 +103,12 @@ public class POMDPSimple extends MDPSimple implements POMDP
super(pomdp, permut);
observationsList = new ArrayList<>(pomdp.observationsList);
unobservationsList = new ArrayList<>(pomdp.unobservationsList);
observationNumChoices = new ArrayList<>(pomdp.observationNumChoices);
int numObservations = pomdp.getNumObservations();
observationStates = new ArrayList<>(numObservations);
for (int o = 0; o < numObservations; o++) {
int s = pomdp.observationStates.get(o);
observationStates.add(s == -1 ? -1 : permut[s]);
}
observablesMap = new ArrayList<Integer>(getNumStates());
unobservablesMap = new ArrayList<Integer>(getNumStates());
for (int s = 0; s < numStates; s++) {
@ -141,7 +147,7 @@ public class POMDPSimple extends MDPSimple implements POMDP
{
observationsList = new ArrayList<>();
unobservationsList = new ArrayList<>();
observationNumChoices = new ArrayList<>();
observationStates = new ArrayList<>();
observablesMap = new ArrayList<>();
unobservablesMap = new ArrayList<>();
}
@ -153,7 +159,7 @@ public class POMDPSimple extends MDPSimple implements POMDP
{
observationsList = new ArrayList<>();
unobservationsList = new ArrayList<>();
observationNumChoices = new ArrayList<>();
observationStates = new ArrayList<>();
observablesMap = new ArrayList<>(numStates);
unobservablesMap = new ArrayList<>(numStates);
for (int i = 0; i < numStates; i++) {
@ -201,17 +207,24 @@ public class POMDPSimple extends MDPSimple implements POMDP
}
/**
* Set the observation info for a state
* Set the observation info for a state.
* If the actions for existing states with this observation do not match,
* an explanatory exception is thrown (so this should be done after transitions
* have been added to the state).
*/
public void setObservation(int s, Observation observ, Unobservation unobserv) throws PrismException
{
int oIndex = observationsList.indexOf(observ);
if (oIndex == -1) {
observationsList.add(observ);
observationNumChoices.add(-1);
observationStates.add(-1);
oIndex = observationsList.size() - 1;
}
setObservation(s, oIndex);
try {
setObservation(s, oIndex);
} catch (PrismException e) {
throw new PrismException("Problem with observation " + observ + ": " + e.getMessage());
}
int unobservIndex = unobservationsList.indexOf(unobserv);
if (unobservIndex == -1) {
unobservationsList.add(unobserv);
@ -221,24 +234,43 @@ public class POMDPSimple extends MDPSimple implements POMDP
}
/**
* Assign observation o to state s
* (so observation has already been added to the list)
* Also update (and check) info about number of choices for an observation
* (so should be done after transitions have been added)
* Assign observation with index o to state s.
* (assumes observation has already been added to the list)
* If the actions for existing states with this observation do not match,
* an explanatory exception is thrown (so this should be done after transitions
* have been added to the state).
*/
protected void setObservation(int s, int o) throws PrismException
{
// Set observation
observablesMap.set(s, o);
// Update observation choice counts
// and check that this value matches
int numChoicesObs = observationNumChoices.get(o);
int numChoicesState = getNumChoices(s);
if (numChoicesObs == -1) {
observationNumChoices.set(o, numChoicesState);
} else {
if (numChoicesState != numChoicesObs) {
throw new PrismException("Conflicting numbers of choices for observation " + getObservation(o) + ": " + numChoicesObs + " vs. " + numChoicesState);
// If this is first state with this observation, store its index
int observationState = observationStates.get(o);
if (observationState == -1) {
observationStates.set(o, s);
}
// Otherwise, check that the actions for existing states with
// the same observation match this one
else {
// Get and sort action strings for existing state(s)
List<String> observationStateActions = new ArrayList<>();
int numChoices = getNumChoices(observationState);
for (int i = 0; i < numChoices; i++) {
Object action = getAction(observationState, i);
observationStateActions.add(action == null ? "" : action.toString());
}
Collections.sort(observationStateActions);
// Get and sort action strings for the new state
List<String> sActions = new ArrayList<>();
numChoices = getNumChoices(s);
for (int i = 0; i < numChoices; i++) {
Object action = getAction(s, i);
sActions.add(action == null ? "" : action.toString());
}
Collections.sort(sActions);
// Check match
if (!(observationStateActions.equals(sActions))) {
throw new PrismException("Differing actions found in states: " + observationStateActions + " vs. " + sActions);
}
}
}
@ -272,7 +304,7 @@ public class POMDPSimple extends MDPSimple implements POMDP
@Override
public int getNumChoicesForObservation(int o)
{
return observationNumChoices.get(o);
return getNumChoices(observationStates.get(o));
}
// Accessors (for POMDP)

Loading…
Cancel
Save