//============================================================================== // // Copyright (c) 2014- // Authors: // * Xueyi Zou (University of York) // * Dave Parker (University of Birmingham) // //------------------------------------------------------------------------------ // // This file is part of PRISM. // // PRISM is free software; you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation; either version 2 of the License, or // (at your option) any later version. // // PRISM is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with PRISM; if not, write to the Free Software Foundation, // Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // //============================================================================== package explicit; import java.util.ArrayList; import java.util.BitSet; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; import explicit.graphviz.Decoration; import explicit.graphviz.Decorator; import explicit.rewards.MDPRewards; import explicit.rewards.StateRewardsSimple; import prism.Accuracy; import prism.AccuracyFactory; import prism.Pair; import prism.PrismComponent; import prism.PrismException; import prism.PrismNotSupportedException; import prism.PrismUtils; /** * Explicit-state model checker for partially observable Markov decision processes (POMDPs). */ public class POMDPModelChecker extends ProbModelChecker { // Some local data structures for convenience /** * Info for a single state of a belief MDP: * (1) a list (over choices in the state) of distributions over beliefs, stored as hashmap; * (2) optionally, a list (over choices in the state) of rewards */ class BeliefMDPState { public List> trans; public List rewards; public BeliefMDPState() { trans = new ArrayList<>(); rewards = new ArrayList<>(); } } /** * Value backup function for belief state value iteration: * mapping from a state and its definition (reward + transitions) * to a pair of the optimal value + choice index. */ @FunctionalInterface interface BeliefMDPBackUp extends BiFunction> {} /** * A model constructed to represent a fragment of a belief MDP induced by a strategy: * (1) the model (represented as an MDP for ease of storing actions labels) * (2) optionally, a reward structure * (3) a list of the beliefs corresponding to each state of the model */ class POMDPStrategyModel { public MDP mdp; public MDPRewards mdpRewards; public List beliefs; } /** * Create a new POMDPModelChecker, inherit basic state from parent (unless null). */ public POMDPModelChecker(PrismComponent parent) throws PrismException { super(parent); } /** * Compute reachability/until probabilities. * i.e. compute the min/max probability of reaching a state in {@code target}, * while remaining in those in @{code remain}. * @param pomdp The POMDP * @param remain Remain in these states (optional: null means "all") * @param target Target states * @param min Min or max probabilities (true=min, false=max) */ public ModelCheckerResult computeReachProbs(POMDP pomdp, BitSet remain, BitSet target, boolean min, BitSet statesOfInterest) throws PrismException { ModelCheckerResult res = null; long timer; String stratFilename = null; // Check we are only computing for a single state (and use initial state if unspecified) if (statesOfInterest == null) { statesOfInterest = new BitSet(); statesOfInterest.set(pomdp.getFirstInitialState()); } else if (statesOfInterest.cardinality() > 1) { throw new PrismNotSupportedException("POMDPs can only be solved from a single start state"); } // Start probabilistic reachability timer = System.currentTimeMillis(); mainLog.println("\nStarting probabilistic reachability (" + (min ? "min" : "max") + ")..."); // If required, create/initialise strategy storage if (genStrat || exportAdv) { stratFilename = exportAdvFilename;//"policyGraph.txt"; } // Compute rewards res = computeReachProbsFixedGrid(pomdp, remain, target, min, statesOfInterest.nextSetBit(0), stratFilename); // Finished probabilistic reachability timer = System.currentTimeMillis() - timer; mainLog.println("Probabilistic reachability took " + timer / 1000.0 + " seconds."); // Update time taken res.timeTaken = timer / 1000.0; return res; } /** * Compute reachability/until probabilities, * i.e. compute the min/max probability of reaching a state in {@code target}, * while remaining in those in @{code remain}, * using Lovejoy's fixed-resolution grid approach. * This only computes the probabiity from a single start state * Optionally, store optimal (memoryless) strategy info. * @param pomdp The POMDP * @param remain Remain in these states (optional: null means "all") * @param target Target states * @param min Min or max rewards (true=min, false=max) * @param sInit State to compute for * @param strat Storage for (memoryless) strategy choice indices (ignored if null) */ protected ModelCheckerResult computeReachProbsFixedGrid(POMDP pomdp, BitSet remain, BitSet target, boolean min, int sInit, String stratFilename) throws PrismException { // Start fixed-resolution grid approximation long timer = System.currentTimeMillis(); mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target/remain states // And determine set of observations actually need to perform computation for BitSet targetObs = getObservationsMatchingStates(pomdp, target);; if (targetObs == null) { throw new PrismException("Target for reachability is not observable"); } BitSet remainObs = (remain == null) ? null : getObservationsMatchingStates(pomdp, remain); if (remain != null && remainObs == null) { throw new PrismException("Left-hand side of until is not observable"); } BitSet unknownObs = new BitSet(); unknownObs.set(0, pomdp.getNumObservations()); unknownObs.andNot(targetObs); if (remainObs != null) { unknownObs.and(remainObs); } // Initialise the grid points (just for unknown beliefs) List gridPoints = initialiseGridPoints(pomdp, unknownObs); mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size()); // Construct grid belief "MDP" mainLog.println("Building belief space approximation..."); List beliefMDP = buildBeliefMDP(pomdp, null, gridPoints); // Initialise hashmaps for storing values for the unknown belief states HashMap vhash = new HashMap<>(); HashMap vhash_backUp = new HashMap<>(); for (Belief belief : gridPoints) { vhash.put(belief, 0.0); vhash_backUp.put(belief, 0.0); } // Define value function for the full set of belief states Function values = belief -> approximateReachProb(belief, vhash_backUp, targetObs, unknownObs); // Define value backup function BeliefMDPBackUp backup = (belief, beliefState) -> approximateReachProbBackup(belief, beliefState, values, min); // Start iterations mainLog.println("Solving belief space approximation..."); long timer2 = System.currentTimeMillis(); int iters = 0; boolean done = false; while (!done && iters < maxIters) { // Iterate over all (unknown) grid points int unK = gridPoints.size(); for (int b = 0; b < unK; b++) { Belief belief = gridPoints.get(b); Pair valChoice = backup.apply(belief, beliefMDP.get(b)); vhash.put(belief, valChoice.first); } // Check termination done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE); // back up Set> entries = vhash.entrySet(); for (Map.Entry entry : entries) { vhash_backUp.put(entry.getKey(), entry.getValue()); } iters++; } // Non-convergence is an error (usually) if (!done && errorOnNonConverge) { String msg = "Iterative method did not converge within " + iters + " iterations."; msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; throw new PrismException(msg); } timer2 = System.currentTimeMillis() - timer2; mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds."); // Extract (approximate) solution value for the initial belief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = Belief.pointDistribution(sInit, pomdp); double outerBound = values.apply(initialBelief); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result mainLog.println("Outer bound: " + outerBound + " (" + outerBoundAcc.toString(outerBound) + ")"); // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); POMDPStrategyModel psm = buildStrategyModel(pomdp, sInit, null, targetObs, unknownObs, backup); MDP mdp = psm.mdp; mainLog.print("Strategy-induced model: " + mdp.infoString()); // Export? if (stratFilename != null) { mdp.exportToPrismExplicitTra(stratFilename); //mdp.exportToDotFile(stratFilename + ".dot", mdp.getLabelStates("target")); mdp.exportToDotFile(stratFilename + ".dot", Collections.singleton(new Decorator() { @Override public Decoration decorateState(int state, Decoration d) { d.labelAddBelow(psm.beliefs.get(state).toString(pomdp)); return d; } })); } // Create MDP model checker (disable strat generation - if enabled, we want the POMDP one) MDPModelChecker mcMDP = new MDPModelChecker(this); mcMDP.setExportAdv(false); mcMDP.setGenStrat(false); // Solve MDP to get inner bound // (just reachability: can ignore "remain" since violating states are absent) ModelCheckerResult mcRes = mcMDP.computeReachProbs(mdp, mdp.getLabelStates("target"), true); double innerBound = mcRes.soln[0]; Accuracy innerBoundAcc = mcRes.accuracy; // Print result String innerBoundStr = "" + innerBound; if (innerBoundAcc != null) { innerBoundStr += " (" + innerBoundAcc.toString(innerBound) + ")"; } mainLog.println("Inner bound: " + innerBoundStr); // Finished fixed-resolution grid approximation timer = System.currentTimeMillis() - timer; mainLog.print("\nFixed-resolution grid approximation (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + timer / 1000.0 + " seconds."); // Extract and store result Pair resultValAndAcc; if (min) { resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(outerBound, outerBoundAcc, innerBound, innerBoundAcc); } else { resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(innerBound, innerBoundAcc, outerBound, outerBoundAcc); } double resultVal = resultValAndAcc.first; Accuracy resultAcc = resultValAndAcc.second; mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]"); double soln[] = new double[pomdp.getNumStates()]; soln[sInit] = resultVal; // Return results ModelCheckerResult res = new ModelCheckerResult(); res.soln = soln; res.accuracy = resultAcc; res.numIters = iters; res.timeTaken = timer / 1000.0; return res; } /** * Compute expected reachability rewards, * i.e. compute the min/max reward accumulated to reach a state in {@code target}. * @param pomdp The POMDP * @param mdpRewards The rewards * @param target Target states * @param min Min or max rewards (true=min, false=max) */ public ModelCheckerResult computeReachRewards(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, BitSet statesOfInterest) throws PrismException { ModelCheckerResult res = null; long timer; String stratFilename = null; // Check we are only computing for a single state (and use initial state if unspecified) if (statesOfInterest == null) { statesOfInterest = new BitSet(); statesOfInterest.set(pomdp.getFirstInitialState()); } else if (statesOfInterest.cardinality() > 1) { throw new PrismNotSupportedException("POMDPs can only be solved from a single start state"); } // Start expected reachability timer = System.currentTimeMillis(); mainLog.println("\nStarting expected reachability (" + (min ? "min" : "max") + ")..."); // If required, create/initialise strategy storage if (genStrat || exportAdv) { stratFilename = exportAdvFilename; } // Compute rewards res = computeReachRewardsFixedGrid(pomdp, mdpRewards, target, min, statesOfInterest.nextSetBit(0), stratFilename); // Finished expected reachability timer = System.currentTimeMillis() - timer; mainLog.println("Expected reachability took " + timer / 1000.0 + " seconds."); // Update time taken res.timeTaken = timer / 1000.0; return res; } /** * Compute expected reachability rewards using Lovejoy's fixed-resolution grid approach. * This only computes the expected reward from a single start state * Optionally, store optimal (memoryless) strategy info. * @param pomdp The POMMDP * @param mdpRewards The rewards * @param target Target states * @param inf States for which reward is infinite * @param min Min or max rewards (true=min, false=max) * @param sInit State to compute for * @param strat Storage for (memoryless) strategy choice indices (ignored if null) */ protected ModelCheckerResult computeReachRewardsFixedGrid(POMDP pomdp, MDPRewards mdpRewards, BitSet target, boolean min, int sInit, String stratFilename) throws PrismException { // Start fixed-resolution grid approximation long timer = System.currentTimeMillis(); mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target states // And determine set of observations actually need to perform computation for BitSet targetObs = getObservationsMatchingStates(pomdp, target);; if (targetObs == null) { throw new PrismException("Target for expected reachability is not observable"); } BitSet unknownObs = new BitSet(); unknownObs.set(0, pomdp.getNumObservations()); unknownObs.andNot(targetObs); // Initialise the grid points (just for unknown beliefs) List gridPoints = initialiseGridPoints(pomdp, unknownObs); mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size()); // Construct grid belief "MDP" mainLog.println("Building belief space approximation..."); List beliefMDP = buildBeliefMDP(pomdp, mdpRewards, gridPoints); // Initialise hashmaps for storing values for the unknown belief states HashMap vhash = new HashMap<>(); HashMap vhash_backUp = new HashMap<>(); for (Belief belief : gridPoints) { vhash.put(belief, 0.0); vhash_backUp.put(belief, 0.0); } // Define value function for the full set of belief states Function values = belief -> approximateReachReward(belief, vhash_backUp, targetObs); // Define value backup function BeliefMDPBackUp backup = (belief, beliefState) -> approximateReachRewardBackup(belief, beliefState, values, min); // Start iterations mainLog.println("Solving belief space approximation..."); long timer2 = System.currentTimeMillis(); int iters = 0; boolean done = false; while (!done && iters < maxIters) { // Iterate over all (unknown) grid points int unK = gridPoints.size(); for (int b = 0; b < unK; b++) { Belief belief = gridPoints.get(b); Pair valChoice = backup.apply(belief, beliefMDP.get(b)); vhash.put(belief, valChoice.first); } // Check termination done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE); // back up Set> entries = vhash.entrySet(); for (Map.Entry entry : entries) { vhash_backUp.put(entry.getKey(), entry.getValue()); } iters++; } // Non-convergence is an error (usually) if (!done && errorOnNonConverge) { String msg = "Iterative method did not converge within " + iters + " iterations."; msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; throw new PrismException(msg); } timer2 = System.currentTimeMillis() - timer2; mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds."); // Extract (approximate) solution value for the initial belief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = Belief.pointDistribution(sInit, pomdp); double outerBound = values.apply(initialBelief); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result mainLog.println("Outer bound: " + outerBound + " (" + outerBoundAcc.toString(outerBound) + ")"); // Build DTMC to get inner bound (and strategy) mainLog.println("\nBuilding strategy-induced model..."); POMDPStrategyModel psm = buildStrategyModel(pomdp, sInit, mdpRewards, targetObs, unknownObs, backup); MDP mdp = psm.mdp; MDPRewards mdpRewardsNew = psm.mdpRewards; mainLog.print("Strategy-induced model: " + mdp.infoString()); // Export? if (stratFilename != null) { mdp.exportToPrismExplicitTra(stratFilename); //mdp.exportToDotFile(stratFilename + ".dot", mdp.getLabelStates("target")); mdp.exportToDotFile(stratFilename + ".dot", Collections.singleton(new Decorator() { @Override public Decoration decorateState(int state, Decoration d) { d.labelAddBelow(psm.beliefs.get(state).toString(pomdp)); return d; } })); } // Create MDP model checker (disable strat generation - if enabled, we want the POMDP one) MDPModelChecker mcMDP = new MDPModelChecker(this); mcMDP.setExportAdv(false); mcMDP.setGenStrat(false); // Solve MDP to get inner bound ModelCheckerResult mcRes = mcMDP.computeReachRewards(mdp, mdpRewardsNew, mdp.getLabelStates("target"), true); double innerBound = mcRes.soln[0]; Accuracy innerBoundAcc = mcRes.accuracy; // Print result String innerBoundStr = "" + innerBound; if (innerBoundAcc != null) { innerBoundStr += " (" + innerBoundAcc.toString(innerBound) + ")"; } mainLog.println("Inner bound: " + innerBoundStr); // Finished fixed-resolution grid approximation timer = System.currentTimeMillis() - timer; mainLog.print("\nFixed-resolution grid approximation (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + timer / 1000.0 + " seconds."); // Extract and store result Pair resultValAndAcc; if (min) { resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(outerBound, outerBoundAcc, innerBound, innerBoundAcc); } else { resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(innerBound, innerBoundAcc, outerBound, outerBoundAcc); } double resultVal = resultValAndAcc.first; Accuracy resultAcc = resultValAndAcc.second; mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]"); double soln[] = new double[pomdp.getNumStates()]; soln[sInit] = resultVal; // Return results ModelCheckerResult res = new ModelCheckerResult(); res.soln = soln; res.accuracy = resultAcc; res.numIters = iters; res.timeTaken = timer / 1000.0; return res; } /** * Get a list of observations from a set of states * (both are represented by BitSets over their indices). * The states should correspond exactly to a set of observations, * i.e., if a state corresponding to an observation is in the set, * then all other states corresponding to it should also be. * Returns null if not. */ protected BitSet getObservationsMatchingStates(POMDP pomdp, BitSet set) { // Find observations corresponding to each state in the set BitSet setObs = new BitSet(); for (int s = set.nextSetBit(0); s >= 0; s = set.nextSetBit(s + 1)) { setObs.set(pomdp.getObservation(s)); } // Recreate the set of states from the observations and make sure it matches BitSet set2 = new BitSet(); int numStates = pomdp.getNumStates(); for (int s = 0; s < numStates; s++) { if (setObs.get(pomdp.getObservation(s))) { set2.set(s); } } if (!set.equals(set2)) { return null; } return setObs; } /** * Construct a list of beliefs for a grid-based approximation of the belief space. * Only beliefs with observable values from {@code unknownObs) are added. */ protected List initialiseGridPoints(POMDP pomdp, BitSet unknownObs) { List gridPoints = new ArrayList<>(); ArrayList> assignment; int numUnobservations = pomdp.getNumUnobservations(); int numStates = pomdp.getNumStates(); for (int so = unknownObs.nextSetBit(0); so >= 0; so = unknownObs.nextSetBit(so + 1)) { ArrayList unobservsForObserv = new ArrayList<>(); for (int s = 0; s < numStates; s++) { if (so == pomdp.getObservation(s)) { unobservsForObserv.add(pomdp.getUnobservation(s)); } } assignment = fullAssignment(unobservsForObserv.size(), gridResolution); for (ArrayList inner : assignment) { double[] bu = new double[numUnobservations]; int k = 0; for (int unobservForObserv : unobservsForObserv) { bu[unobservForObserv] = inner.get(k); k++; } gridPoints.add(new Belief(so, bu)); } } return gridPoints; } /** * Construct (part of) a belief MDP, just for the set of passed in belief states. * If provided, also construct a list of rewards for each state. * It is stored as a list (over source beliefs) of BeliefMDPState objects. */ protected List buildBeliefMDP(POMDP pomdp, MDPRewards mdpRewards, List beliefs) { List beliefMDP = new ArrayList<>(); for (Belief belief: beliefs) { beliefMDP.add(buildBeliefMDPState(pomdp, mdpRewards, belief)); } return beliefMDP; } /** * Construct a single single state (belief) of a belief MDP, stored as a * list (over choices) of distributions over target beliefs. * If provided, also construct a list of rewards for the state. * It is stored as a BeliefMDPState object. */ protected BeliefMDPState buildBeliefMDPState(POMDP pomdp, MDPRewards mdpRewards, Belief belief) { double[] beliefInDist = belief.toDistributionOverStates(pomdp); BeliefMDPState beliefMDPState = new BeliefMDPState(); // And for each choice int numChoices = pomdp.getNumChoicesForObservation(belief.so); for (int i = 0; i < numChoices; i++) { // Get successor observations and their probs HashMap obsProbs = pomdp.computeObservationProbsAfterAction(beliefInDist, i); HashMap beliefDist = new HashMap<>(); // Find the belief for each observation for (Map.Entry entry : obsProbs.entrySet()) { int o = entry.getKey(); Belief nextBelief = pomdp.getBeliefAfterChoiceAndObservation(belief, i, o); beliefDist.put(nextBelief, entry.getValue()); } beliefMDPState.trans.add(beliefDist); // Store reward too, if required if (mdpRewards != null) { beliefMDPState.rewards.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards)); } } return beliefMDPState; } /** * Perform a single backup step of (approximate) value iteration for probabilistic reachability */ protected Pair approximateReachProbBackup(Belief belief, BeliefMDPState beliefMDPState, Function values, boolean min) { int numChoices = beliefMDPState.trans.size(); double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; int chosenActionIndex = -1; for (int i = 0; i < numChoices; i++) { double value = 0; for (Map.Entry entry : beliefMDPState.trans.get(i).entrySet()) { double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); value += nextBeliefProb * values.apply(nextBelief); } if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { chosenValue = value; chosenActionIndex = i; } else if (Math.abs(value - chosenValue) < 1.0e-6) { chosenActionIndex = i; } } return new Pair(chosenValue, chosenActionIndex); } /** * Perform a single backup step of (approximate) value iteration for reward reachability */ protected Pair approximateReachRewardBackup(Belief belief, BeliefMDPState beliefMDPState, Function values, boolean min) { int numChoices = beliefMDPState.trans.size(); double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; int chosenActionIndex = -1; for (int i = 0; i < numChoices; i++) { double value = beliefMDPState.rewards.get(i); for (Map.Entry entry : beliefMDPState.trans.get(i).entrySet()) { double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); value += nextBeliefProb * values.apply(nextBelief); } if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { chosenValue = value; chosenActionIndex = i; } else if (Math.abs(value - chosenValue) < 1.0e-6) { chosenActionIndex = i; } } return new Pair(chosenValue, chosenActionIndex); } /** * Compute the grid-based approximate value for a belief for probabilistic reachability */ protected double approximateReachProb(Belief belief, HashMap gridValues, BitSet targetObs, BitSet unknownObs) { // 1 for target states if (targetObs.get(belief.so)) { return 1.0; } // 0 for other non-unknown states else if (!unknownObs.get(belief.so)) { return 0.0; } // Otherwise approximate vie interpolation over grid points else { return interpolateOverGrid(belief, gridValues); } } /** * Compute the grid-based approximate value for a belief for reward reachability */ protected double approximateReachReward(Belief belief, HashMap gridValues, BitSet targetObs) { // 0 for target states if (targetObs.get(belief.so)) { return 0.0; } // Otherwise approximate vie interpolation over grid points else { return interpolateOverGrid(belief, gridValues); } } /** * Approximate the value for a belief {@code belief} by interpolating over values {@code gridValues} * for a representative set of beliefs whose convex hull is the full belief space. */ protected double interpolateOverGrid(Belief belief, HashMap gridValues) { ArrayList subSimplex = new ArrayList<>(); double[] lambdas = new double[belief.bu.length]; getSubSimplexAndLambdas(belief.bu, subSimplex, lambdas, gridResolution); double val = 0; for (int j = 0; j < lambdas.length; j++) { if (lambdas[j] >= 1e-6) { val += lambdas[j] * gridValues.get(new Belief(belief.so, subSimplex.get(j))); } } return val; } /** * Build a (Markov chain) model representing the fragment of the belief MDP induced by an optimal strategy. * The model is stored as an MDP to allow easier attachment of optional actions. * @param pomdp * @param sInit * @param mdpRewards * @param vhash * @param vhash_backUp * @param target * @param min * @param listBeliefs */ protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, int sInit, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup) { // Initialise model/state/rewards storage MDPSimple mdp = new MDPSimple(); IndexedSet exploredBeliefs = new IndexedSet<>(true); LinkedList toBeExploredBeliefs = new LinkedList<>(); BitSet mdpTarget = new BitSet(); StateRewardsSimple stateRewards = new StateRewardsSimple(); // Add initial state Belief initialBelief = Belief.pointDistribution(sInit, pomdp); exploredBeliefs.add(initialBelief); toBeExploredBeliefs.offer(initialBelief); mdp.addState(); mdp.addInitialState(0); // Explore model int src = -1; while (!toBeExploredBeliefs.isEmpty()) { Belief belief = toBeExploredBeliefs.pollFirst(); src++; // Remember if this is a target state if (targetObs.get(belief.so)) { mdpTarget.set(src); } // Only explore "unknown" states if (unknownObs.get(belief.so)) { // Build the belief MDP for this belief state and solve BeliefMDPState beliefMDPState = buildBeliefMDPState(pomdp, mdpRewards, belief); Pair valChoice = backup.apply(belief, beliefMDPState); int chosenActionIndex = valChoice.second; // Build a distribution over successor belief states and add to MDP Distribution distr = new Distribution(); for (Map.Entry entry : beliefMDPState.trans.get(chosenActionIndex).entrySet()) { double nextBeliefProb = entry.getValue(); Belief nextBelief = entry.getKey(); // Add each successor belief to the MDP and the "to explore" set if new if (exploredBeliefs.add(nextBelief)) { toBeExploredBeliefs.add(nextBelief); mdp.addState(); } // Get index of state in state set int dest = exploredBeliefs.getIndexOfLastAdd(); distr.add(dest, nextBeliefProb); } // Add transition distribution, with choice _index_ encoded as action mdp.addActionLabelledChoice(src, distr, pomdp.getActionForObservation(belief.so, chosenActionIndex)); // Store reward too, if needed if (mdpRewards != null) { stateRewards.setStateReward(src, pomdp.getRewardAfterChoice(belief, chosenActionIndex, mdpRewards)); } } } // Attach a label marking target states mdp.addLabel("target", mdpTarget); // Return POMDPStrategyModel psm = new POMDPStrategyModel(); psm.mdp = mdp; psm.mdpRewards = stateRewards; psm.beliefs = new ArrayList<>(); psm.beliefs.addAll(exploredBeliefs.toArrayList()); return psm; } protected ArrayList> assignGPrime(int startIndex, int min, int max, int length) { ArrayList> result = new ArrayList>(); if (startIndex == length - 1) { for (int i = min; i <= max; i++) { ArrayList innerList = new ArrayList<>(); innerList.add(i); result.add(innerList); } } else { for (int i = min; i <= max; i++) { ArrayList> nextResult = assignGPrime(startIndex + 1, 0, i, length); for (ArrayList nextReulstInner : nextResult) { ArrayList innerList = new ArrayList<>(); innerList.add(i); for (Integer a : nextReulstInner) { innerList.add(a); } result.add(innerList); } } } return result; } private ArrayList> fullAssignment(int length, int resolution) { ArrayList> GPrime = assignGPrime(0, resolution, resolution, length); ArrayList> result = new ArrayList>(); for (ArrayList GPrimeInner : GPrime) { ArrayList innerList = new ArrayList<>(); int i; for (i = 0; i < length - 1; i++) { int temp = GPrimeInner.get(i) - GPrimeInner.get(i + 1); innerList.add((double) temp / resolution); } innerList.add((double) GPrimeInner.get(i) / resolution); result.add(innerList); } return result; } private int[] getSortedPermutation(double[] inputArray) { int n = inputArray.length; double[] inputCopy = new double[n]; int[] permutation = new int[n]; int iState = 0, iIteration = 0; int iNonZeroEntry = 0, iZeroEntry = n - 1; boolean bDone = false; for (iState = n - 1; iState >= 0; iState--) { if (inputArray[iState] == 0.0) { inputCopy[iZeroEntry] = 0.0; permutation[iZeroEntry] = iState; iZeroEntry--; } } for (iState = 0; iState < n; iState++) { if (inputArray[iState] != 0.0) { inputCopy[iNonZeroEntry] = inputArray[iState]; permutation[iNonZeroEntry] = iState; iNonZeroEntry++; } } while (!bDone) { bDone = true; for (iState = 0; iState < iNonZeroEntry - iIteration - 1; iState++) { if (inputCopy[iState] < inputCopy[iState + 1]) { swap(inputCopy, iState, iState + 1); swap(permutation, iState, iState + 1); bDone = false; } } iIteration++; } return permutation; } private void swap(int[] aiArray, int i, int j) { int temp = aiArray[i]; aiArray[i] = aiArray[j]; aiArray[j] = temp; } private void swap(double[] aiArray, int i, int j) { double temp = aiArray[i]; aiArray[i] = aiArray[j]; aiArray[j] = temp; } protected boolean getSubSimplexAndLambdas(double[] b, ArrayList subSimplex, double[] lambdas, int resolution) { int n = b.length; int M = resolution; double[] X = new double[n]; int[] V = new int[n]; double[] D = new double[n]; for (int i = 0; i < n; i++) { X[i] = 0; for (int j = i; j < n; j++) { X[i] += M * b[j]; } X[i] = Math.round(X[i] * 1e6) / 1e6; V[i] = (int) Math.floor(X[i]); D[i] = X[i] - V[i]; } int[] P = getSortedPermutation(D); // mainLog.println("X: "+ Arrays.toString(X)); // mainLog.println("V: "+ Arrays.toString(V)); // mainLog.println("D: "+ Arrays.toString(D)); // mainLog.println("P: "+ Arrays.toString(P)); ArrayList Qs = new ArrayList<>(); for (int i = 0; i < n; i++) { int[] Q = new int[n]; if (i == 0) { for (int j = 0; j < n; j++) { Q[j] = V[j]; } Qs.add(Q); } else { for (int j = 0; j < n; j++) { if (j == P[i - 1]) { Q[j] = Qs.get(i - 1)[j] + 1; } else { Q[j] = Qs.get(i - 1)[j]; } } Qs.add(Q); } // mainLog.println(Arrays.toString(Q)); } for (int[] Q : Qs) { double[] node = new double[n]; int i; for (i = 0; i < n - 1; i++) { int temp = Q[i] - Q[i + 1]; node[i] = (double) temp / M; } node[i] = (double) Q[i] / M; subSimplex.add(node); } double sum = 0; for (int i = 1; i < n; i++) { double lambda = D[P[i - 1]] - D[P[i]]; lambdas[i] = lambda; sum = sum + lambda; } lambdas[0] = 1 - sum; for (int i = 0; i < n; i++) { double sum2 = 0; for (int j = 0; j < n; j++) { sum2 += lambdas[j] * subSimplex.get(j)[i]; } // mainLog.println("b["+i+"]: "+b[i]+" b^[i]:"+sum2); if (Math.abs(b[i] - sum2) > 1e-4) { return false; } } return true; } public static boolean isTargetBelief(double[] belief, BitSet target) { double prob=0; for (int i = target.nextSetBit(0); i >= 0; i = target.nextSetBit(i+1)) { prob+=belief[i]; } if(Math.abs(prob-1.0)<1.0e-6) { return true; } return false; } /** * Simple test program. */ public static void main(String args[]) { POMDPModelChecker mc; POMDPSimple pomdp; ModelCheckerResult res; BitSet init, target; Map labels; boolean min = true; try { mc = new POMDPModelChecker(null); MDPSimple mdp = new MDPSimple(); mdp.buildFromPrismExplicit(args[0]); //mainLog.println(mdp); labels = mc.loadLabelsFile(args[1]); //mainLog.println(labels); init = labels.get("init"); target = labels.get(args[2]); if (target == null) throw new PrismException("Unknown label \"" + args[2] + "\""); for (int i = 3; i < args.length; i++) { if (args[i].equals("-min")) min = true; else if (args[i].equals("-max")) min = false; else if (args[i].equals("-nopre")) mc.setPrecomp(false); } pomdp = new POMDPSimple(mdp); res = mc.computeReachRewards(pomdp, null, target, min, null); System.out.println(res.soln[init.nextSetBit(0)]); } catch (PrismException e) { System.out.println(e); } } }