Browse Source

Add support for checking probabilistic until for POMDPs.

accumulation-v4.7
Dave Parker 5 years ago
parent
commit
c8bde25ab4
  1. 3
      prism-tests/functionality/verify/pomdps/mdp_simple.prism.props
  2. 104
      prism/src/explicit/POMDPModelChecker.java
  3. 2
      prism/src/explicit/ProbModelChecker.java

3
prism-tests/functionality/verify/pomdps/mdp_simple.prism.props

@ -4,6 +4,9 @@ Pmax=? [ F t=1 ];
// RESULT: 0.1
Pmin=? [ F t=1 ];
// RESULT: 0.3
Pmax=? [ s<=2 U t=1 ];
// RESULT: 3.0
Rmax=? [ F t>0 ];

104
prism/src/explicit/POMDPModelChecker.java

@ -70,7 +70,7 @@ public class POMDPModelChecker extends ProbModelChecker
* @param target Target states
* @param min Min or max probabilities (true=min, false=max)
*/
public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet target, boolean min) throws PrismException
public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet remain, BitSet target, boolean min) throws PrismException
{
ModelCheckerResult res = null;
long timer;
@ -91,7 +91,7 @@ public class POMDPModelChecker extends ProbModelChecker
}
// Compute rewards
res = computeReachProbsFixedGrid(pomdp, target, min, stratFilename);
res = computeReachProbsFixedGrid(pomdp, remain, target, min, stratFilename);
// Finished probabilistic reachability
timer = System.currentTimeMillis() - timer;
@ -103,27 +103,43 @@ public class POMDPModelChecker extends ProbModelChecker
}
/**
* Compute expected reachability rewards using Lovejoy's fixed-resolution grid approach.
* Compute reachability/until probabilities,
* i.e. compute the min/max probability of reaching a state in {@code target},
* while remaining in those in @{code remain},
* using Lovejoy's fixed-resolution grid approach.
* Optionally, store optimal (memoryless) strategy info.
* @param pomdp The POMMDP
* @param mdpRewards The rewards
* @param pomdp The POMDP
* @param remain Remain in these states (optional: null means "all")
* @param target Target states
* @param inf States for which reward is infinite
* @param min Min or max rewards (true=min, false=max)
* @param strat Storage for (memoryless) strategy choice indices (ignored if null)
*/
protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet target, boolean min, String stratFilename) throws PrismException
protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet remain, BitSet target, boolean min, String stratFilename) throws PrismException
{
// Start fixed-resolution grid approximation
long timer = System.currentTimeMillis();
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")...");
// Find out the observations for the target states
// Find out the observations for the target/remain states
// And determine set of observations actually need to perform computation for
BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
BitSet targetObs = null;
try {
targetObs = getObservationsMatchingStates(pomdp, target);
} catch (PrismException e) {
throw new PrismException("Target for reachability is not observable");
}
BitSet remainObs = null;
try {
remainObs = remain == null ? null : getObservationsMatchingStates(pomdp, remain);
} catch (PrismException e) {
throw new PrismException("Left-hand side of until is not observable");
}
BitSet unknownObs = new BitSet();
unknownObs.set(0, pomdp.getNumObservations());
unknownObs.andNot(targetObs);
if (remainObs != null) {
unknownObs.and(remainObs);
}
// Initialise the grid points (just for unknown beliefs)
List<Belief> gridPoints = initialiseGridPoints(pomdp, unknownObs);
@ -160,7 +176,7 @@ public class POMDPModelChecker extends ProbModelChecker
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs);
value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs, unknownObs);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
@ -191,7 +207,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs);
double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs, unknownObs);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@ -200,7 +216,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model...");
List<Belief> listBeliefs = new ArrayList<>();
MDP mdp = buildStrategyModel(pomdp, null, vhash, targetObs, min, listBeliefs).mdp;
MDP mdp = buildStrategyModel(pomdp, null, vhash, targetObs, unknownObs, min, listBeliefs).mdp;
mainLog.print("Strategy-induced model: " + mdp.infoString());
// Export?
if (stratFilename != null) {
@ -221,6 +237,7 @@ public class POMDPModelChecker extends ProbModelChecker
mcMDP.setExportAdv(false);
mcMDP.setGenStrat(false);
// Solve MDP to get inner bound
// (just reachability: can ignore "remain" since violating states are absent)
ModelCheckerResult mcRes = mcMDP.computeReachProbs(mdp, mdp.getLabelStates("target"), true);
double innerBound = mcRes.soln[0];
Accuracy innerBoundAcc = mcRes.accuracy;
@ -261,7 +278,7 @@ public class POMDPModelChecker extends ProbModelChecker
}
/**
* Compute expected reachability rewards.
* Compute expected reachability rewards,
* i.e. compute the min/max reward accumulated to reach a state in {@code target}.
* @param pomdp The POMDP
* @param mdpRewards The rewards
@ -318,7 +335,12 @@ public class POMDPModelChecker extends ProbModelChecker
// Find out the observations for the target states
// And determine set of observations actually need to perform computation for
BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
BitSet targetObs = null;
try {
targetObs = getObservationsMatchingStates(pomdp, target);
} catch (PrismException e) {
throw new PrismException("Target for expected reachability is not observable");
}
BitSet unknownObs = new BitSet();
unknownObs.set(0, pomdp.getNumObservations());
unknownObs.andNot(targetObs);
@ -409,7 +431,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model...");
List<Belief> listBeliefs = new ArrayList<>();
POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, vhash, targetObs, min, listBeliefs);
POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, vhash, targetObs, unknownObs, min, listBeliefs);
MDP mdp = psm.mdp;
MDPRewards mdpRewardsNew = psm.mdpRewards;
mainLog.print("Strategy-induced model: " + mdp.infoString());
@ -473,30 +495,32 @@ public class POMDPModelChecker extends ProbModelChecker
}
/**
* Get a list of target observations from a set of target states
* Get a list of observations from a set of states
* (both are represented by BitSets over their indices).
* Also check that the set of target states corresponds to a set
* of observations, and throw an exception if not.
* The states should correspond exactly to a set of observations,
* i.e., if a state corresponding to an observation is in the set,
* then all other states corresponding to it should also be.
* An exception is thrown if not.
*/
protected BitSet getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException
protected BitSet getObservationsMatchingStates(POMDP pomdp, BitSet set) throws PrismException
{
// Find observations corresponding to each state in the target
BitSet targetObs = new BitSet();
for (int s = target.nextSetBit(0); s >= 0; s = target.nextSetBit(s + 1)) {
targetObs.set(pomdp.getObservation(s));
// Find observations corresponding to each state in the set
BitSet setObs = new BitSet();
for (int s = set.nextSetBit(0); s >= 0; s = set.nextSetBit(s + 1)) {
setObs.set(pomdp.getObservation(s));
}
// Recreate the set of target states from the target observations and make sure it matches
BitSet target2 = new BitSet();
// Recreate the set of states from the observations and make sure it matches
BitSet set2 = new BitSet();
int numStates = pomdp.getNumStates();
for (int s = 0; s < numStates; s++) {
if (targetObs.get(pomdp.getObservation(s))) {
target2.set(s);
if (setObs.get(pomdp.getObservation(s))) {
set2.set(s);
}
}
if (!target.equals(target2)) {
throw new PrismException("Target is not observable");
if (!set.equals(set2)) {
throw new PrismException("Set is not observable");
}
return targetObs;
return setObs;
}
/**
@ -572,12 +596,16 @@ public class POMDPModelChecker extends ProbModelChecker
/**
* Compute the grid-based approximate value for a belief for probabilistic reachability
*/
protected double approximateReachProb(Belief belief, HashMap<Belief, Double> gridValues, BitSet targetObs)
protected double approximateReachProb(Belief belief, HashMap<Belief, Double> gridValues, BitSet targetObs, BitSet unknownObs)
{
// 1 for target states
if (targetObs.get(belief.so)) {
return 1.0;
}
// 0 for other non-unknown states
else if (!unknownObs.get(belief.so)) {
return 0.0;
}
// Otherwise approximate vie interpolation over grid points
else {
return interpolateOverGrid(belief, gridValues);
@ -634,7 +662,7 @@ public class POMDPModelChecker extends ProbModelChecker
* @param min
* @param listBeliefs
*/
protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, HashMap<Belief, Double> vhash, BitSet targetObs, boolean min, List<Belief> listBeliefs)
protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, HashMap<Belief, Double> vhash, BitSet targetObs, BitSet unknownObs, boolean min, List<Belief> listBeliefs)
{
// Initialise model/state/rewards storage
MDPSimple mdp = new MDPSimple();
@ -654,11 +682,13 @@ public class POMDPModelChecker extends ProbModelChecker
while (!toBeExploredBelives.isEmpty()) {
Belief b = toBeExploredBelives.pollFirst();
src++;
// Remember if this is a target state
if (targetObs.get(b.so)) {
mdpTarget.set(src);
} else {
extractBestActions(src, b, vhash, targetObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards);
}
// Only explore "unknown" states
if (unknownObs.get(b.so)) {
extractBestActions(src, b, vhash, targetObs, unknownObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards);
}
}
// Attach a label marking target states
@ -681,7 +711,7 @@ public class POMDPModelChecker extends ProbModelChecker
* @param min
* @param beliefList
*/
protected void extractBestActions(int src, Belief belief, HashMap<Belief, Double> vhash, BitSet targetObs, POMDP pomdp, MDPRewards mdpRewards, boolean min,
protected void extractBestActions(int src, Belief belief, HashMap<Belief, Double> vhash, BitSet targetObs, BitSet unknownObs, POMDP pomdp, MDPRewards mdpRewards, boolean min,
IndexedSet<Belief> exploredBelieves, LinkedList<Belief> toBeExploredBelives, MDPSimple mdp, StateRewardsSimple stateRewards)
{
double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
@ -698,7 +728,7 @@ public class POMDPModelChecker extends ProbModelChecker
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
if (mdpRewards == null) {
value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs);
value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs, unknownObs);
} else {
value += nextBeliefProb * approximateReachReward(nextBelief, vhash, targetObs);
}

2
prism/src/explicit/ProbModelChecker.java

@ -852,7 +852,7 @@ public class ProbModelChecker extends NonProbModelChecker
res = ((MDPModelChecker) this).computeUntilProbs((MDP) model, remain, target, minMax.isMin());
break;
case POMDP:
res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, target, minMax.isMin());
res = ((POMDPModelChecker) this).computeReachProbs((POMDP) model, remain, target, minMax.isMin());
break;
case STPG:
res = ((STPGModelChecker) this).computeUntilProbs((STPG) model, remain, target, minMax.isMin1(), minMax.isMin2());

Loading…
Cancel
Save