diff --git a/prism-tests/functionality/verify/pomdps/guess-multi.prism.props b/prism-tests/functionality/verify/pomdps/guess-multi.prism.props index c17a6dcb..8b9b5180 100644 --- a/prism-tests/functionality/verify/pomdps/guess-multi.prism.props +++ b/prism-tests/functionality/verify/pomdps/guess-multi.prism.props @@ -3,3 +3,9 @@ // RESULT (N=3): 1.0 // RESULT (N=4): 1.0 Pmax=? [ F "correct" ]; + +// RESULT (N=1): Infinity +// RESULT (N=2): Infinity +// RESULT (N=3): 1.5 +// RESULT (N=4): 1.5 +R{"guesses"}min=? [ F "correct" ]; diff --git a/prism/src/explicit/POMDPModelChecker.java b/prism/src/explicit/POMDPModelChecker.java index 6a845162..8b35e837 100644 --- a/prism/src/explicit/POMDPModelChecker.java +++ b/prism/src/explicit/POMDPModelChecker.java @@ -159,7 +159,6 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target/remain states - // And determine set of observations actually need to perform computation for BitSet targetObs = getObservationsMatchingStates(pomdp, target);; if (targetObs == null) { throw new PrismException("Target for reachability is not observable"); @@ -168,6 +167,9 @@ public class POMDPModelChecker extends ProbModelChecker if (remain != null && remainObs == null) { throw new PrismException("Left-hand side of until is not observable"); } + mainLog.println("target obs=" + targetObs.cardinality() + ", remain obs=" + remainObs.cardinality()); + + // Determine set of observations actually need to perform computation for BitSet unknownObs = new BitSet(); unknownObs.set(0, pomdp.getNumObservations()); unknownObs.andNot(targetObs); @@ -359,14 +361,26 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target states - // And determine set of observations actually need to perform computation for BitSet targetObs = getObservationsMatchingStates(pomdp, target);; if (targetObs == null) { throw new PrismException("Target for expected reachability is not observable"); } + + // Find _some_ of the states with infinite reward + // (those from which *every* MDP strategy has prob<1 of reaching the target, + // and therefore so does every POMDP strategy) + MDPModelChecker mcProb1 = new MDPModelChecker(this); + BitSet inf = mcProb1.prob1(pomdp, null, target, false, null); + inf.flip(0, pomdp.getNumStates()); + // Find observations for which all states are known to have inf reward + BitSet infObs = getObservationsCoveredByStates(pomdp, inf); + mainLog.println("target obs=" + targetObs.cardinality() + ", inf obs=" + infObs.cardinality()); + + // Determine set of observations actually need to perform computation for BitSet unknownObs = new BitSet(); unknownObs.set(0, pomdp.getNumObservations()); unknownObs.andNot(targetObs); + unknownObs.andNot(infObs); // Initialise the grid points (just for unknown beliefs) List gridPoints = initialiseGridPoints(pomdp, unknownObs); @@ -383,7 +397,7 @@ public class POMDPModelChecker extends ProbModelChecker vhash_backUp.put(belief, 0.0); } // Define value function for the full set of belief states - Function values = belief -> approximateReachReward(belief, vhash_backUp, targetObs); + Function values = belief -> approximateReachReward(belief, vhash_backUp, targetObs, infObs); // Define value backup function BeliefMDPBackUp backup = (belief, beliefState) -> approximateReachRewardBackup(belief, beliefState, values, min); @@ -528,6 +542,33 @@ public class POMDPModelChecker extends ProbModelChecker return setObs; } + /** + * Get a list of observations from a set of states + * (both are represented by BitSets over their indices). + * Observations are included only if all their corresponding states + * are included in the passed in set. + */ + protected BitSet getObservationsCoveredByStates(POMDP pomdp, BitSet set) throws PrismException + { + // Find observations corresponding to each state in the set + BitSet setObs = new BitSet(); + for (int s = set.nextSetBit(0); s >= 0; s = set.nextSetBit(s + 1)) { + setObs.set(pomdp.getObservation(s)); + } + // Find observations for which not all states are in the set + // and remove them from the observation set to be returned + int numStates = pomdp.getNumStates(); + for (int o = setObs.nextSetBit(0); o >= 0; o = set.nextSetBit(o + 1)) { + for (int s = 0; s < numStates; s++) { + if (pomdp.getObservation(s) == o && !set.get(s)) { + setObs.set(o, false); + break; + } + } + } + return setObs; + } + /** * Construct a list of beliefs for a grid-based approximation of the belief space. * Only beliefs with observable values from {@code unknownObs) are added. @@ -636,7 +677,7 @@ public class POMDPModelChecker extends ProbModelChecker { int numChoices = beliefMDPState.trans.size(); double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; - int chosenActionIndex = -1; + int chosenActionIndex = 0; for (int i = 0; i < numChoices; i++) { double value = beliefMDPState.rewards.get(i); for (Map.Entry entry : beliefMDPState.trans.get(i).entrySet()) { @@ -676,12 +717,16 @@ public class POMDPModelChecker extends ProbModelChecker /** * Compute the grid-based approximate value for a belief for reward reachability */ - protected double approximateReachReward(Belief belief, HashMap gridValues, BitSet targetObs) + protected double approximateReachReward(Belief belief, HashMap gridValues, BitSet targetObs, BitSet infObs) { // 0 for target states if (targetObs.get(belief.so)) { return 0.0; } + // +Inf for states in "inf" + else if (infObs.get(belief.so)) { + return Double.POSITIVE_INFINITY; + } // Otherwise approximate vie interpolation over grid points else { return interpolateOverGrid(belief, gridValues); @@ -718,7 +763,7 @@ public class POMDPModelChecker extends ProbModelChecker * @param min * @param listBeliefs */ - protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, int sInit, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup) + protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, int sInit, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup) throws PrismException { // Initialise model/state/rewards storage MDPSimple mdp = new MDPSimple(); @@ -767,9 +812,15 @@ public class POMDPModelChecker extends ProbModelChecker // Store reward too, if needed if (mdpRewards != null) { stateRewards.setStateReward(src, pomdp.getRewardAfterChoice(belief, chosenActionIndex, mdpRewards)); + } else { + stateRewards.setStateReward(src, 0.0); } + } else { + stateRewards.setStateReward(src, 0.0); } } + // Add deadlocks to unexplored (known-value) states + mdp.findDeadlocks(true); // Attach a label marking target states mdp.addLabel("target", mdpTarget); // Return