Browse Source

POMDP computation performed for any (single) state, not just the initial one.

accumulation-v4.7
Dave Parker 5 years ago
parent
commit
e594f2309b
  1. 6
      prism-tests/functionality/verify/pomdps/mdp_simple.prism.props
  2. 57
      prism/src/explicit/POMDPModelChecker.java
  3. 4
      prism/src/explicit/ProbModelChecker.java

6
prism-tests/functionality/verify/pomdps/mdp_simple.prism.props

@ -4,11 +4,17 @@ Pmax=? [ F t=1 ];
// RESULT: 0.1 // RESULT: 0.1
Pmin=? [ F t=1 ]; Pmin=? [ F t=1 ];
// RESULT: 0.3
filter(state, Pmin=? [ F t=1 ], s=2&t=0);
// RESULT: 0.3 // RESULT: 0.3
Pmax=? [ s<=2 U t=1 ]; Pmax=? [ s<=2 U t=1 ];
// RESULT: 3.0 // RESULT: 3.0
Rmax=? [ F t>0 ]; Rmax=? [ F t>0 ];
// RESULT: 2.0
filter(state, Rmax=? [ F t>0 ], s=2&t=0);
// RESULT: 1.0 // RESULT: 1.0
Rmin=? [ F t>0 ]; Rmin=? [ F t>0 ];

57
prism/src/explicit/POMDPModelChecker.java

