From 90c7df8209340bfc8ec5d384f11f66069fbf2973 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Thu, 4 Mar 2021 15:38:45 +0000 Subject: [PATCH] More refactoring in POMDP solution. * Separate data structures for BeliefMDPState and POMDPStrategyModel * Push reward construction/storage into the code for the belief MDP * Factor out prob/reward backup operations into methods for re-use * Store value function + backup using functional interfaces * Collapse (now simpler) buildStrategyModel into one method --- prism/src/explicit/POMDPModelChecker.java | 350 +++++++++++----------- 1 file changed, 178 insertions(+), 172 deletions(-) diff --git a/prism/src/explicit/POMDPModelChecker.java b/prism/src/explicit/POMDPModelChecker.java index 35ce8080..0744b26e 100644 --- a/prism/src/explicit/POMDPModelChecker.java +++ b/prism/src/explicit/POMDPModelChecker.java @@ -35,6 +35,8 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; import explicit.graphviz.Decoration; import explicit.graphviz.Decorator; @@ -53,6 +55,45 @@ import prism.PrismUtils; */ public class POMDPModelChecker extends ProbModelChecker { + // Some local data structures for convenience + + /** + * Info for a single state of a belief MDP: + * (1) a list (over choices in the state) of distributions over beliefs, stored as hashmap; + * (2) optionally, a list (over choices in the state) of rewards + */ + class BeliefMDPState + { + public List> trans; + public List rewards; + public BeliefMDPState() + { + trans = new ArrayList<>(); + rewards = new ArrayList<>(); + } + } + + /** + * Value backup function for belief state value iteration: + * mapping from a state and its definition (reward + transitions) + * to a pair of the optimal value + choice index. + */ + @FunctionalInterface + interface BeliefMDPBackUp extends BiFunction> {} + + /** + * A model constructed to represent a fragment of a belief MDP induced by a strategy: + * (1) the model (represented as an MDP for ease of storing actions labels) + * (2) optionally, a reward structure + * (3) a list of the beliefs corresponding to each state of the model + */ + class POMDPStrategyModel + { + public MDP mdp; + public MDPRewards mdpRewards; + public List beliefs; + } + /** * Create a new POMDPModelChecker, inherit basic state from parent (unless null). */ @@ -122,16 +163,12 @@ public class POMDPModelChecker extends ProbModelChecker // Find out the observations for the target/remain states // And determine set of observations actually need to perform computation for - BitSet targetObs = null; - try { - targetObs = getObservationsMatchingStates(pomdp, target); - } catch (PrismException e) { + BitSet targetObs = getObservationsMatchingStates(pomdp, target);; + if (targetObs == null) { throw new PrismException("Target for reachability is not observable"); } - BitSet remainObs = null; - try { - remainObs = remain == null ? null : getObservationsMatchingStates(pomdp, remain); - } catch (PrismException e) { + BitSet remainObs = (remain == null) ? null : getObservationsMatchingStates(pomdp, remain); + if (remain != null && remainObs == null) { throw new PrismException("Left-hand side of until is not observable"); } BitSet unknownObs = new BitSet(); @@ -146,20 +183,23 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size()); // Construct grid belief "MDP" mainLog.println("Building belief space approximation..."); - List>> beliefMDP = buildBeliefMDP(pomdp, gridPoints); + List beliefMDP = buildBeliefMDP(pomdp, null, gridPoints); - // Initialise hashmaps for storing values for the grid belief states + // Initialise hashmaps for storing values for the unknown belief states HashMap vhash = new HashMap<>(); HashMap vhash_backUp = new HashMap<>(); for (Belief belief : gridPoints) { vhash.put(belief, 0.0); vhash_backUp.put(belief, 0.0); } - + // Define value function for the full set of belief states + Function values = belief -> approximateReachProb(belief, vhash_backUp, targetObs, unknownObs); + // Define value backup function + BeliefMDPBackUp backup = (belief, beliefState) -> approximateReachProbBackup(belief, beliefState, values, min); + // Start iterations mainLog.println("Solving belief space approximation..."); long timer2 = System.currentTimeMillis(); - double value, chosenValue; int iters = 0; boolean done = false; while (!done && iters < maxIters) { @@ -167,23 +207,8 @@ public class POMDPModelChecker extends ProbModelChecker int unK = gridPoints.size(); for (int b = 0; b < unK; b++) { Belief belief = gridPoints.get(b); - int numChoices = pomdp.getNumChoicesForObservation(belief.so); - - chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; - for (int i = 0; i < numChoices; i++) { - value = 0; - for (Map.Entry entry : beliefMDP.get(b).get(i).entrySet()) { - double nextBeliefProb = entry.getValue(); - Belief nextBelief = entry.getKey(); - // find discretized grid points to approximate the nextBelief - value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs, unknownObs); - } - if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { - chosenValue = value; - } - } - //update V(b) to the chosenValue - vhash.put(belief, chosenValue); + Pair valChoice = backup.apply(belief, beliefMDP.get(b)); + vhash.put(belief, valChoice.first); } // Check termination done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE); @@ -204,10 +229,10 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds."); - // Find discretized grid points to approximate the initialBelief + // Extract (approximate) solution value for the initial belief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = pomdp.getInitialBelief(); - double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs, unknownObs); + double outerBound = values.apply(initialBelief); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result @@ -215,8 +240,8 @@ public class POMDPModelChecker extends ProbModelChecker // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); - List listBeliefs = new ArrayList<>(); - MDP mdp = buildStrategyModel(pomdp, null, vhash, targetObs, unknownObs, min, listBeliefs).mdp; + POMDPStrategyModel psm = buildStrategyModel(pomdp, null, targetObs, unknownObs, backup); + MDP mdp = psm.mdp; mainLog.print("Strategy-induced model: " + mdp.infoString()); // Export? if (stratFilename != null) { @@ -227,7 +252,7 @@ public class POMDPModelChecker extends ProbModelChecker @Override public Decoration decorateState(int state, Decoration d) { - d.labelAddBelow(listBeliefs.get(state).toString(pomdp)); + d.labelAddBelow(psm.beliefs.get(state).toString(pomdp)); return d; } })); @@ -335,10 +360,8 @@ public class POMDPModelChecker extends ProbModelChecker // Find out the observations for the target states // And determine set of observations actually need to perform computation for - BitSet targetObs = null; - try { - targetObs = getObservationsMatchingStates(pomdp, target); - } catch (PrismException e) { + BitSet targetObs = getObservationsMatchingStates(pomdp, target);; + if (targetObs == null) { throw new PrismException("Target for expected reachability is not observable"); } BitSet unknownObs = new BitSet(); @@ -350,55 +373,32 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size()); // Construct grid belief "MDP" mainLog.println("Building belief space approximation..."); - List>> beliefMDP = buildBeliefMDP(pomdp, gridPoints); - - // Rewards - List> rewards = new ArrayList<>(); // memoization for reuse - int unK = gridPoints.size(); - for (int b = 0; b < unK; b++) { - Belief belief = gridPoints.get(b); - int numChoices = pomdp.getNumChoicesForObservation(belief.so); - List action_reward = new ArrayList<>();// for memoization - for (int i = 0; i < numChoices; i++) { - action_reward.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards)); // c(a,b) - } - rewards.add(action_reward); - } + List beliefMDP = buildBeliefMDP(pomdp, mdpRewards, gridPoints); - // Initialise hashmaps for storing values for the grid belief states + // Initialise hashmaps for storing values for the unknown belief states HashMap vhash = new HashMap<>(); HashMap vhash_backUp = new HashMap<>(); for (Belief belief : gridPoints) { vhash.put(belief, 0.0); vhash_backUp.put(belief, 0.0); } + // Define value function for the full set of belief states + Function values = belief -> approximateReachReward(belief, vhash_backUp, targetObs); + // Define value backup function + BeliefMDPBackUp backup = (belief, beliefState) -> approximateReachRewardBackup(belief, beliefState, values, min); // Start iterations mainLog.println("Solving belief space approximation..."); long timer2 = System.currentTimeMillis(); - double value, chosenValue; int iters = 0; boolean done = false; while (!done && iters < maxIters) { // Iterate over all (unknown) grid points + int unK = gridPoints.size(); for (int b = 0; b < unK; b++) { Belief belief = gridPoints.get(b); - int numChoices = pomdp.getNumChoicesForObservation(belief.so); - chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; - for (int i = 0; i < numChoices; i++) { - value = rewards.get(b).get(i); - for (Map.Entry entry : beliefMDP.get(b).get(i).entrySet()) { - double nextBeliefProb = entry.getValue(); - Belief nextBelief = entry.getKey(); - // find discretized grid points to approximate the nextBelief - value += nextBeliefProb * approximateReachReward(nextBelief, vhash_backUp, targetObs); - } - if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { - chosenValue = value; - } - } - //update V(b) to the chosenValue - vhash.put(belief, chosenValue); + Pair valChoice = backup.apply(belief, beliefMDP.get(b)); + vhash.put(belief, valChoice.first); } // Check termination done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE); @@ -419,10 +419,10 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds."); - // Find discretized grid points to approximate the initialBelief + // Extract (approximate) solution value for the initial belief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = pomdp.getInitialBelief(); - double outerBound = approximateReachReward(initialBelief, vhash_backUp, targetObs); + double outerBound = values.apply(initialBelief); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result @@ -430,8 +430,7 @@ public class POMDPModelChecker extends ProbModelChecker // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); - List listBeliefs = new ArrayList<>(); - POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, vhash, targetObs, unknownObs, min, listBeliefs); + POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, targetObs, unknownObs, backup); MDP mdp = psm.mdp; MDPRewards mdpRewardsNew = psm.mdpRewards; mainLog.print("Strategy-induced model: " + mdp.infoString()); @@ -444,7 +443,7 @@ public class POMDPModelChecker extends ProbModelChecker @Override public Decoration decorateState(int state, Decoration d) { - d.labelAddBelow(listBeliefs.get(state).toString(pomdp)); + d.labelAddBelow(psm.beliefs.get(state).toString(pomdp)); return d; } })); @@ -500,9 +499,9 @@ public class POMDPModelChecker extends ProbModelChecker * The states should correspond exactly to a set of observations, * i.e., if a state corresponding to an observation is in the set, * then all other states corresponding to it should also be. - * An exception is thrown if not. + * Returns null if not. */ - protected BitSet getObservationsMatchingStates(POMDP pomdp, BitSet set) throws PrismException + protected BitSet getObservationsMatchingStates(POMDP pomdp, BitSet set) { // Find observations corresponding to each state in the set BitSet setObs = new BitSet(); @@ -518,7 +517,7 @@ public class POMDPModelChecker extends ProbModelChecker } } if (!set.equals(set2)) { - throw new PrismException("Set is not observable"); + return null; } return setObs; } @@ -556,43 +555,99 @@ public class POMDPModelChecker extends ProbModelChecker /** * Construct (part of) a belief MDP, just for the set of passed in belief states. - * It is stored as a list (over source beliefs) of lists (over choices) - * of distributions over target beliefs, stored as a hashmap. + * If provided, also construct a list of rewards for each state. + * It is stored as a list (over source beliefs) of BeliefMDPState objects. */ - protected List>> buildBeliefMDP(POMDP pomdp, List beliefs) + protected List buildBeliefMDP(POMDP pomdp, MDPRewards mdpRewards, List beliefs) { - List>> beliefMDP = new ArrayList<>(); + List beliefMDP = new ArrayList<>(); for (Belief belief: beliefs) { - beliefMDP.add(buildBeliefMDPState(pomdp, belief)); + beliefMDP.add(buildBeliefMDPState(pomdp, mdpRewards, belief)); } return beliefMDP; } /** * Construct a single single state (belief) of a belief MDP, stored as a - * list (over choices) of distributions over target beliefs, stored as a hashmap. + * list (over choices) of distributions over target beliefs. + * If provided, also construct a list of rewards for the state. + * It is stored as a BeliefMDPState object. */ - protected List> buildBeliefMDPState(POMDP pomdp, Belief belief) + protected BeliefMDPState buildBeliefMDPState(POMDP pomdp, MDPRewards mdpRewards, Belief belief) { double[] beliefInDist = belief.toDistributionOverStates(pomdp); - List> beliefMDPState = new ArrayList<>(); + BeliefMDPState beliefMDPState = new BeliefMDPState(); // And for each choice int numChoices = pomdp.getNumChoicesForObservation(belief.so); for (int i = 0; i < numChoices; i++) { // Get successor observations and their probs HashMap obsProbs = pomdp.computeObservationProbsAfterAction(beliefInDist, i); HashMap beliefDist = new HashMap<>(); - // Find the belief for each observations + // Find the belief for each observation for (Map.Entry entry : obsProbs.entrySet()) { int o = entry.getKey(); Belief nextBelief = pomdp.getBeliefAfterChoiceAndObservation(belief, i, o); beliefDist.put(nextBelief, entry.getValue()); } - beliefMDPState.add(beliefDist); + beliefMDPState.trans.add(beliefDist); + // Store reward too, if required + if (mdpRewards != null) { + beliefMDPState.rewards.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards)); + } } return beliefMDPState; } + /** + * Perform a single backup step of (approximate) value iteration for probabilistic reachability + */ + protected Pair approximateReachProbBackup(Belief belief, BeliefMDPState beliefMDPState, Function values, boolean min) + { + int numChoices = beliefMDPState.trans.size(); + double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; + int chosenActionIndex = -1; + for (int i = 0; i < numChoices; i++) { + double value = 0; + for (Map.Entry entry : beliefMDPState.trans.get(i).entrySet()) { + double nextBeliefProb = entry.getValue(); + Belief nextBelief = entry.getKey(); + value += nextBeliefProb * values.apply(nextBelief); + } + if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { + chosenValue = value; + chosenActionIndex = i; + } else if (Math.abs(value - chosenValue) < 1.0e-6) { + chosenActionIndex = i; + } + } + return new Pair(chosenValue, chosenActionIndex); + } + + /** + * Perform a single backup step of (approximate) value iteration for reward reachability + */ + protected Pair approximateReachRewardBackup(Belief belief, BeliefMDPState beliefMDPState, Function values, boolean min) + { + int numChoices = beliefMDPState.trans.size(); + double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; + int chosenActionIndex = -1; + for (int i = 0; i < numChoices; i++) { + double value = beliefMDPState.rewards.get(i); + for (Map.Entry entry : beliefMDPState.trans.get(i).entrySet()) { + double nextBeliefProb = entry.getValue(); + Belief nextBelief = entry.getKey(); + value += nextBeliefProb * values.apply(nextBelief); + } + if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { + chosenValue = value; + chosenActionIndex = i; + } else if (Math.abs(value - chosenValue) < 1.0e-6) { + chosenActionIndex = i; + } + } + return new Pair(chosenValue, chosenActionIndex); + } + /** * Compute the grid-based approximate value for a belief for probabilistic reachability */ @@ -645,12 +700,6 @@ public class POMDPModelChecker extends ProbModelChecker return val; } - class POMDPStrategyModel - { - public MDP mdp; - public MDPRewards mdpRewards; - } - /** * Build a (Markov chain) model representing the fragment of the belief MDP induced by an optimal strategy. * The model is stored as an MDP to allow easier attachment of optional actions. @@ -662,112 +711,69 @@ public class POMDPModelChecker extends ProbModelChecker * @param min * @param listBeliefs */ - protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, HashMap vhash, BitSet targetObs, BitSet unknownObs, boolean min, List listBeliefs) + protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup) { // Initialise model/state/rewards storage MDPSimple mdp = new MDPSimple(); - IndexedSet exploredBelieves = new IndexedSet<>(true); - LinkedList toBeExploredBelives = new LinkedList<>(); + IndexedSet exploredBeliefs = new IndexedSet<>(true); + LinkedList toBeExploredBeliefs = new LinkedList<>(); BitSet mdpTarget = new BitSet(); StateRewardsSimple stateRewards = new StateRewardsSimple(); // Add initial state Belief initialBelief = pomdp.getInitialBelief(); - exploredBelieves.add(initialBelief); - toBeExploredBelives.offer(initialBelief); + exploredBeliefs.add(initialBelief); + toBeExploredBeliefs.offer(initialBelief); mdp.addState(); mdp.addInitialState(0); // Explore model int src = -1; - while (!toBeExploredBelives.isEmpty()) { - Belief b = toBeExploredBelives.pollFirst(); + while (!toBeExploredBeliefs.isEmpty()) { + Belief belief = toBeExploredBeliefs.pollFirst(); src++; // Remember if this is a target state - if (targetObs.get(b.so)) { + if (targetObs.get(belief.so)) { mdpTarget.set(src); } // Only explore "unknown" states - if (unknownObs.get(b.so)) { - extractBestActions(src, b, vhash, targetObs, unknownObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards); + if (unknownObs.get(belief.so)) { + // Build the belief MDP for this belief state and solve + BeliefMDPState beliefMDPState = buildBeliefMDPState(pomdp, mdpRewards, belief); + Pair valChoice = backup.apply(belief, beliefMDPState); + int chosenActionIndex = valChoice.second; + // Build a distribution over successor belief states and add to MDP + Distribution distr = new Distribution(); + for (Map.Entry entry : beliefMDPState.trans.get(chosenActionIndex).entrySet()) { + double nextBeliefProb = entry.getValue(); + Belief nextBelief = entry.getKey(); + // Add each successor belief to the MDP and the "to explore" set if new + if (exploredBeliefs.add(nextBelief)) { + toBeExploredBeliefs.add(nextBelief); + mdp.addState(); + } + // Get index of state in state set + int dest = exploredBeliefs.getIndexOfLastAdd(); + distr.add(dest, nextBeliefProb); + } + // Add transition distribution, with choice _index_ encoded as action + mdp.addActionLabelledChoice(src, distr, pomdp.getActionForObservation(belief.so, chosenActionIndex)); + // Store reward too, if needed + if (mdpRewards != null) { + stateRewards.setStateReward(src, pomdp.getRewardAfterChoice(belief, chosenActionIndex, mdpRewards)); + } } } // Attach a label marking target states mdp.addLabel("target", mdpTarget); - listBeliefs.addAll(exploredBelieves.toArrayList()); // Return POMDPStrategyModel psm = new POMDPStrategyModel(); psm.mdp = mdp; psm.mdpRewards = stateRewards; + psm.beliefs = new ArrayList<>(); + psm.beliefs.addAll(exploredBeliefs.toArrayList()); return psm; } - /** - * Find the best action for this belief state, add the belief state to the list - * of ones examined so far, and store the strategy info. We store this as an MDP. - * @param belief Belief state to examine - * @param vhash - * @param pomdp - * @param mdpRewards - * @param min - * @param beliefList - */ - protected void extractBestActions(int src, Belief belief, HashMap vhash, BitSet targetObs, BitSet unknownObs, POMDP pomdp, MDPRewards mdpRewards, boolean min, - IndexedSet exploredBelieves, LinkedList toBeExploredBelives, MDPSimple mdp, StateRewardsSimple stateRewards) - { - double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; - int chosenActionIndex = -1; - //evaluate each action in b - int numChoices = pomdp.getNumChoicesForObservation(belief.so); - List> beliefMDPState = buildBeliefMDPState(pomdp, belief); - for (int a = 0; a < numChoices; a++) { - double value = 0; - if (mdpRewards != null) { - value = pomdp.getRewardAfterChoice(belief, a, mdpRewards); // c(a,b) - } - for (Map.Entry entry : beliefMDPState.get(a).entrySet()) { - double nextBeliefProb = entry.getValue(); - Belief nextBelief = entry.getKey(); - if (mdpRewards == null) { - value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs, unknownObs); - } else { - value += nextBeliefProb * approximateReachReward(nextBelief, vhash, targetObs); - } - } - - //select action that minimizes/maximizes Q(a,b), i.e. value - if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6))//value entry : beliefMDPState.get(chosenActionIndex).entrySet()) { - double nextBeliefProb = entry.getValue(); - Belief nextBelief = entry.getKey(); - // Add each successor belief to the MDP and the "to explore" set if new - if (exploredBelieves.add(nextBelief)) { - toBeExploredBelives.add(nextBelief); - mdp.addState(); - } - // Get index of state in state set - int dest = exploredBelieves.getIndexOfLastAdd(); - distr.add(dest, nextBeliefProb); - } - // Add transition distribution, with choice _index_ encoded as action - mdp.addActionLabelledChoice(src, distr, pomdp.getActionForObservation(belief.so, chosenActionIndex)); - // Store reward too, if needed - if (mdpRewards != null) { - stateRewards.setStateReward(src, pomdp.getRewardAfterChoice(belief, chosenActionIndex, mdpRewards)); - } - } - protected ArrayList> assignGPrime(int startIndex, int min, int max, int length) { ArrayList> result = new ArrayList>();