From 1bf07ddbcf7fae29038cbb643218dcd1ca3c9eb1 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Thu, 25 Feb 2021 23:13:50 +0000 Subject: [PATCH] Tidying, refactoring and commenting in POMDP code. --- prism/src/explicit/POMDP.java | 63 ++++-- prism/src/explicit/POMDPModelChecker.java | 224 ++++++++++------------ prism/src/explicit/POMDPSimple.java | 125 ++++++------ 3 files changed, 210 insertions(+), 202 deletions(-) diff --git a/prism/src/explicit/POMDP.java b/prism/src/explicit/POMDP.java index cb5eda1b..5ce6f926 100644 --- a/prism/src/explicit/POMDP.java +++ b/prism/src/explicit/POMDP.java @@ -27,6 +27,7 @@ package explicit; +import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.TreeMap; @@ -38,6 +39,11 @@ import prism.PrismUtils; /** * Interface for classes that provide (read) access to an explicit-state POMDP. + *

+ * POMDPs require that states with the same observation have the same set of + * available actions. Class implementing this interface must further ensure + * that these actions appear in the same order (in terms of choice indexing) + * in each observationally equivalent state. */ public interface POMDP extends MDP, PartiallyObservableModel { @@ -105,46 +111,71 @@ public interface POMDP extends MDP, PartiallyObservableModel // Accessors /** - * Get initial belief state + * Get the initial belief state, as a {@link Belief} object. */ public Belief getInitialBelief(); /** - * Get initial belief state as an distribution over all states (array). + * Get the initial belief state, as an array of probabilities over all states. */ public double[] getInitialBeliefInDist(); /** - * Get the updated belief after action {@code action}. + * Get the belief state (as a {@link Belief} object) + * after taking the {@code i}th choice from belief state {@code belief}. */ - public Belief getBeliefAfterAction(Belief belief, int action); + public Belief getBeliefAfterChoice(Belief belief, int i); /** - * Get the updated belief after action {@code action} using the distribution over all states belief representation. + * Get the belief state (as an array of probabilities over all states) + * after taking the {@code i}th choice from belief state {@code belief}. */ - public double[] getBeliefInDistAfterAction(double[] belief, int action); + public double[] getBeliefInDistAfterChoice(double[] belief, int i); /** - * Get the updated belief after action {@code action} and observation {@code observation}. + * Get the belief state (as a {@link Belief} object) + * after taking the {@code i}th choice from belief state {@code belief} + * and seeing observation {@code o} in the next state. */ - public Belief getBeliefAfterActionAndObservation(Belief belief, int action, int observation); + public Belief getBeliefAfterChoiceAndObservation(Belief belief, int i, int o); /** - * Get the updated belief after action {@code action} and observation {@code observation} using the distribution over all states belief representation. + * Get the belief state (as an array of probabilities over all states) + * after taking the {@code i}th choice from belief state {@code belief} + * and seeing observation {@code o} in the next state. */ - public double[] getBeliefInDistAfterActionAndObservation(double[] belief, int action, int observation); + public double[] getBeliefInDistAfterChoiceAndObservation(double[] belief, int i, int o); /** - * Get the probability of an observation {@code observation}} after action {@code action} from belief {@code belief}. + * Get the probability of seeing observation {@code o} after taking the + * {@code i}th choice from belief state {@code belief}. */ - public double getObservationProbAfterAction(Belief belief, int action, int observation); + public double getObservationProbAfterChoice(Belief belief, int i, int o); - public double getObservationProbAfterAction(double[] belief, int action, int observation); + /** + * Get the probability of seeing observation {@code o} after taking the + * {@code i}th choice from belief state {@code belief}. + * The belief state is given as an array of probabilities over all states. + */ + public double getObservationProbAfterChoice(double[] belief, int i, int o); /** - * Get the cost (reward) of an action {@code action}} from a belief {@code belief}. + * Get the (non-zero) probabilities of seeing each observation after taking the + * {@code i}th choice from belief state {@code belief}. + * The belief state is given as an array of probabilities over all states. */ - public double getCostAfterAction(Belief belief, int action, MDPRewards mdpRewards); + public HashMap computeObservationProbsAfterAction(double[] belief, int i); + + /** + * Get the expected (state and transition) reward value when taking the + * {@code i}th choice from belief state {@code belief}. + */ + public double getRewardAfterChoice(Belief belief, int i, MDPRewards mdpRewards); - public double getCostAfterAction(double[] belief, int action, MDPRewards mdpRewards); + /** + * Get the expected (state and transition) reward value when taking the + * {@code i}th choice from belief state {@code belief}. + * The belief state is given as an array of probabilities over all states. + */ + public double getRewardAfterChoice(double[] belief, int i, MDPRewards mdpRewards); } diff --git a/prism/src/explicit/POMDPModelChecker.java b/prism/src/explicit/POMDPModelChecker.java index a13e21fd..41b2393b 100644 --- a/prism/src/explicit/POMDPModelChecker.java +++ b/prism/src/explicit/POMDPModelChecker.java @@ -28,7 +28,6 @@ package explicit; import java.util.ArrayList; -import java.util.Arrays; import java.util.BitSet; import java.util.Collections; import java.util.HashMap; @@ -36,7 +35,6 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.TreeSet; import explicit.graphviz.Decoration; import explicit.graphviz.Decorator; @@ -45,7 +43,6 @@ import explicit.rewards.MDPRewardsSimple; import prism.Accuracy; import prism.AccuracyFactory; import prism.Pair; -import prism.Accuracy.AccuracyLevel; import prism.PrismComponent; import prism.PrismException; import prism.PrismNotSupportedException; @@ -122,21 +119,19 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target states - LinkedList targetObservs = getAndCheckTargetObservations(pomdp, target); + BitSet targetObs = getAndCheckTargetObservations(pomdp, target); // Initialise the grid points ArrayList gridPoints = new ArrayList<>();//the set of grid points (discretized believes) ArrayList unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes) - initialiseGridPoints(pomdp, targetObservs, gridPoints, unknownGridPoints); + initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints); int unK = unknownGridPoints.size(); mainLog.print("Grid statistics: resolution=" + gridResolution); mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK); - // Construct grid belief "MDP" (over all unknown grid points_) + // Construct grid belief "MDP" (over all unknown grid points) mainLog.println("Building belief space approximation..."); - List>> observationProbs = new ArrayList<>();//memoization for reuse - List>> nextBelieves = new ArrayList<>();//memoization for reuse - buildBeliefMDP(pomdp, unknownGridPoints, observationProbs, nextBelieves); + List>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints); // HashMap for storing real time values for the discretized grid belief states HashMap vhash = new HashMap<>(); @@ -159,26 +154,25 @@ public class POMDPModelChecker extends ProbModelChecker boolean done = false; while (!done && iters < maxIters) { // Iterate over all (unknown) grid points - for (int i = 0; i < unK; i++) { - Belief b = unknownGridPoints.get(i); - int numChoices = pomdp.getNumChoicesForObservation(b.so); + for (int b = 0; b < unK; b++) { + Belief belief = unknownGridPoints.get(b); + int numChoices = pomdp.getNumChoicesForObservation(belief.so); chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; - for (int a = 0; a < numChoices; a++) { + for (int i = 0; i < numChoices; i++) { value = 0; - for (Map.Entry entry : observationProbs.get(i).get(a).entrySet()) { - int o = entry.getKey(); - double observationProb = entry.getValue(); - Belief nextBelief = nextBelieves.get(i).get(a).get(o); + for (Map.Entry entry : beliefMDP.get(b).get(i).entrySet()) { + double nextBeliefProb = entry.getValue(); + Belief nextBelief = entry.getKey(); // find discretized grid points to approximate the nextBelief - value += observationProb * interpolateOverGrid(o, nextBelief, vhash_backUp); + value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp); } if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { chosenValue = value; } } //update V(b) to the chosenValue - vhash.put(b, chosenValue); + vhash.put(belief, chosenValue); } // Check termination done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE); @@ -202,7 +196,7 @@ public class POMDPModelChecker extends ProbModelChecker // Find discretized grid points to approximate the initialBelief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = pomdp.getInitialBelief(); - double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp); + double outerBound = interpolateOverGrid(initialBelief, vhash_backUp); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result @@ -328,29 +322,27 @@ public class POMDPModelChecker extends ProbModelChecker mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")..."); // Find out the observations for the target states - LinkedList targetObservs = getAndCheckTargetObservations(pomdp, target); + BitSet targetObs = getAndCheckTargetObservations(pomdp, target); // Initialise the grid points ArrayList gridPoints = new ArrayList<>();//the set of grid points (discretized believes) ArrayList unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes) - initialiseGridPoints(pomdp, targetObservs, gridPoints, unknownGridPoints); + initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints); int unK = unknownGridPoints.size(); mainLog.print("Grid statistics: resolution=" + gridResolution); mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK); - // Construct grid belief "MDP" (over all unknown grid points_) + // Construct grid belief "MDP" (over all unknown grid points) mainLog.println("Building belief space approximation..."); - List>> observationProbs = new ArrayList<>();// memoization for reuse - List>> nextBelieves = new ArrayList<>();// memoization for reuse - buildBeliefMDP(pomdp, unknownGridPoints, observationProbs, nextBelieves); + List>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints); // Rewards List> rewards = new ArrayList<>(); // memoization for reuse - for (int i = 0; i < unK; i++) { - Belief b = unknownGridPoints.get(i); - int numChoices = pomdp.getNumChoicesForObservation(b.so); + for (int b = 0; b < unK; b++) { + Belief belief = unknownGridPoints.get(b); + int numChoices = pomdp.getNumChoicesForObservation(belief.so); List action_reward = new ArrayList<>();// for memoization - for (int a = 0; a < numChoices; a++) { - action_reward.add(pomdp.getCostAfterAction(b, a, mdpRewards)); // c(a,b) + for (int i = 0; i < numChoices; i++) { + action_reward.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards)); // c(a,b) } rewards.add(action_reward); } @@ -371,25 +363,24 @@ public class POMDPModelChecker extends ProbModelChecker boolean done = false; while (!done && iters < maxIters) { // Iterate over all (unknown) grid points - for (int i = 0; i < unK; i++) { - Belief b = unknownGridPoints.get(i); - int numChoices = pomdp.getNumChoicesForObservation(b.so); + for (int b = 0; b < unK; b++) { + Belief belief = unknownGridPoints.get(b); + int numChoices = pomdp.getNumChoicesForObservation(belief.so); chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; - for (int a = 0; a < numChoices; a++) { - value = rewards.get(i).get(a); - for (Map.Entry entry : observationProbs.get(i).get(a).entrySet()) { - int o = entry.getKey(); - double observationProb = entry.getValue(); - Belief nextBelief = nextBelieves.get(i).get(a).get(o); + for (int i = 0; i < numChoices; i++) { + value = rewards.get(b).get(i); + for (Map.Entry entry : beliefMDP.get(b).get(i).entrySet()) { + double nextBeliefProb = entry.getValue(); + Belief nextBelief = entry.getKey(); // find discretized grid points to approximate the nextBelief - value += observationProb * interpolateOverGrid(o, nextBelief, vhash_backUp); + value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp); } if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) { chosenValue = value; } } //update V(b) to the chosenValue - vhash.put(b, chosenValue); + vhash.put(belief, chosenValue); } // Check termination done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE); @@ -413,7 +404,7 @@ public class POMDPModelChecker extends ProbModelChecker // Find discretized grid points to approximate the initialBelief // Also get (approximate) accuracy of result from value iteration Belief initialBelief = pomdp.getInitialBelief(); - double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp); + double outerBound = interpolateOverGrid(initialBelief, vhash_backUp); double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE); Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE); // Print result @@ -430,7 +421,7 @@ public class POMDPModelChecker extends ProbModelChecker for (int ii = 0; ii < numStates; ii++) { if (mdp.getNumChoices(ii) > 0) { int action = ((Integer) mdp.getAction(ii, 0)); - double rew = pomdp.getCostAfterAction(listBeliefs.get(ii), action, mdpRewards); + double rew = pomdp.getRewardAfterChoice(listBeliefs.get(ii), action, mdpRewards); mdpRewardsNew.addToStateReward(ii, rew); } } @@ -495,37 +486,34 @@ public class POMDPModelChecker extends ProbModelChecker /** * Get a list of target observations from a set of target states - * (both are represented by their indices). + * (both are represented by BitSets over their indices). * Also check that the set of target states corresponds to a set * of observations, and throw an exception if not. */ - protected LinkedList getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException + protected BitSet getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException { // Find observations corresponding to each state in the target - TreeSet targetObservsSet = new TreeSet<>(); + BitSet targetObs = new BitSet(); for (int s = target.nextSetBit(0); s >= 0; s = target.nextSetBit(s + 1)) { - targetObservsSet.add(pomdp.getObservation(s)); + targetObs.set(pomdp.getObservation(s)); } - LinkedList targetObservs = new LinkedList<>(targetObservsSet); - // Rereate the set of target states from the target observations - // and make sure it matches + // Recreate the set of target states from the target observations and make sure it matches BitSet target2 = new BitSet(); int numStates = pomdp.getNumStates(); for (int s = 0; s < numStates; s++) { - if (targetObservs.contains(pomdp.getObservation(s))) { + if (targetObs.get(pomdp.getObservation(s))) { target2.set(s); } } if (!target.equals(target2)) { throw new PrismException("Target is not observable"); } - return targetObservs; + return targetObs; } - protected void initialiseGridPoints(POMDP pomdp, LinkedList targetObservs, ArrayList gridPoints, ArrayList unknownGridPoints) + protected void initialiseGridPoints(POMDP pomdp, BitSet targetObs, ArrayList gridPoints, ArrayList unknownGridPoints) { ArrayList> assignment; - boolean isTargetObserv; int numObservations = pomdp.getNumObservations(); int numUnobservations = pomdp.getNumUnobservations(); int numStates = pomdp.getNumStates(); @@ -537,12 +525,6 @@ public class POMDPModelChecker extends ProbModelChecker } } assignment = fullAssignment(unobservsForObserv.size(), gridResolution); - - isTargetObserv = targetObservs.isEmpty() ? false : ((Integer) targetObservs.peekFirst() == so); - if (isTargetObserv) { - targetObservs.removeFirst(); - } - for (ArrayList inner : assignment) { double[] bu = new double[numUnobservations]; int k = 0; @@ -553,47 +535,57 @@ public class POMDPModelChecker extends ProbModelChecker Belief g = new Belief(so, bu); gridPoints.add(g); - if (!isTargetObserv) { + if (!targetObs.get(so)) { unknownGridPoints.add(g); } } } } - protected void buildBeliefMDP(POMDP pomdp, ArrayList unknownGridPoints, List>> observationProbs, List>> nextBelieves) + /** + * Construct (part of) a belief MDP, just for the set of passed in belief states. + * It is stored as a list (over source beliefs) of lists (over choices) + * of distributions over target beliefs, stored as a hashmap. + */ + protected List>> buildBeliefMDP(POMDP pomdp, List beliefs) { - int unK = unknownGridPoints.size(); - for (int i = 0; i < unK; i++) { - Belief b = unknownGridPoints.get(i); - double[] beliefInDist = b.toDistributionOverStates(pomdp); - //mainLog.println("Belief " + i + ": " + b); - //mainLog.print("Belief dist:"); - //mainLog.println(beliefInDist); - List> action_observation_probs = new ArrayList<>();// for memoization - List> action_observation_Believes = new ArrayList<>();// for memoization - int numChoices = pomdp.getNumChoicesForObservation(b.so); - for (int a = 0; a < numChoices; a++) { - //mainLog.println(i+"/"+unK+", "+a+"/"+numChoices); - HashMap observation_probs = new HashMap<>();// for memoization - HashMap observation_believes = new HashMap<>();// for memoization - ((POMDPSimple) pomdp).computeObservationProbsAfterAction(beliefInDist, a, observation_probs); - for (Map.Entry entry : observation_probs.entrySet()) { - int o = entry.getKey(); - //mainLog.println(i+"/"+unK+", "+a+"/"+numChoices+", "+o+"/"+numObservations); - Belief nextBelief = pomdp.getBeliefAfterActionAndObservation(b, a, o); - //mainLog.print(i + "/" + unK + ", " + a + "/" + numChoices + ", " + o + "/" + numObservations); - //mainLog.println(" - " + entry.getValue() + ":" + nextBelief); - observation_believes.put(o, nextBelief); - } - action_observation_probs.add(observation_probs); - action_observation_Believes.add(observation_believes); + List>> beliefMDP = new ArrayList<>(); + for (Belief belief: beliefs) { + beliefMDP.add(buildBeliefMDPState(pomdp, belief)); + } + return beliefMDP; + } + + /** + * Construct a single single state (belief) of a belief MDP, stored as a + * list (over choices) of distributions over target beliefs, stored as a hashmap. + */ + protected List> buildBeliefMDPState(POMDP pomdp, Belief belief) + { + double[] beliefInDist = belief.toDistributionOverStates(pomdp); + List> beliefMDPState = new ArrayList<>(); + // And for each choice + int numChoices = pomdp.getNumChoicesForObservation(belief.so); + for (int i = 0; i < numChoices; i++) { + // Get successor observations and their probs + HashMap obsProbs = pomdp.computeObservationProbsAfterAction(beliefInDist, i); + HashMap beliefDist = new HashMap<>(); + // Find the belief for each observations + for (Map.Entry entry : obsProbs.entrySet()) { + int o = entry.getKey(); + Belief nextBelief = pomdp.getBeliefAfterChoiceAndObservation(belief, i, o); + beliefDist.put(nextBelief, entry.getValue()); } - observationProbs.add(action_observation_probs); - nextBelieves.add(action_observation_Believes); + beliefMDPState.add(beliefDist); } + return beliefMDPState; } - protected double interpolateOverGrid(int o, Belief belief, HashMap vhash) + /** + * Approximate the value for a belief {@code belief} by interpolating over values {@code gridValues} + * for a representative set of beliefs whose convex hull is the full belief space. + */ + protected double interpolateOverGrid(Belief belief, HashMap gridValues) { ArrayList subSimplex = new ArrayList<>(); double[] lambdas = new double[belief.bu.length]; @@ -602,7 +594,7 @@ public class POMDPModelChecker extends ProbModelChecker double val = 0; for (int j = 0; j < lambdas.length; j++) { if (lambdas[j] >= 1e-6) { - val += lambdas[j] * vhash.get(new Belief(o, subSimplex.get(j))); + val += lambdas[j] * gridValues.get(new Belief(belief.so, subSimplex.get(j))); } } return val; @@ -627,8 +619,9 @@ public class POMDPModelChecker extends ProbModelChecker src++; if (isTargetBelief(b.toDistributionOverStates(pomdp), target)) { mdpTarget.set(src); + } else { + extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp); } - extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, target, mdp); } mdp.addLabel("target", mdpTarget); @@ -824,45 +817,25 @@ public class POMDPModelChecker extends ProbModelChecker * @param beliefList */ protected void extractBestActions(int src, Belief belief, HashMap vhash, POMDP pomdp, MDPRewards mdpRewards, boolean min, - IndexedSet exploredBelieves, LinkedList toBeExploredBelives, BitSet target, MDPSimple mdp) + IndexedSet exploredBelieves, LinkedList toBeExploredBelives, MDPSimple mdp) { - if (isTargetBelief(belief.toDistributionOverStates(pomdp), target)) { - // Add self-loop - /*Distribution distr = new Distribution(); - distr.set(src, 1); - mdp.addActionLabelledChoice(src, distr, null);*/ - return; - } - - double[] beliefInDist = belief.toDistributionOverStates(pomdp); double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; int chosenActionIndex = -1; ArrayList bestActions = new ArrayList<>(); - List action_reward = new ArrayList<>(); - List> action_observation_probs = new ArrayList<>(); - List> action_observation_Believes = new ArrayList<>(); //evaluate each action in b int numChoices = pomdp.getNumChoicesForObservation(belief.so); + List> beliefMDPState = buildBeliefMDPState(pomdp, belief); for (int a = 0; a < numChoices; a++) { double value = 0; if (mdpRewards != null) { - value = pomdp.getCostAfterAction(belief, a, mdpRewards); // c(a,b) + value = pomdp.getRewardAfterChoice(belief, a, mdpRewards); // c(a,b) } - // Build/store successor observations, probabilities and resulting beliefs - HashMap observation_probs = new HashMap<>(); - HashMap observation_believes = new HashMap<>(); - ((POMDPSimple) pomdp).computeObservationProbsAfterAction(beliefInDist, a, observation_probs); - for (Map.Entry entry : observation_probs.entrySet()) { - int o = entry.getKey(); - Belief nextBelief = pomdp.getBeliefAfterActionAndObservation(belief, a, o); - observation_believes.put(o, nextBelief); - double observationProb = observation_probs.get(o); - value += observationProb * interpolateOverGrid(o, nextBelief, vhash); + for (Map.Entry entry : beliefMDPState.get(a).entrySet()) { + double nextBeliefProb = entry.getValue(); + Belief nextBelief = entry.getKey(); + value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash); } - // Store the list of observations, probabilities and resulting beliefs for this action - action_observation_probs.add(observation_probs); - action_observation_Believes.add(observation_believes); - + //select action that minimizes/maximizes Q(a,b), i.e. value if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6))//value entry : action_observation_probs.get(a).entrySet()) { - int o = entry.getKey(); - double observationProb = entry.getValue(); - Belief nextBelief = action_observation_Believes.get(a).get(o); + for (Map.Entry entry : beliefMDPState.get(a).entrySet()) { + double nextBeliefProb = entry.getValue(); + Belief nextBelief = entry.getKey(); if (exploredBelieves.add(nextBelief)) { // If so, add to the explore list toBeExploredBelives.add(nextBelief); @@ -893,7 +865,7 @@ public class POMDPModelChecker extends ProbModelChecker } // Get index of state in state set int dest = exploredBelieves.getIndexOfLastAdd(); - distr.add(dest, observationProb); + distr.add(dest, nextBeliefProb); } } // Add transition distribution, with choice _index_ encoded as action diff --git a/prism/src/explicit/POMDPSimple.java b/prism/src/explicit/POMDPSimple.java index 9563f271..1f3ffdb1 100644 --- a/prism/src/explicit/POMDPSimple.java +++ b/prism/src/explicit/POMDPSimple.java @@ -228,23 +228,29 @@ public class POMDPSimple extends MDPSimple implements POMDP */ public void setObservation(int s, State observ, State unobserv, List observableNames) throws PrismException { + // See if the observation already exists and add it if not int oIndex = observationsList.indexOf(observ); if (oIndex == -1) { + // Add new observation observationsList.add(observ); - observationStates.add(-1); oIndex = observationsList.size() - 1; + // Also extend the observationStates list, to be filled shortly + observationStates.add(-1); } + // Assign the observation (index) to the state try { setObservation(s, oIndex); } catch (PrismException e) { String sObs = observableNames == null ? observ.toString() : observ.toString(observableNames); throw new PrismException("Problem with observation " + sObs + ": " + e.getMessage()); } + // See if the unobservation already exists and add it if not int unobservIndex = unobservationsList.indexOf(unobserv); if (unobservIndex == -1) { unobservationsList.add(unobserv); unobservIndex = unobservationsList.size() - 1; } + // Assign the unobservation (index) to the state unobservablesMap.set(s, unobservIndex); } @@ -381,44 +387,21 @@ public class POMDPSimple extends MDPSimple implements POMDP } @Override - public double getCostAfterAction(Belief belief, int action, MDPRewards mdpRewards) - { - double[] beliefInDist = belief.toDistributionOverStates(this); - double cost = getCostAfterAction(beliefInDist, action, mdpRewards); - return cost; - } - - @Override - public double getCostAfterAction(double[] beliefInDist, int action, MDPRewards mdpRewards) - { - double cost = 0; - for (int i = 0; i < beliefInDist.length; i++) { - if (beliefInDist[i] == 0) { - cost += 0; - } else { - cost += beliefInDist[i] * (mdpRewards.getTransitionReward(i, action) + mdpRewards.getStateReward(i)); - } - - } - return cost; - } - - @Override - public Belief getBeliefAfterAction(Belief belief, int action) + public Belief getBeliefAfterChoice(Belief belief, int i) { double[] beliefInDist = belief.toDistributionOverStates(this); - double[] nextBeliefInDist = getBeliefInDistAfterAction(beliefInDist, action); + double[] nextBeliefInDist = getBeliefInDistAfterChoice(beliefInDist, i); return beliefInDistToBelief(nextBeliefInDist); } @Override - public double[] getBeliefInDistAfterAction(double[] beliefInDist, int action) + public double[] getBeliefInDistAfterChoice(double[] beliefInDist, int i) { int n = beliefInDist.length; double[] nextBeliefInDist = new double[n]; for (int sp = 0; sp < n; sp++) { if (beliefInDist[sp] >= 1.0e-6) { - Distribution distr = getChoice(sp, action); + Distribution distr = getChoice(sp, i); for (Map.Entry e : distr) { int s = (Integer) e.getKey(); double prob = (Double) e.getValue(); @@ -429,72 +412,94 @@ public class POMDPSimple extends MDPSimple implements POMDP return nextBeliefInDist; } + @Override + public Belief getBeliefAfterChoiceAndObservation(Belief belief, int i, int o) + { + double[] beliefInDist = belief.toDistributionOverStates(this); + double[] nextBeliefInDist = getBeliefInDistAfterChoiceAndObservation(beliefInDist, i, o); + Belief nextBelief = beliefInDistToBelief(nextBeliefInDist); + assert(nextBelief.so == o); + return nextBelief; + } + + @Override + public double[] getBeliefInDistAfterChoiceAndObservation(double[] beliefInDist, int i, int o) + { + int n = beliefInDist.length; + double[] nextBelief = new double[n]; + double[] beliefAfterAction = this.getBeliefInDistAfterChoice(beliefInDist, i); + double prob; + for (int s = 0; s < n; s++) { + prob = beliefAfterAction[s] * getObservationProb(s, o); + nextBelief[s] = prob; + } + PrismUtils.normalise(nextBelief); + return nextBelief; + } + @Override // SLOW - public double getObservationProbAfterAction(Belief belief, int action, int observation) + public double getObservationProbAfterChoice(Belief belief, int i, int o) { double[] beliefInDist = belief.toDistributionOverStates(this); - double prob = getObservationProbAfterAction(beliefInDist, action, observation); + double prob = getObservationProbAfterChoice(beliefInDist, i, o); return prob; } @Override // SLOW - public double getObservationProbAfterAction(double[] beliefInDist, int action, int observation) + public double getObservationProbAfterChoice(double[] beliefInDist, int i, int o) { - double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action); + double[] beliefAfterAction = this.getBeliefInDistAfterChoice(beliefInDist, i); int s; double prob = 0; for (s = 0; s < beliefAfterAction.length; s++) { - prob += beliefAfterAction[s] * getObservationProb(s, observation); + prob += beliefAfterAction[s] * getObservationProb(s, o); } return prob; } - public void computeObservationProbsAfterAction(double[] beliefInDist, int action, HashMap observation_probs) + @Override + public HashMap computeObservationProbsAfterAction(double[] belief, int i) { - double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action); + HashMap probs = new HashMap<>(); + double[] beliefAfterAction = this.getBeliefInDistAfterChoice(belief, i); for (int s = 0; s < beliefAfterAction.length; s++) { int o = getObservation(s); double probToAdd = beliefAfterAction[s]; if (probToAdd > 1e-6) { - Double lookup = observation_probs.get(o); - if (lookup == null) - observation_probs.put(o, probToAdd); - else - observation_probs.put(o, lookup + probToAdd); + Double lookup = probs.get(o); + if (lookup == null) { + probs.put(o, probToAdd); + } else { + probs.put(o, lookup + probToAdd); + } } } + return probs; } @Override - public Belief getBeliefAfterActionAndObservation(Belief belief, int action, int observation) + public double getRewardAfterChoice(Belief belief, int i, MDPRewards mdpRewards) { double[] beliefInDist = belief.toDistributionOverStates(this); - double[] nextBeliefInDist = getBeliefInDistAfterActionAndObservation(beliefInDist, action, observation); - Belief nextBelief = beliefInDistToBelief(nextBeliefInDist); - if (nextBelief.so != observation) { - System.err.println(nextBelief.so + "<--" + observation - + " something wrong with POMDPSimple.getBeliefAfterActionAndObservation(Belief belief, int action, int observation)"); - } - return nextBelief; + double cost = getRewardAfterChoice(beliefInDist, i, mdpRewards); + return cost; } @Override - public double[] getBeliefInDistAfterActionAndObservation(double[] beliefInDist, int action, int observation) + public double getRewardAfterChoice(double[] beliefInDist, int i, MDPRewards mdpRewards) { - int n = beliefInDist.length; - double[] nextBelief = new double[n]; - double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action); - int i; - double prob; - for (i = 0; i < n; i++) { - prob = beliefAfterAction[i] * getObservationProb(i, observation); - nextBelief[i] = prob; + double cost = 0; + for (int s = 0; s < beliefInDist.length; s++) { + if (beliefInDist[s] == 0) { + cost += 0; + } else { + cost += beliefInDist[s] * (mdpRewards.getTransitionReward(s, i) + mdpRewards.getStateReward(s)); + } + } - PrismUtils.normalise(nextBelief); - return nextBelief; + return cost; } - // Helpers protected Belief beliefInDistToBelief(double[] beliefInDist)