|
|
|
@ -117,7 +117,7 @@ public class POMDPModelChecker extends ProbModelChecker |
|
|
|
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); |
|
|
|
|
|
|
|
// Find out the observations for the target states |
|
|
|
LinkedList<Integer> targetObservs = getTargetObservations(pomdp, target); |
|
|
|
LinkedList<Integer> targetObservs = getAndCheckTargetObservations(pomdp, target); |
|
|
|
|
|
|
|
// Initialise the grid points |
|
|
|
ArrayList<Belief> gridPoints = new ArrayList<>();//the set of grid points (discretized believes) |
|
|
|
@ -301,7 +301,7 @@ public class POMDPModelChecker extends ProbModelChecker |
|
|
|
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); |
|
|
|
|
|
|
|
// Find out the observations for the target states |
|
|
|
LinkedList<Integer> targetObservs = getTargetObservations(pomdp, target); |
|
|
|
LinkedList<Integer> targetObservs = getAndCheckTargetObservations(pomdp, target); |
|
|
|
|
|
|
|
// Initialise the grid points |
|
|
|
ArrayList<Belief> gridPoints = new ArrayList<>();//the set of grid points (discretized believes) |
|
|
|
@ -444,13 +444,32 @@ public class POMDPModelChecker extends ProbModelChecker |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
protected LinkedList<Integer> getTargetObservations(POMDP pomdp, BitSet target) |
|
|
|
/** |
|
|
|
* Get a list of target observations from a set of target states |
|
|
|
* (both are represented by their indices). |
|
|
|
* Also check that the set of target states corresponds to a set |
|
|
|
* of observations, and throw an exception if not. |
|
|
|
*/ |
|
|
|
protected LinkedList<Integer> getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException |
|
|
|
{ |
|
|
|
// Find observations corresponding to each state in the target |
|
|
|
TreeSet<Integer> targetObservsSet = new TreeSet<>(); |
|
|
|
for (int bit = target.nextSetBit(0); bit >= 0; bit = target.nextSetBit(bit + 1)) { |
|
|
|
targetObservsSet.add(pomdp.getObservation(bit)); |
|
|
|
for (int s = target.nextSetBit(0); s >= 0; s = target.nextSetBit(s + 1)) { |
|
|
|
targetObservsSet.add(pomdp.getObservation(s)); |
|
|
|
} |
|
|
|
LinkedList<Integer> targetObservs = new LinkedList<>(targetObservsSet); |
|
|
|
// Rereate the set of target states from the target observations |
|
|
|
// and make sure it matches |
|
|
|
BitSet target2 = new BitSet(); |
|
|
|
int numStates = pomdp.getNumStates(); |
|
|
|
for (int s = 0; s < numStates; s++) { |
|
|
|
if (targetObservs.contains(pomdp.getObservation(s))) { |
|
|
|
target2.set(s); |
|
|
|
} |
|
|
|
} |
|
|
|
if (!target.equals(target2)) { |
|
|
|
throw new PrismException("Target is not observable"); |
|
|
|
} |
|
|
|
return targetObservs; |
|
|
|
} |
|
|
|
|
|
|
|
|