|
|
|
@ -27,6 +27,7 @@ |
|
|
|
package explicit; |
|
|
|
|
|
|
|
import java.util.ArrayList; |
|
|
|
import java.util.Collections; |
|
|
|
import java.util.HashMap; |
|
|
|
import java.util.List; |
|
|
|
import java.util.Map; |
|
|
|
@ -51,8 +52,8 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
* i.e. the Unobservation object corresponding to each unobservation index. */ |
|
|
|
protected List<Unobservation> unobservationsList; |
|
|
|
|
|
|
|
/** Number of choices for each observation (same in states with that observation) */ |
|
|
|
protected List<Integer> observationNumChoices; |
|
|
|
/** One state corresponding to each observation (used to look up info about it) */ |
|
|
|
protected List<Integer> observationStates; |
|
|
|
|
|
|
|
/** Observable assigned to each state */ |
|
|
|
protected List<Integer> observablesMap; |
|
|
|
@ -88,7 +89,7 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
super(pomdp); |
|
|
|
observationsList = new ArrayList<>(pomdp.observationsList); |
|
|
|
unobservationsList = new ArrayList<>(pomdp.unobservationsList); |
|
|
|
observationNumChoices = new ArrayList<>(pomdp.observationNumChoices); |
|
|
|
observationStates = new ArrayList<>(pomdp.observationStates); |
|
|
|
observablesMap = new ArrayList<>(pomdp.observablesMap); |
|
|
|
unobservablesMap = new ArrayList<>(pomdp.unobservablesMap); |
|
|
|
} |
|
|
|
@ -102,7 +103,12 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
super(pomdp, permut); |
|
|
|
observationsList = new ArrayList<>(pomdp.observationsList); |
|
|
|
unobservationsList = new ArrayList<>(pomdp.unobservationsList); |
|
|
|
observationNumChoices = new ArrayList<>(pomdp.observationNumChoices); |
|
|
|
int numObservations = pomdp.getNumObservations(); |
|
|
|
observationStates = new ArrayList<>(numObservations); |
|
|
|
for (int o = 0; o < numObservations; o++) { |
|
|
|
int s = pomdp.observationStates.get(o); |
|
|
|
observationStates.add(s == -1 ? -1 : permut[s]); |
|
|
|
} |
|
|
|
observablesMap = new ArrayList<Integer>(getNumStates()); |
|
|
|
unobservablesMap = new ArrayList<Integer>(getNumStates()); |
|
|
|
for (int s = 0; s < numStates; s++) { |
|
|
|
@ -141,7 +147,7 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
{ |
|
|
|
observationsList = new ArrayList<>(); |
|
|
|
unobservationsList = new ArrayList<>(); |
|
|
|
observationNumChoices = new ArrayList<>(); |
|
|
|
observationStates = new ArrayList<>(); |
|
|
|
observablesMap = new ArrayList<>(); |
|
|
|
unobservablesMap = new ArrayList<>(); |
|
|
|
} |
|
|
|
@ -153,7 +159,7 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
{ |
|
|
|
observationsList = new ArrayList<>(); |
|
|
|
unobservationsList = new ArrayList<>(); |
|
|
|
observationNumChoices = new ArrayList<>(); |
|
|
|
observationStates = new ArrayList<>(); |
|
|
|
observablesMap = new ArrayList<>(numStates); |
|
|
|
unobservablesMap = new ArrayList<>(numStates); |
|
|
|
for (int i = 0; i < numStates; i++) { |
|
|
|
@ -201,17 +207,24 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* Set the observation info for a state |
|
|
|
* Set the observation info for a state. |
|
|
|
* If the actions for existing states with this observation do not match, |
|
|
|
* an explanatory exception is thrown (so this should be done after transitions |
|
|
|
* have been added to the state). |
|
|
|
*/ |
|
|
|
public void setObservation(int s, Observation observ, Unobservation unobserv) throws PrismException |
|
|
|
{ |
|
|
|
int oIndex = observationsList.indexOf(observ); |
|
|
|
if (oIndex == -1) { |
|
|
|
observationsList.add(observ); |
|
|
|
observationNumChoices.add(-1); |
|
|
|
observationStates.add(-1); |
|
|
|
oIndex = observationsList.size() - 1; |
|
|
|
} |
|
|
|
setObservation(s, oIndex); |
|
|
|
try { |
|
|
|
setObservation(s, oIndex); |
|
|
|
} catch (PrismException e) { |
|
|
|
throw new PrismException("Problem with observation " + observ + ": " + e.getMessage()); |
|
|
|
} |
|
|
|
int unobservIndex = unobservationsList.indexOf(unobserv); |
|
|
|
if (unobservIndex == -1) { |
|
|
|
unobservationsList.add(unobserv); |
|
|
|
@ -221,24 +234,43 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* Assign observation o to state s |
|
|
|
* (so observation has already been added to the list) |
|
|
|
* Also update (and check) info about number of choices for an observation |
|
|
|
* (so should be done after transitions have been added) |
|
|
|
* Assign observation with index o to state s. |
|
|
|
* (assumes observation has already been added to the list) |
|
|
|
* If the actions for existing states with this observation do not match, |
|
|
|
* an explanatory exception is thrown (so this should be done after transitions |
|
|
|
* have been added to the state). |
|
|
|
*/ |
|
|
|
protected void setObservation(int s, int o) throws PrismException |
|
|
|
{ |
|
|
|
// Set observation |
|
|
|
observablesMap.set(s, o); |
|
|
|
// Update observation choice counts |
|
|
|
// and check that this value matches |
|
|
|
int numChoicesObs = observationNumChoices.get(o); |
|
|
|
int numChoicesState = getNumChoices(s); |
|
|
|
if (numChoicesObs == -1) { |
|
|
|
observationNumChoices.set(o, numChoicesState); |
|
|
|
} else { |
|
|
|
if (numChoicesState != numChoicesObs) { |
|
|
|
throw new PrismException("Conflicting numbers of choices for observation " + getObservation(o) + ": " + numChoicesObs + " vs. " + numChoicesState); |
|
|
|
// If this is first state with this observation, store its index |
|
|
|
int observationState = observationStates.get(o); |
|
|
|
if (observationState == -1) { |
|
|
|
observationStates.set(o, s); |
|
|
|
} |
|
|
|
// Otherwise, check that the actions for existing states with |
|
|
|
// the same observation match this one |
|
|
|
else { |
|
|
|
// Get and sort action strings for existing state(s) |
|
|
|
List<String> observationStateActions = new ArrayList<>(); |
|
|
|
int numChoices = getNumChoices(observationState); |
|
|
|
for (int i = 0; i < numChoices; i++) { |
|
|
|
Object action = getAction(observationState, i); |
|
|
|
observationStateActions.add(action == null ? "" : action.toString()); |
|
|
|
} |
|
|
|
Collections.sort(observationStateActions); |
|
|
|
// Get and sort action strings for the new state |
|
|
|
List<String> sActions = new ArrayList<>(); |
|
|
|
numChoices = getNumChoices(s); |
|
|
|
for (int i = 0; i < numChoices; i++) { |
|
|
|
Object action = getAction(s, i); |
|
|
|
sActions.add(action == null ? "" : action.toString()); |
|
|
|
} |
|
|
|
Collections.sort(sActions); |
|
|
|
// Check match |
|
|
|
if (!(observationStateActions.equals(sActions))) { |
|
|
|
throw new PrismException("Differing actions found in states: " + observationStateActions + " vs. " + sActions); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@ -272,7 +304,7 @@ public class POMDPSimple extends MDPSimple implements POMDP |
|
|
|
@Override |
|
|
|
public int getNumChoicesForObservation(int o) |
|
|
|
{ |
|
|
|
return observationNumChoices.get(o); |
|
|
|
return getNumChoices(observationStates.get(o)); |
|
|
|
} |
|
|
|
|
|
|
|
// Accessors (for POMDP) |
|
|
|
|