From c8bde25ab485aabac46ef8b77774f8d4e1b9cc9a Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Sat, 27 Feb 2021 23:18:14 +0000 Subject: [PATCH] Add support for checking probabilistic until for POMDPs. --- .../verify/pomdps/mdp_simple.prism.props | 3 + prism/src/explicit/POMDPModelChecker.java | 104 +++++++++++------- prism/src/explicit/ProbModelChecker.java | 2 +- 3 files changed, 71 insertions(+), 38 deletions(-) diff --git a/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props b/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props index 5acb3363..48dc3057 100644 --- a/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props +++ b/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props @@ -4,6 +4,9 @@ Pmax=? [ F t=1 ]; // RESULT: 0.1 Pmin=? [ F t=1 ]; +// RESULT: 0.3 +Pmax=? [ s<=2 U t=1 ]; + // RESULT: 3.0 Rmax=? [ F t>0 ]; diff --git a/prism/src/explicit/POMDPModelChecker.java b/prism/src/explicit/POMDPModelChecker.java index c087d324..35ce8080 100644 --- a/prism/src/explicit/POMDPModelChecker.java +++ b/prism/src/explicit/POMDPModelChecker.java @@ -70,7 +70,7 @@ public class POMDPModelChecker extends ProbModelChecker * @param target Target states * @param min Min or max probabilities (true=min, false=max) */ - public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet target, boolean min) throws PrismException + public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet remain, BitSet target, boolean min) throws PrismException { ModelCheckerResult res = null; long timer; @@ -91,7 +91,7 @@ public class POMDPModelChecker extends ProbModelChecker } // Compute rewards - res = computeReachProbsFixedGrid(pomdp, target, min, stratFilename); + res = computeReachProbsFixedGrid(pomdp, remain, target, min, stratFilename); // Finished probabilistic reachability timer = System.currentTimeMillis() - timer; @@ -103,27 +103,43 @@ public class POMDPModelChecker extends ProbModelChecker } /** - * Compute expected reachability rewards using Lovejoy's fixed-resolution grid approach. + * Compute reachability/until probabilities, + * i.e. compute the min/max probability of reaching a state in {@code target}, + * while remaining in those in @{code remain}, + * using Lovejoy's fixed-resolution grid approach. * Optionally, store optimal (memoryless) strategy info. - * @param pomdp The POMMDP - * @param mdpRewards The rewards + * @param pomdp The POMDP + * @param remain Remain in these states (optional: null means "all") * @param target Target states - * @param inf States for which reward is infinite * @param min Min or max rewards (true=min, false=max) * @param strat Storage for (memoryless) strategy choice indices (ignored if null) */ - protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet target, boolean min, String stratFilename) throws PrismException + protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet remain, BitSet target, boolean min, String stratFilename) throws PrismException { // Start fixed-resolution grid approximation long timer = System.currentTimeMillis(); mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); - // Find out the observations for the target states + // Find out the observations for the target/remain states // And determine set of observations actually need to perform computation for - BitSet targetObs = getAndCheckTargetObservations(pomdp, target); + BitSet targetObs = null; + try { + targetObs = getObservationsMatchingStates(pomdp, target); + } catch (PrismException e) { + throw new PrismException("Target for reachability is not observable"); + } + BitSet remainObs = null; + try { + remainObs = remain == null ? null : getObservationsMatchingStates(pomdp, remain); + } catch (PrismException e) { + throw new PrismException("Left-hand side of until is not observable"); + } BitSet unknownObs = new BitSet(); unknownObs.set(0, pomdp.getNumObservations()); unknownObs.andNot(targetObs); + if (remainObs != null) { + unknownObs.and(remainObs); + } // Initialise the grid points (just for unknown beliefs) List gridPoints = initialiseGridPoints(pomdp, unknownObs); @@ -160,7 +176,7 @@ public class POMDPModelChecker extends ProbModelChecker double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); // find discretized grid points to approximate the nextBelief - value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs); + value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs, unknownObs); } if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { chosenValue = value; @@ -191,7 +207,7 @@ public class POMDPModelChecker extends ProbModelChecker // Find discretized grid points to approximate the initialBelief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = pomdp.getInitialBelief(); - double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs); + double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs, unknownObs); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result @@ -200,7 +216,7 @@ public class POMDPModelChecker extends ProbModelChecker // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); List listBeliefs = new ArrayList<>(); - MDP mdp = buildStrategyModel(pomdp, null, vhash, targetObs, min, listBeliefs).mdp; + MDP mdp = buildStrategyModel(pomdp, null, vhash, targetObs, unknownObs, min, listBeliefs).mdp; mainLog.print("Strategy-induced model: " + mdp.infoString()); // Export? if (stratFilename != null) { @@ -221,6 +237,7 @@ public class POMDPModelChecker extends ProbModelChecker mcMDP.setExportAdv(false); mcMDP.setGenStrat(false); // Solve MDP to get inner bound + // (just reachability: can ignore "remain" since violating states are absent) ModelCheckerResult mcRes = mcMDP.computeReachProbs(mdp, mdp.getLabelStates("target"), true); double innerBound = mcRes.soln[0]; Accuracy innerBoundAcc = mcRes.accuracy; @@ -261,7 +278,7 @@ public class POMDPModelChecker extends ProbModelChecker } /** - * Compute expected reachability rewards. + * Compute expected reachability rewards, * i.e. compute the min/max reward accumulated to reach a state in {@code target}. * @param pomdp The POMDP * @param mdpRewards The rewards @@ -318,7 +335,12 @@ public class POMDPModelChecker extends ProbModelChecker // Find out the observations for the target states // And determine set of observations actually need to perform computation for - BitSet targetObs = getAndCheckTargetObservations(pomdp, target); + BitSet targetObs = null; + try { + targetObs = getObservationsMatchingStates(pomdp, target); + } catch (PrismException e) { + throw new PrismException("Target for expected reachability is not observable"); + } BitSet unknownObs = new BitSet(); unknownObs.set(0, pomdp.getNumObservations()); unknownObs.andNot(targetObs); @@ -409,7 +431,7 @@ public class POMDPModelChecker extends ProbModelChecker // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); List listBeliefs = new ArrayList<>(); - POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, vhash, targetObs, min, listBeliefs); + POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, vhash, targetObs, unknownObs, min, listBeliefs); MDP mdp = psm.mdp; MDPRewards mdpRewardsNew = psm.mdpRewards; mainLog.print("Strategy-induced model: " + mdp.infoString()); @@ -473,30 +495,32 @@ public class POMDPModelChecker extends ProbModelChecker } /** - * Get a list of target observations from a set of target states + * Get a list of observations from a set of states * (both are represented by BitSets over their indices). - * Also check that the set of target states corresponds to a set - * of observations, and throw an exception if not. + * The states should correspond exactly to a set of observations, + * i.e., if a state corresponding to an observation is in the set, + * then all other states corresponding to it should also be. + * An exception is thrown if not. */ - protected BitSet getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException + protected BitSet getObservationsMatchingStates(POMDP pomdp, BitSet set) throws PrismException { - // Find observations corresponding to each state in the target - BitSet targetObs = new BitSet(); - for (int s = target.nextSetBit(0); s >= 0; s = target.nextSetBit(s + 1)) { - targetObs.set(pomdp.getObservation(s)); + // Find observations corresponding to each state in the set + BitSet setObs = new BitSet(); + for (int s = set.nextSetBit(0); s >= 0; s = set.nextSetBit(s + 1)) { + setObs.set(pomdp.getObservation(s)); } - // Recreate the set of target states from the target observations and make sure it matches - BitSet target2 = new BitSet(); + // Recreate the set of states from the observations and make sure it matches + BitSet set2 = new BitSet(); int numStates = pomdp.getNumStates(); for (int s = 0; s < numStates; s++) { - if (targetObs.get(pomdp.getObservation(s))) { - target2.set(s); + if (setObs.get(pomdp.getObservation(s))) { + set2.set(s); } } - if (!target.equals(target2)) { - throw new PrismException("Target is not observable"); + if (!set.equals(set2)) { + throw new PrismException("Set is not observable"); } - return targetObs; + return setObs; } /** @@ -572,12 +596,16 @@ public class POMDPModelChecker extends ProbModelChecker /** * Compute the grid-based approximate value for a belief for probabilistic reachability */ - protected double approximateReachProb(Belief belief, HashMap gridValues, BitSet targetObs) + protected double approximateReachProb(Belief belief, HashMap gridValues, BitSet targetObs, BitSet unknownObs) { // 1 for target states if (targetObs.get(belief.so)) { return 1.0; } + // 0 for other non-unknown states + else if (!unknownObs.get(belief.so)) { + return 0.0; + } // Otherwise approximate vie interpolation over grid points else { return interpolateOverGrid(belief, gridValues); @@ -634,7 +662,7 @@ public class POMDPModelChecker extends ProbModelChecker * @param min * @param listBeliefs */ - protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, HashMap vhash, BitSet targetObs, boolean min, List listBeliefs) + protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, HashMap vhash, BitSet targetObs, BitSet unknownObs, boolean min, List listBeliefs) { // Initialise model/state/rewards storage MDPSimple mdp = new MDPSimple(); @@ -654,11 +682,13 @@ public class POMDPModelChecker extends ProbModelChecker while (!toBeExploredBelives.isEmpty()) { Belief b = toBeExploredBelives.pollFirst(); src++; - + // Remember if this is a target state if (targetObs.get(b.so)) { mdpTarget.set(src); - } else { - extractBestActions(src, b, vhash, targetObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards); + } + // Only explore "unknown" states + if (unknownObs.get(b.so)) { + extractBestActions(src, b, vhash, targetObs, unknownObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards); } } // Attach a label marking target states @@ -681,7 +711,7 @@ public class POMDPModelChecker extends ProbModelChecker * @param min * @param beliefList */ - protected void extractBestActions(int src, Belief belief, HashMap vhash, BitSet targetObs, POMDP pomdp, MDPRewards mdpRewards, boolean min, + protected void extractBestActions(int src, Belief belief, HashMap vhash, BitSet targetObs, BitSet unknownObs, POMDP pomdp, MDPRewards mdpRewards, boolean min, IndexedSet exploredBelieves, LinkedList toBeExploredBelives, MDPSimple mdp, StateRewardsSimple stateRewards) { double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; @@ -698,7 +728,7 @@ public class POMDPModelChecker extends ProbModelChecker double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); if (mdpRewards == null) { - value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs); + value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs, unknownObs); } else { value += nextBeliefProb * approximateReachReward(nextBelief, vhash, targetObs); } diff --git a/prism/src/explicit/ProbModelChecker.java b/prism/src/explicit/ProbModelChecker.java index 17ef973a..067eaea0 100644 --- a/prism/src/explicit/ProbModelChecker.java +++ b/prism/src/explicit/ProbModelChecker.java @@ -852,7 +852,7 @@ public class ProbModelChecker extends NonProbModelChecker res = ((MDPModelChecker) this).computeUntilProbs((MDP) model, remain, target, minMax.isMin()); break; case POMDP: - res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, target, minMax.isMin()); + res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, remain, target, minMax.isMin()); break; case STPG: res = ((STPGModelChecker) this).computeUntilProbs((STPG) model, remain, target, minMax.isMin1(), minMax.isMin2());