@ -111,15 +111,18 @@ public class POMDPModelChecker extends ProbModelChecker
* @param target Target states * @param target Target states
* @param min Min or max probabilities (true=min, false=max) * @param min Min or max probabilities (true=min, false=max)
*/ */
public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet remain, BitSet target, boolean min) throws PrismException
public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet remain, BitSet target, boolean min, BitSet statesOfInterest) throws PrismException
{ {
ModelCheckerResult res = null; ModelCheckerResult res = null;
long timer; long timer;
String stratFilename = null; String stratFilename = null;
// Check for multiple initial states
if (pomdp.getNumInitialStates() > 1) {
throw new PrismNotSupportedException("POMDP model checking does not yet support multiple initial states");
// Check we are only computing for a single state (and use initial state if unspecified)
if (statesOfInterest == null) {
statesOfInterest = new BitSet();
statesOfInterest.set(pomdp.getFirstInitialState());
} else if (statesOfInterest.cardinality() > 1) {
throw new PrismNotSupportedException("POMDPs can only be solved from a single start state");
} }
// Start probabilistic reachability // Start probabilistic reachability
@ -132,7 +135,7 @@ public class POMDPModelChecker extends ProbModelChecker
} }
// Compute rewards // Compute rewards
res = computeReachProbsFixedGrid(pomdp, remain, target, min, stratFilename);
res = computeReachProbsFixedGrid(pomdp, remain, target, min, statesOfInterest.nextSetBit(0), stratFilename);
// Finished probabilistic reachability // Finished probabilistic reachability
timer = System.currentTimeMillis() - timer; timer = System.currentTimeMillis() - timer;
@ -148,14 +151,16 @@ public class POMDPModelChecker extends ProbModelChecker
* i.e. compute the min/max probability of reaching a state in {@code target}, * i.e. compute the min/max probability of reaching a state in {@code target},
* while remaining in those in @{code remain}, * while remaining in those in @{code remain},
* using Lovejoy's fixed-resolution grid approach. * using Lovejoy's fixed-resolution grid approach.
* This only computes the probabiity from a single start state
* Optionally, store optimal (memoryless) strategy info. * Optionally, store optimal (memoryless) strategy info.
* @param pomdp The POMDP * @param pomdp The POMDP
* @param remain Remain in these states (optional: null means "all") * @param remain Remain in these states (optional: null means "all")
* @param target Target states * @param target Target states
* @param min Min or max rewards (true=min, false=max) * @param min Min or max rewards (true=min, false=max)
* @param sInit State to compute for
* @param strat Storage for (memoryless) strategy choice indices (ignored if null) * @param strat Storage for (memoryless) strategy choice indices (ignored if null)
*/ */
protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet remain, BitSet target, boolean min, String stratFilename) throws PrismException
protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet remain, BitSet target, boolean min, int sInit, String stratFilename) throws PrismException
{ {
// Start fixed-resolution grid approximation // Start fixed-resolution grid approximation
long timer = System.currentTimeMillis(); long timer = System.currentTimeMillis();
@ -231,7 +236,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Extract (approximate) solution value for the initial belief // Extract (approximate) solution value for the initial belief
// Also get (approximate) accuracy of result from value iteration // Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
Belief initialBelief = Belief.pointDistribution(sInit, pomdp);
double outerBound = values.apply(initialBelief); double outerBound = values.apply(initialBelief);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
@ -240,7 +245,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Build DTMC to get inner bound (and strategy) // Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model..."); mainLog.println("\nBuilding strategy-induced model...");
POMDPStrategyModel psm = buildStrategyModel(pomdp, null, targetObs, unknownObs, backup);
POMDPStrategyModel psm = buildStrategyModel(pomdp, sInit, null, targetObs, unknownObs, backup);
MDP mdp = psm.mdp; MDP mdp = psm.mdp;
mainLog.print("Strategy-induced model: " + mdp.infoString()); mainLog.print("Strategy-induced model: " + mdp.infoString());
// Export? // Export?
@ -289,9 +294,7 @@ public class POMDPModelChecker extends ProbModelChecker
Accuracy resultAcc = resultValAndAcc.second; Accuracy resultAcc = resultValAndAcc.second;
mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]"); mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]");
double soln[] = new double[pomdp.getNumStates()]; double soln[] = new double[pomdp.getNumStates()];
for (int initialState : pomdp.getInitialStates()) {
soln[initialState] = resultVal;
}
soln[sInit] = resultVal;
// Return results // Return results
ModelCheckerResult res = new ModelCheckerResult(); ModelCheckerResult res = new ModelCheckerResult();
@ -310,15 +313,18 @@ public class POMDPModelChecker extends ProbModelChecker
* @param target Target states * @param target Target states
* @param min Min or max rewards (true=min, false=max) * @param min Min or max rewards (true=min, false=max)
*/ */
public ModelCheckerResult computeReachRewards(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min) throws PrismException
public ModelCheckerResult computeReachRewards(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, BitSet statesOfInterest) throws PrismException
{ {
ModelCheckerResult res = null; ModelCheckerResult res = null;
long timer; long timer;
String stratFilename = null; String stratFilename = null;
// Check for multiple initial states
if (pomdp.getNumInitialStates() > 1) {
throw new PrismNotSupportedException("POMDP model checking does not yet support multiple initial states");
// Check we are only computing for a single state (and use initial state if unspecified)
if (statesOfInterest == null) {
statesOfInterest = new BitSet();
statesOfInterest.set(pomdp.getFirstInitialState());
} else if (statesOfInterest.cardinality() > 1) {
throw new PrismNotSupportedException("POMDPs can only be solved from a single start state");
} }
// Start expected reachability // Start expected reachability
@ -331,7 +337,7 @@ public class POMDPModelChecker extends ProbModelChecker
} }
// Compute rewards // Compute rewards
res = computeReachRewardsFixedGrid(pomdp, mdpRewards, target, min, stratFilename);
res = computeReachRewardsFixedGrid(pomdp, mdpRewards, target, min, statesOfInterest.nextSetBit(0), stratFilename);
// Finished expected reachability // Finished expected reachability
timer = System.currentTimeMillis() - timer; timer = System.currentTimeMillis() - timer;
@ -344,15 +350,17 @@ public class POMDPModelChecker extends ProbModelChecker
/** /**
* Compute expected reachability rewards using Lovejoy's fixed-resolution grid approach. * Compute expected reachability rewards using Lovejoy's fixed-resolution grid approach.
* This only computes the expected reward from a single start state
* Optionally, store optimal (memoryless) strategy info. * Optionally, store optimal (memoryless) strategy info.
* @param pomdp The POMMDP * @param pomdp The POMMDP
* @param mdpRewards The rewards * @param mdpRewards The rewards
* @param target Target states * @param target Target states
* @param inf States for which reward is infinite * @param inf States for which reward is infinite
* @param min Min or max rewards (true=min, false=max) * @param min Min or max rewards (true=min, false=max)
* @param sInit State to compute for
* @param strat Storage for (memoryless) strategy choice indices (ignored if null) * @param strat Storage for (memoryless) strategy choice indices (ignored if null)
*/ */
protected ModelCheckerResult computeReachRewardsFixedGrid(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, String stratFilename) throws PrismException
protected ModelCheckerResult computeReachRewardsFixedGrid(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, int sInit, String stratFilename) throws PrismException
{ {
// Start fixed-resolution grid approximation // Start fixed-resolution grid approximation
long timer = System.currentTimeMillis(); long timer = System.currentTimeMillis();
@ -421,7 +429,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Extract (approximate) solution value for the initial belief // Extract (approximate) solution value for the initial belief
// Also get (approximate) accuracy of result from value iteration // Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
Belief initialBelief = Belief.pointDistribution(sInit, pomdp);
double outerBound = values.apply(initialBelief); double outerBound = values.apply(initialBelief);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
@ -430,7 +438,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Build DTMC to get inner bound (and strategy) // Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model..."); mainLog.println("\nBuilding strategy-induced model...");
POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, targetObs, unknownObs, backup);
POMDPStrategyModel psm = buildStrategyModel(pomdp, sInit, mdpRewards, targetObs, unknownObs, backup);
MDP mdp = psm.mdp; MDP mdp = psm.mdp;
MDPRewards mdpRewardsNew = psm.mdpRewards; MDPRewards mdpRewardsNew = psm.mdpRewards;
mainLog.print("Strategy-induced model: " + mdp.infoString()); mainLog.print("Strategy-induced model: " + mdp.infoString());
@ -480,9 +488,7 @@ public class POMDPModelChecker extends ProbModelChecker
Accuracy resultAcc = resultValAndAcc.second; Accuracy resultAcc = resultValAndAcc.second;
mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]"); mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]");
double soln[] = new double[pomdp.getNumStates()]; double soln[] = new double[pomdp.getNumStates()];
for (int initialState : pomdp.getInitialStates()) {
soln[initialState] = resultVal;
}
soln[sInit] = resultVal;
// Return results // Return results
ModelCheckerResult res = new ModelCheckerResult(); ModelCheckerResult res = new ModelCheckerResult();
@ -704,6 +710,7 @@ public class POMDPModelChecker extends ProbModelChecker
* Build a (Markov chain) model representing the fragment of the belief MDP induced by an optimal strategy. * Build a (Markov chain) model representing the fragment of the belief MDP induced by an optimal strategy.
* The model is stored as an MDP to allow easier attachment of optional actions. * The model is stored as an MDP to allow easier attachment of optional actions.
* @param pomdp * @param pomdp
* @param sInit
* @param mdpRewards * @param mdpRewards
* @param vhash * @param vhash
* @param vhash_backUp * @param vhash_backUp
@ -711,7 +718,7 @@ public class POMDPModelChecker extends ProbModelChecker
* @param min * @param min
* @param listBeliefs * @param listBeliefs
*/ */
protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup)
protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, int sInit, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup)
{ {
// Initialise model/state/rewards storage // Initialise model/state/rewards storage
MDPSimple mdp = new MDPSimple(); MDPSimple mdp = new MDPSimple();
@ -720,7 +727,7 @@ public class POMDPModelChecker extends ProbModelChecker
BitSet mdpTarget = new BitSet(); BitSet mdpTarget = new BitSet();
StateRewardsSimple stateRewards = new StateRewardsSimple(); StateRewardsSimple stateRewards = new StateRewardsSimple();
// Add initial state // Add initial state
Belief initialBelief = pomdp.getInitialBelief();
Belief initialBelief = Belief.pointDistribution(sInit, pomdp);
exploredBeliefs.add(initialBelief); exploredBeliefs.add(initialBelief);
toBeExploredBeliefs.offer(initialBelief); toBeExploredBeliefs.offer(initialBelief);
mdp.addState(); mdp.addState();
@ -996,7 +1003,7 @@ public class POMDPModelChecker extends ProbModelChecker
mc.setPrecomp(false); mc.setPrecomp(false);
} }
pomdp = new POMDPSimple(mdp); pomdp = new POMDPSimple(mdp);
res = mc.computeReachRewards(pomdp, null, target, min);
res = mc.computeReachRewards(pomdp, null, target, min, null);
System.out.println(res.soln[init.nextSetBit(0)]); System.out.println(res.soln[init.nextSetBit(0)]);
} catch (PrismException e) { } catch (PrismException e) {
System.out.println(e); System.out.println(e);

4
prism/src/explicit/ProbModelChecker.java

@ -852,7 +852,7 @@ public class ProbModelChecker extends NonProbModelChecker
res = ((MDPModelChecker) this).computeUntilProbs((MDP) model, remain, target, minMax.isMin()); res = ((MDPModelChecker) this).computeUntilProbs((MDP) model, remain, target, minMax.isMin());
break; break;
case POMDP: case POMDP:
res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, remain, target, minMax.isMin());
res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, remain, target, minMax.isMin(), statesOfInterest);
break; break;
case STPG: case STPG:
res = ((STPGModelChecker) this).computeUntilProbs((STPG) model, remain, target, minMax.isMin1(), minMax.isMin2()); res = ((STPGModelChecker) this).computeUntilProbs((STPG) model, remain, target, minMax.isMin1(), minMax.isMin2());
@ -1155,7 +1155,7 @@ public class ProbModelChecker extends NonProbModelChecker
res = ((MDPModelChecker) this).computeReachRewards((MDP) model, (MDPRewards) modelRewards, target, minMax.isMin()); res = ((MDPModelChecker) this).computeReachRewards((MDP) model, (MDPRewards) modelRewards, target, minMax.isMin());
break; break;
case POMDP: case POMDP:
res = ((POMDPModelChecker) this).computeReachRewards((POMDP) model, (MDPRewards) modelRewards, target, minMax.isMin());
res = ((POMDPModelChecker) this).computeReachRewards((POMDP) model, (MDPRewards) modelRewards, target, minMax.isMin(), statesOfInterest);
break; break;
case STPG: case STPG:
res = ((STPGModelChecker) this).computeReachRewards((STPG) model, (STPGRewards) modelRewards, target, minMax.isMin1(), minMax.isMin2()); res = ((STPGModelChecker) this).computeReachRewards((STPG) model, (STPGRewards) modelRewards, target, minMax.isMin1(), minMax.isMin2());

Loading…
Cancel
Save