diff --git a/prism/src/explicit/POMDPModelChecker.java b/prism/src/explicit/POMDPModelChecker.java index 1e4f8b43..c087d324 100644 --- a/prism/src/explicit/POMDPModelChecker.java +++ b/prism/src/explicit/POMDPModelChecker.java @@ -119,31 +119,25 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target states + // And determine set of observations actually need to perform computation for BitSet targetObs = getAndCheckTargetObservations(pomdp, target); - - // Initialise the grid points - ArrayList gridPoints = new ArrayList<>();//the set of grid points (discretized believes) - ArrayList unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes) - initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints); - int unK = unknownGridPoints.size(); - mainLog.print("Grid statistics: resolution=" + gridResolution); - mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK); - - // Construct grid belief "MDP" (over all unknown grid points) + BitSet unknownObs = new BitSet(); + unknownObs.set(0, pomdp.getNumObservations()); + unknownObs.andNot(targetObs); + + // Initialise the grid points (just for unknown beliefs) + List gridPoints = initialiseGridPoints(pomdp, unknownObs); + mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size()); + // Construct grid belief "MDP" mainLog.println("Building belief space approximation..."); - List>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints); + List>> beliefMDP = buildBeliefMDP(pomdp, gridPoints); - // HashMap for storing real time values for the discretized grid belief states + // Initialise hashmaps for storing values for the grid belief states HashMap vhash = new HashMap<>(); HashMap vhash_backUp = new HashMap<>(); - for (Belief g : gridPoints) { - if (unknownGridPoints.contains(g)) { - vhash.put(g, 0.0); - vhash_backUp.put(g, 0.0); - } else { - vhash.put(g, 1.0); - vhash_backUp.put(g, 1.0); - } + for (Belief belief : gridPoints) { + vhash.put(belief, 0.0); + vhash_backUp.put(belief, 0.0); } // Start iterations @@ -154,8 +148,9 @@ public class POMDPModelChecker extends ProbModelChecker boolean done = false; while (!done && iters < maxIters) { // Iterate over all (unknown) grid points + int unK = gridPoints.size(); for (int b = 0; b < unK; b++) { - Belief belief = unknownGridPoints.get(b); + Belief belief = gridPoints.get(b); int numChoices = pomdp.getNumChoicesForObservation(belief.so); chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; @@ -165,7 +160,7 @@ public class POMDPModelChecker extends ProbModelChecker double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); // find discretized grid points to approximate the nextBelief - value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp); + value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs); } if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { chosenValue = value; @@ -196,7 +191,7 @@ public class POMDPModelChecker extends ProbModelChecker // Find discretized grid points to approximate the initialBelief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = pomdp.getInitialBelief(); - double outerBound = interpolateOverGrid(initialBelief, vhash_backUp); + double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result @@ -322,23 +317,24 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target states + // And determine set of observations actually need to perform computation for BitSet targetObs = getAndCheckTargetObservations(pomdp, target); - - // Initialise the grid points - ArrayList gridPoints = new ArrayList<>();//the set of grid points (discretized believes) - ArrayList unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes) - initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints); - int unK = unknownGridPoints.size(); - mainLog.print("Grid statistics: resolution=" + gridResolution); - mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK); - - // Construct grid belief "MDP" (over all unknown grid points) + BitSet unknownObs = new BitSet(); + unknownObs.set(0, pomdp.getNumObservations()); + unknownObs.andNot(targetObs); + + // Initialise the grid points (just for unknown beliefs) + List gridPoints = initialiseGridPoints(pomdp, unknownObs); + mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size()); + // Construct grid belief "MDP" mainLog.println("Building belief space approximation..."); - List>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints); + List>> beliefMDP = buildBeliefMDP(pomdp, gridPoints); + // Rewards List> rewards = new ArrayList<>(); // memoization for reuse + int unK = gridPoints.size(); for (int b = 0; b < unK; b++) { - Belief belief = unknownGridPoints.get(b); + Belief belief = gridPoints.get(b); int numChoices = pomdp.getNumChoicesForObservation(belief.so); List action_reward = new ArrayList<>();// for memoization for (int i = 0; i < numChoices; i++) { @@ -347,12 +343,12 @@ public class POMDPModelChecker extends ProbModelChecker rewards.add(action_reward); } - // HashMap for storing real time values for the discretized grid belief states + // Initialise hashmaps for storing values for the grid belief states HashMap vhash = new HashMap<>(); HashMap vhash_backUp = new HashMap<>(); - for (Belief g : gridPoints) { - vhash.put(g, 0.0); - vhash_backUp.put(g, 0.0); + for (Belief belief : gridPoints) { + vhash.put(belief, 0.0); + vhash_backUp.put(belief, 0.0); } // Start iterations @@ -364,7 +360,7 @@ public class POMDPModelChecker extends ProbModelChecker while (!done && iters < maxIters) { // Iterate over all (unknown) grid points for (int b = 0; b < unK; b++) { - Belief belief = unknownGridPoints.get(b); + Belief belief = gridPoints.get(b); int numChoices = pomdp.getNumChoicesForObservation(belief.so); chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; for (int i = 0; i < numChoices; i++) { @@ -373,7 +369,7 @@ public class POMDPModelChecker extends ProbModelChecker double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); // find discretized grid points to approximate the nextBelief - value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp); + value += nextBeliefProb * approximateReachReward(nextBelief, vhash_backUp, targetObs); } if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { chosenValue = value; @@ -404,7 +400,7 @@ public class POMDPModelChecker extends ProbModelChecker // Find discretized grid points to approximate the initialBelief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = pomdp.getInitialBelief(); - double outerBound = interpolateOverGrid(initialBelief, vhash_backUp); + double outerBound = approximateReachReward(initialBelief, vhash_backUp, targetObs); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result @@ -503,13 +499,17 @@ public class POMDPModelChecker extends ProbModelChecker return targetObs; } - protected void initialiseGridPoints(POMDP pomdp, BitSet targetObs, ArrayList gridPoints, ArrayList unknownGridPoints) + /** + * Construct a list of beliefs for a grid-based approximation of the belief space. + * Only beliefs with observable values from {@code unknownObs) are added. + */ + protected List initialiseGridPoints(POMDP pomdp, BitSet unknownObs) { + List gridPoints = new ArrayList<>(); ArrayList> assignment; - int numObservations = pomdp.getNumObservations(); int numUnobservations = pomdp.getNumUnobservations(); int numStates = pomdp.getNumStates(); - for (int so = 0; so < numObservations; so++) { + for (int so = unknownObs.nextSetBit(0); so >= 0; so = unknownObs.nextSetBit(so + 1)) { ArrayList unobservsForObserv = new ArrayList<>(); for (int s = 0; s < numStates; s++) { if (so == pomdp.getObservation(s)) { @@ -524,14 +524,10 @@ public class POMDPModelChecker extends ProbModelChecker bu[unobservForObserv] = inner.get(k); k++; } - - Belief g = new Belief(so, bu); - gridPoints.add(g); - if (!targetObs.get(so)) { - unknownGridPoints.add(g); - } + gridPoints.add(new Belief(so, bu)); } } + return gridPoints; } /** @@ -573,6 +569,36 @@ public class POMDPModelChecker extends ProbModelChecker return beliefMDPState; } + /** + * Compute the grid-based approximate value for a belief for probabilistic reachability + */ + protected double approximateReachProb(Belief belief, HashMap gridValues, BitSet targetObs) + { + // 1 for target states + if (targetObs.get(belief.so)) { + return 1.0; + } + // Otherwise approximate vie interpolation over grid points + else { + return interpolateOverGrid(belief, gridValues); + } + } + + /** + * Compute the grid-based approximate value for a belief for reward reachability + */ + protected double approximateReachReward(Belief belief, HashMap gridValues, BitSet targetObs) + { + // 0 for target states + if (targetObs.get(belief.so)) { + return 0.0; + } + // Otherwise approximate vie interpolation over grid points + else { + return interpolateOverGrid(belief, gridValues); + } + } + /** * Approximate the value for a belief {@code belief} by interpolating over values {@code gridValues} * for a representative set of beliefs whose convex hull is the full belief space. @@ -582,7 +608,6 @@ public class POMDPModelChecker extends ProbModelChecker ArrayList subSimplex = new ArrayList<>(); double[] lambdas = new double[belief.bu.length]; getSubSimplexAndLambdas(belief.bu, subSimplex, lambdas, gridResolution); - //calculate the approximate value for the belief double val = 0; for (int j = 0; j < lambdas.length; j++) { if (lambdas[j] >= 1e-6) { @@ -633,7 +658,7 @@ public class POMDPModelChecker extends ProbModelChecker if (targetObs.get(b.so)) { mdpTarget.set(src); } else { - extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards); + extractBestActions(src, b, vhash, targetObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards); } } // Attach a label marking target states @@ -656,7 +681,7 @@ public class POMDPModelChecker extends ProbModelChecker * @param min * @param beliefList */ - protected void extractBestActions(int src, Belief belief, HashMap vhash, POMDP pomdp, MDPRewards mdpRewards, boolean min, + protected void extractBestActions(int src, Belief belief, HashMap vhash, BitSet targetObs, POMDP pomdp, MDPRewards mdpRewards, boolean min, IndexedSet exploredBelieves, LinkedList toBeExploredBelives, MDPSimple mdp, StateRewardsSimple stateRewards) { double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; @@ -672,7 +697,11 @@ public class POMDPModelChecker extends ProbModelChecker for (Map.Entry entry : beliefMDPState.get(a).entrySet()) { double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); - value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash); + if (mdpRewards == null) { + value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs); + } else { + value += nextBeliefProb * approximateReachReward(nextBelief, vhash, targetObs); + } } //select action that minimizes/maximizes Q(a,b), i.e. value