From e594f2309b8d229848b9a149ca03ced0691af9fd Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Wed, 3 Mar 2021 22:11:27 +0000 Subject: [PATCH] POMDP computation performed for any (single) state, not just the initial one. --- .../verify/pomdps/mdp_simple.prism.props | 6 ++ prism/src/explicit/POMDPModelChecker.java | 57 +++++++++++-------- prism/src/explicit/ProbModelChecker.java | 4 +- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props b/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props index 48dc3057..8fadcee6 100644 --- a/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props +++ b/prism-tests/functionality/verify/pomdps/mdp_simple.prism.props @@ -4,11 +4,17 @@ Pmax=? [ F t=1 ]; // RESULT: 0.1 Pmin=? [ F t=1 ]; +// RESULT: 0.3 +filter(state, Pmin=? [ F t=1 ], s=2&t=0); + // RESULT: 0.3 Pmax=? [ s<=2 U t=1 ]; // RESULT: 3.0 Rmax=? [ F t>0 ]; +// RESULT: 2.0 +filter(state, Rmax=? [ F t>0 ], s=2&t=0); + // RESULT: 1.0 Rmin=? [ F t>0 ]; diff --git a/prism/src/explicit/POMDPModelChecker.java b/prism/src/explicit/POMDPModelChecker.java index 0744b26e..e249d2ff 100644 --- a/prism/src/explicit/POMDPModelChecker.java +++ b/prism/src/explicit/POMDPModelChecker.java @@ -111,15 +111,18 @@ public class POMDPModelChecker extends ProbModelChecker * @param target Target states * @param min Min or max probabilities (true=min, false=max) */ - public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet remain, BitSet target, boolean min) throws PrismException + public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet remain, BitSet target, boolean min, BitSet statesOfInterest) throws PrismException { ModelCheckerResult res = null; long timer; String stratFilename = null; - // Check for multiple initial states - if (pomdp.getNumInitialStates() > 1) { - throw new PrismNotSupportedException("POMDP model checking does not yet support multiple initial states"); + // Check we are only computing for a single state (and use initial state if unspecified) + if (statesOfInterest == null) { + statesOfInterest = new BitSet(); + statesOfInterest.set(pomdp.getFirstInitialState()); + } else if (statesOfInterest.cardinality() > 1) { + throw new PrismNotSupportedException("POMDPs can only be solved from a single start state"); } // Start probabilistic reachability @@ -132,7 +135,7 @@ public class POMDPModelChecker extends ProbModelChecker } // Compute rewards - res = computeReachProbsFixedGrid(pomdp, remain, target, min, stratFilename); + res = computeReachProbsFixedGrid(pomdp, remain, target, min, statesOfInterest.nextSetBit(0), stratFilename); // Finished probabilistic reachability timer = System.currentTimeMillis() - timer; @@ -148,14 +151,16 @@ public class POMDPModelChecker extends ProbModelChecker * i.e. compute the min/max probability of reaching a state in {@code target}, * while remaining in those in @{code remain}, * using Lovejoy's fixed-resolution grid approach. + * This only computes the probabiity from a single start state * Optionally, store optimal (memoryless) strategy info. * @param pomdp The POMDP * @param remain Remain in these states (optional: null means "all") * @param target Target states * @param min Min or max rewards (true=min, false=max) + * @param sInit State to compute for * @param strat Storage for (memoryless) strategy choice indices (ignored if null) */ - protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet remain, BitSet target, boolean min, String stratFilename) throws PrismException + protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet remain, BitSet target, boolean min, int sInit, String stratFilename) throws PrismException { // Start fixed-resolution grid approximation long timer = System.currentTimeMillis(); @@ -231,7 +236,7 @@ public class POMDPModelChecker extends ProbModelChecker // Extract (approximate) solution value for the initial belief // Also get (approximate) accuracy of result from value iteration - Belief initialBelief = pomdp.getInitialBelief(); + Belief initialBelief = Belief.pointDistribution(sInit, pomdp); double outerBound = values.apply(initialBelief); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); @@ -240,7 +245,7 @@ public class POMDPModelChecker extends ProbModelChecker // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); - POMDPStrategyModel psm = buildStrategyModel(pomdp, null, targetObs, unknownObs, backup); + POMDPStrategyModel psm = buildStrategyModel(pomdp, sInit, null, targetObs, unknownObs, backup); MDP mdp = psm.mdp; mainLog.print("Strategy-induced model: " + mdp.infoString()); // Export? @@ -289,9 +294,7 @@ public class POMDPModelChecker extends ProbModelChecker Accuracy resultAcc = resultValAndAcc.second; mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]"); double soln[] = new double[pomdp.getNumStates()]; - for (int initialState : pomdp.getInitialStates()) { - soln[initialState] = resultVal; - } + soln[sInit] = resultVal; // Return results ModelCheckerResult res = new ModelCheckerResult(); @@ -310,15 +313,18 @@ public class POMDPModelChecker extends ProbModelChecker * @param target Target states * @param min Min or max rewards (true=min, false=max) */ - public ModelCheckerResult computeReachRewards(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min) throws PrismException + public ModelCheckerResult computeReachRewards(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, BitSet statesOfInterest) throws PrismException { ModelCheckerResult res = null; long timer; String stratFilename = null; - // Check for multiple initial states - if (pomdp.getNumInitialStates() > 1) { - throw new PrismNotSupportedException("POMDP model checking does not yet support multiple initial states"); + // Check we are only computing for a single state (and use initial state if unspecified) + if (statesOfInterest == null) { + statesOfInterest = new BitSet(); + statesOfInterest.set(pomdp.getFirstInitialState()); + } else if (statesOfInterest.cardinality() > 1) { + throw new PrismNotSupportedException("POMDPs can only be solved from a single start state"); } // Start expected reachability @@ -331,7 +337,7 @@ public class POMDPModelChecker extends ProbModelChecker } // Compute rewards - res = computeReachRewardsFixedGrid(pomdp, mdpRewards, target, min, stratFilename); + res = computeReachRewardsFixedGrid(pomdp, mdpRewards, target, min, statesOfInterest.nextSetBit(0), stratFilename); // Finished expected reachability timer = System.currentTimeMillis() - timer; @@ -344,15 +350,17 @@ public class POMDPModelChecker extends ProbModelChecker /** * Compute expected reachability rewards using Lovejoy's fixed-resolution grid approach. + * This only computes the expected reward from a single start state * Optionally, store optimal (memoryless) strategy info. * @param pomdp The POMMDP * @param mdpRewards The rewards * @param target Target states * @param inf States for which reward is infinite * @param min Min or max rewards (true=min, false=max) + * @param sInit State to compute for * @param strat Storage for (memoryless) strategy choice indices (ignored if null) */ - protected ModelCheckerResult computeReachRewardsFixedGrid(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, String stratFilename) throws PrismException + protected ModelCheckerResult computeReachRewardsFixedGrid(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, int sInit, String stratFilename) throws PrismException { // Start fixed-resolution grid approximation long timer = System.currentTimeMillis(); @@ -421,7 +429,7 @@ public class POMDPModelChecker extends ProbModelChecker // Extract (approximate) solution value for the initial belief // Also get (approximate) accuracy of result from value iteration - Belief initialBelief = pomdp.getInitialBelief(); + Belief initialBelief = Belief.pointDistribution(sInit, pomdp); double outerBound = values.apply(initialBelief); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); @@ -430,7 +438,7 @@ public class POMDPModelChecker extends ProbModelChecker // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); - POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, targetObs, unknownObs, backup); + POMDPStrategyModel psm = buildStrategyModel(pomdp, sInit, mdpRewards, targetObs, unknownObs, backup); MDP mdp = psm.mdp; MDPRewards mdpRewardsNew = psm.mdpRewards; mainLog.print("Strategy-induced model: " + mdp.infoString()); @@ -480,9 +488,7 @@ public class POMDPModelChecker extends ProbModelChecker Accuracy resultAcc = resultValAndAcc.second; mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]"); double soln[] = new double[pomdp.getNumStates()]; - for (int initialState : pomdp.getInitialStates()) { - soln[initialState] = resultVal; - } + soln[sInit] = resultVal; // Return results ModelCheckerResult res = new ModelCheckerResult(); @@ -704,6 +710,7 @@ public class POMDPModelChecker extends ProbModelChecker * Build a (Markov chain) model representing the fragment of the belief MDP induced by an optimal strategy. * The model is stored as an MDP to allow easier attachment of optional actions. * @param pomdp + * @param sInit * @param mdpRewards * @param vhash * @param vhash_backUp @@ -711,7 +718,7 @@ public class POMDPModelChecker extends ProbModelChecker * @param min * @param listBeliefs */ - protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup) + protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, int sInit, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup) { // Initialise model/state/rewards storage MDPSimple mdp = new MDPSimple(); @@ -720,7 +727,7 @@ public class POMDPModelChecker extends ProbModelChecker BitSet mdpTarget = new BitSet(); StateRewardsSimple stateRewards = new StateRewardsSimple(); // Add initial state - Belief initialBelief = pomdp.getInitialBelief(); + Belief initialBelief = Belief.pointDistribution(sInit, pomdp); exploredBeliefs.add(initialBelief); toBeExploredBeliefs.offer(initialBelief); mdp.addState(); @@ -996,7 +1003,7 @@ public class POMDPModelChecker extends ProbModelChecker mc.setPrecomp(false); } pomdp = new POMDPSimple(mdp); - res = mc.computeReachRewards(pomdp, null, target, min); + res = mc.computeReachRewards(pomdp, null, target, min, null); System.out.println(res.soln[init.nextSetBit(0)]); } catch (PrismException e) { System.out.println(e); diff --git a/prism/src/explicit/ProbModelChecker.java b/prism/src/explicit/ProbModelChecker.java index 067eaea0..0fdf6e56 100644 --- a/prism/src/explicit/ProbModelChecker.java +++ b/prism/src/explicit/ProbModelChecker.java @@ -852,7 +852,7 @@ public class ProbModelChecker extends NonProbModelChecker res = ((MDPModelChecker) this).computeUntilProbs((MDP) model, remain, target, minMax.isMin()); break; case POMDP: - res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, remain, target, minMax.isMin()); + res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, remain, target, minMax.isMin(), statesOfInterest); break; case STPG: res = ((STPGModelChecker) this).computeUntilProbs((STPG) model, remain, target, minMax.isMin1(), minMax.isMin2()); @@ -1155,7 +1155,7 @@ public class ProbModelChecker extends NonProbModelChecker res = ((MDPModelChecker) this).computeReachRewards((MDP) model, (MDPRewards) modelRewards, target, minMax.isMin()); break; case POMDP: - res = ((POMDPModelChecker) this).computeReachRewards((POMDP) model, (MDPRewards) modelRewards, target, minMax.isMin()); + res = ((POMDPModelChecker) this).computeReachRewards((POMDP) model, (MDPRewards) modelRewards, target, minMax.isMin(), statesOfInterest); break; case STPG: res = ((STPGModelChecker) this).computeReachRewards((STPG) model, (STPGRewards) modelRewards, target, minMax.isMin1(), minMax.isMin2());