diff --git a/prism/src/explicit/POMDP.java b/prism/src/explicit/POMDP.java
index cb5eda1b..5ce6f926 100644
--- a/prism/src/explicit/POMDP.java
+++ b/prism/src/explicit/POMDP.java
@@ -27,6 +27,7 @@
package explicit;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
@@ -38,6 +39,11 @@ import prism.PrismUtils;
/**
* Interface for classes that provide (read) access to an explicit-state POMDP.
+ *
+ * POMDPs require that states with the same observation have the same set of
+ * available actions. Class implementing this interface must further ensure
+ * that these actions appear in the same order (in terms of choice indexing)
+ * in each observationally equivalent state.
*/
public interface POMDP extends MDP, PartiallyObservableModel
{
@@ -105,46 +111,71 @@ public interface POMDP extends MDP, PartiallyObservableModel
// Accessors
/**
- * Get initial belief state
+ * Get the initial belief state, as a {@link Belief} object.
*/
public Belief getInitialBelief();
/**
- * Get initial belief state as an distribution over all states (array).
+ * Get the initial belief state, as an array of probabilities over all states.
*/
public double[] getInitialBeliefInDist();
/**
- * Get the updated belief after action {@code action}.
+ * Get the belief state (as a {@link Belief} object)
+ * after taking the {@code i}th choice from belief state {@code belief}.
*/
- public Belief getBeliefAfterAction(Belief belief, int action);
+ public Belief getBeliefAfterChoice(Belief belief, int i);
/**
- * Get the updated belief after action {@code action} using the distribution over all states belief representation.
+ * Get the belief state (as an array of probabilities over all states)
+ * after taking the {@code i}th choice from belief state {@code belief}.
*/
- public double[] getBeliefInDistAfterAction(double[] belief, int action);
+ public double[] getBeliefInDistAfterChoice(double[] belief, int i);
/**
- * Get the updated belief after action {@code action} and observation {@code observation}.
+ * Get the belief state (as a {@link Belief} object)
+ * after taking the {@code i}th choice from belief state {@code belief}
+ * and seeing observation {@code o} in the next state.
*/
- public Belief getBeliefAfterActionAndObservation(Belief belief, int action, int observation);
+ public Belief getBeliefAfterChoiceAndObservation(Belief belief, int i, int o);
/**
- * Get the updated belief after action {@code action} and observation {@code observation} using the distribution over all states belief representation.
+ * Get the belief state (as an array of probabilities over all states)
+ * after taking the {@code i}th choice from belief state {@code belief}
+ * and seeing observation {@code o} in the next state.
*/
- public double[] getBeliefInDistAfterActionAndObservation(double[] belief, int action, int observation);
+ public double[] getBeliefInDistAfterChoiceAndObservation(double[] belief, int i, int o);
/**
- * Get the probability of an observation {@code observation}} after action {@code action} from belief {@code belief}.
+ * Get the probability of seeing observation {@code o} after taking the
+ * {@code i}th choice from belief state {@code belief}.
*/
- public double getObservationProbAfterAction(Belief belief, int action, int observation);
+ public double getObservationProbAfterChoice(Belief belief, int i, int o);
- public double getObservationProbAfterAction(double[] belief, int action, int observation);
+ /**
+ * Get the probability of seeing observation {@code o} after taking the
+ * {@code i}th choice from belief state {@code belief}.
+ * The belief state is given as an array of probabilities over all states.
+ */
+ public double getObservationProbAfterChoice(double[] belief, int i, int o);
/**
- * Get the cost (reward) of an action {@code action}} from a belief {@code belief}.
+ * Get the (non-zero) probabilities of seeing each observation after taking the
+ * {@code i}th choice from belief state {@code belief}.
+ * The belief state is given as an array of probabilities over all states.
*/
- public double getCostAfterAction(Belief belief, int action, MDPRewards mdpRewards);
+ public HashMap computeObservationProbsAfterAction(double[] belief, int i);
+
+ /**
+ * Get the expected (state and transition) reward value when taking the
+ * {@code i}th choice from belief state {@code belief}.
+ */
+ public double getRewardAfterChoice(Belief belief, int i, MDPRewards mdpRewards);
- public double getCostAfterAction(double[] belief, int action, MDPRewards mdpRewards);
+ /**
+ * Get the expected (state and transition) reward value when taking the
+ * {@code i}th choice from belief state {@code belief}.
+ * The belief state is given as an array of probabilities over all states.
+ */
+ public double getRewardAfterChoice(double[] belief, int i, MDPRewards mdpRewards);
}
diff --git a/prism/src/explicit/POMDPModelChecker.java b/prism/src/explicit/POMDPModelChecker.java
index a13e21fd..41b2393b 100644
--- a/prism/src/explicit/POMDPModelChecker.java
+++ b/prism/src/explicit/POMDPModelChecker.java
@@ -28,7 +28,6 @@
package explicit;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
@@ -36,7 +35,6 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.TreeSet;
import explicit.graphviz.Decoration;
import explicit.graphviz.Decorator;
@@ -45,7 +43,6 @@ import explicit.rewards.MDPRewardsSimple;
import prism.Accuracy;
import prism.AccuracyFactory;
import prism.Pair;
-import prism.Accuracy.AccuracyLevel;
import prism.PrismComponent;
import prism.PrismException;
import prism.PrismNotSupportedException;
@@ -122,21 +119,19 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")...");
// Find out the observations for the target states
- LinkedList targetObservs = getAndCheckTargetObservations(pomdp, target);
+ BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
// Initialise the grid points
ArrayList gridPoints = new ArrayList<>();//the set of grid points (discretized believes)
ArrayList unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes)
- initialiseGridPoints(pomdp, targetObservs, gridPoints, unknownGridPoints);
+ initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints);
int unK = unknownGridPoints.size();
mainLog.print("Grid statistics: resolution=" + gridResolution);
mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK);
- // Construct grid belief "MDP" (over all unknown grid points_)
+ // Construct grid belief "MDP" (over all unknown grid points)
mainLog.println("Building belief space approximation...");
- List>> observationProbs = new ArrayList<>();//memoization for reuse
- List>> nextBelieves = new ArrayList<>();//memoization for reuse
- buildBeliefMDP(pomdp, unknownGridPoints, observationProbs, nextBelieves);
+ List>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints);
// HashMap for storing real time values for the discretized grid belief states
HashMap vhash = new HashMap<>();
@@ -159,26 +154,25 @@ public class POMDPModelChecker extends ProbModelChecker
boolean done = false;
while (!done && iters < maxIters) {
// Iterate over all (unknown) grid points
- for (int i = 0; i < unK; i++) {
- Belief b = unknownGridPoints.get(i);
- int numChoices = pomdp.getNumChoicesForObservation(b.so);
+ for (int b = 0; b < unK; b++) {
+ Belief belief = unknownGridPoints.get(b);
+ int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
- for (int a = 0; a < numChoices; a++) {
+ for (int i = 0; i < numChoices; i++) {
value = 0;
- for (Map.Entry entry : observationProbs.get(i).get(a).entrySet()) {
- int o = entry.getKey();
- double observationProb = entry.getValue();
- Belief nextBelief = nextBelieves.get(i).get(a).get(o);
+ for (Map.Entry entry : beliefMDP.get(b).get(i).entrySet()) {
+ double nextBeliefProb = entry.getValue();
+ Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
- value += observationProb * interpolateOverGrid(o, nextBelief, vhash_backUp);
+ value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
}
}
//update V(b) to the chosenValue
- vhash.put(b, chosenValue);
+ vhash.put(belief, chosenValue);
}
// Check termination
done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE);
@@ -202,7 +196,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
- double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp);
+ double outerBound = interpolateOverGrid(initialBelief, vhash_backUp);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@@ -328,29 +322,27 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")...");
// Find out the observations for the target states
- LinkedList targetObservs = getAndCheckTargetObservations(pomdp, target);
+ BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
// Initialise the grid points
ArrayList gridPoints = new ArrayList<>();//the set of grid points (discretized believes)
ArrayList unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes)
- initialiseGridPoints(pomdp, targetObservs, gridPoints, unknownGridPoints);
+ initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints);
int unK = unknownGridPoints.size();
mainLog.print("Grid statistics: resolution=" + gridResolution);
mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK);
- // Construct grid belief "MDP" (over all unknown grid points_)
+ // Construct grid belief "MDP" (over all unknown grid points)
mainLog.println("Building belief space approximation...");
- List>> observationProbs = new ArrayList<>();// memoization for reuse
- List>> nextBelieves = new ArrayList<>();// memoization for reuse
- buildBeliefMDP(pomdp, unknownGridPoints, observationProbs, nextBelieves);
+ List>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints);
// Rewards
List> rewards = new ArrayList<>(); // memoization for reuse
- for (int i = 0; i < unK; i++) {
- Belief b = unknownGridPoints.get(i);
- int numChoices = pomdp.getNumChoicesForObservation(b.so);
+ for (int b = 0; b < unK; b++) {
+ Belief belief = unknownGridPoints.get(b);
+ int numChoices = pomdp.getNumChoicesForObservation(belief.so);
List action_reward = new ArrayList<>();// for memoization
- for (int a = 0; a < numChoices; a++) {
- action_reward.add(pomdp.getCostAfterAction(b, a, mdpRewards)); // c(a,b)
+ for (int i = 0; i < numChoices; i++) {
+ action_reward.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards)); // c(a,b)
}
rewards.add(action_reward);
}
@@ -371,25 +363,24 @@ public class POMDPModelChecker extends ProbModelChecker
boolean done = false;
while (!done && iters < maxIters) {
// Iterate over all (unknown) grid points
- for (int i = 0; i < unK; i++) {
- Belief b = unknownGridPoints.get(i);
- int numChoices = pomdp.getNumChoicesForObservation(b.so);
+ for (int b = 0; b < unK; b++) {
+ Belief belief = unknownGridPoints.get(b);
+ int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
- for (int a = 0; a < numChoices; a++) {
- value = rewards.get(i).get(a);
- for (Map.Entry entry : observationProbs.get(i).get(a).entrySet()) {
- int o = entry.getKey();
- double observationProb = entry.getValue();
- Belief nextBelief = nextBelieves.get(i).get(a).get(o);
+ for (int i = 0; i < numChoices; i++) {
+ value = rewards.get(b).get(i);
+ for (Map.Entry entry : beliefMDP.get(b).get(i).entrySet()) {
+ double nextBeliefProb = entry.getValue();
+ Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
- value += observationProb * interpolateOverGrid(o, nextBelief, vhash_backUp);
+ value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
}
}
//update V(b) to the chosenValue
- vhash.put(b, chosenValue);
+ vhash.put(belief, chosenValue);
}
// Check termination
done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE);
@@ -413,7 +404,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
- double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp);
+ double outerBound = interpolateOverGrid(initialBelief, vhash_backUp);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@@ -430,7 +421,7 @@ public class POMDPModelChecker extends ProbModelChecker
for (int ii = 0; ii < numStates; ii++) {
if (mdp.getNumChoices(ii) > 0) {
int action = ((Integer) mdp.getAction(ii, 0));
- double rew = pomdp.getCostAfterAction(listBeliefs.get(ii), action, mdpRewards);
+ double rew = pomdp.getRewardAfterChoice(listBeliefs.get(ii), action, mdpRewards);
mdpRewardsNew.addToStateReward(ii, rew);
}
}
@@ -495,37 +486,34 @@ public class POMDPModelChecker extends ProbModelChecker
/**
* Get a list of target observations from a set of target states
- * (both are represented by their indices).
+ * (both are represented by BitSets over their indices).
* Also check that the set of target states corresponds to a set
* of observations, and throw an exception if not.
*/
- protected LinkedList getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException
+ protected BitSet getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException
{
// Find observations corresponding to each state in the target
- TreeSet targetObservsSet = new TreeSet<>();
+ BitSet targetObs = new BitSet();
for (int s = target.nextSetBit(0); s >= 0; s = target.nextSetBit(s + 1)) {
- targetObservsSet.add(pomdp.getObservation(s));
+ targetObs.set(pomdp.getObservation(s));
}
- LinkedList targetObservs = new LinkedList<>(targetObservsSet);
- // Rereate the set of target states from the target observations
- // and make sure it matches
+ // Recreate the set of target states from the target observations and make sure it matches
BitSet target2 = new BitSet();
int numStates = pomdp.getNumStates();
for (int s = 0; s < numStates; s++) {
- if (targetObservs.contains(pomdp.getObservation(s))) {
+ if (targetObs.get(pomdp.getObservation(s))) {
target2.set(s);
}
}
if (!target.equals(target2)) {
throw new PrismException("Target is not observable");
}
- return targetObservs;
+ return targetObs;
}
- protected void initialiseGridPoints(POMDP pomdp, LinkedList targetObservs, ArrayList gridPoints, ArrayList unknownGridPoints)
+ protected void initialiseGridPoints(POMDP pomdp, BitSet targetObs, ArrayList gridPoints, ArrayList unknownGridPoints)
{
ArrayList> assignment;
- boolean isTargetObserv;
int numObservations = pomdp.getNumObservations();
int numUnobservations = pomdp.getNumUnobservations();
int numStates = pomdp.getNumStates();
@@ -537,12 +525,6 @@ public class POMDPModelChecker extends ProbModelChecker
}
}
assignment = fullAssignment(unobservsForObserv.size(), gridResolution);
-
- isTargetObserv = targetObservs.isEmpty() ? false : ((Integer) targetObservs.peekFirst() == so);
- if (isTargetObserv) {
- targetObservs.removeFirst();
- }
-
for (ArrayList inner : assignment) {
double[] bu = new double[numUnobservations];
int k = 0;
@@ -553,47 +535,57 @@ public class POMDPModelChecker extends ProbModelChecker
Belief g = new Belief(so, bu);
gridPoints.add(g);
- if (!isTargetObserv) {
+ if (!targetObs.get(so)) {
unknownGridPoints.add(g);
}
}
}
}
- protected void buildBeliefMDP(POMDP pomdp, ArrayList unknownGridPoints, List>> observationProbs, List>> nextBelieves)
+ /**
+ * Construct (part of) a belief MDP, just for the set of passed in belief states.
+ * It is stored as a list (over source beliefs) of lists (over choices)
+ * of distributions over target beliefs, stored as a hashmap.
+ */
+ protected List>> buildBeliefMDP(POMDP pomdp, List beliefs)
{
- int unK = unknownGridPoints.size();
- for (int i = 0; i < unK; i++) {
- Belief b = unknownGridPoints.get(i);
- double[] beliefInDist = b.toDistributionOverStates(pomdp);
- //mainLog.println("Belief " + i + ": " + b);
- //mainLog.print("Belief dist:");
- //mainLog.println(beliefInDist);
- List> action_observation_probs = new ArrayList<>();// for memoization
- List> action_observation_Believes = new ArrayList<>();// for memoization
- int numChoices = pomdp.getNumChoicesForObservation(b.so);
- for (int a = 0; a < numChoices; a++) {
- //mainLog.println(i+"/"+unK+", "+a+"/"+numChoices);
- HashMap observation_probs = new HashMap<>();// for memoization
- HashMap observation_believes = new HashMap<>();// for memoization
- ((POMDPSimple) pomdp).computeObservationProbsAfterAction(beliefInDist, a, observation_probs);
- for (Map.Entry entry : observation_probs.entrySet()) {
- int o = entry.getKey();
- //mainLog.println(i+"/"+unK+", "+a+"/"+numChoices+", "+o+"/"+numObservations);
- Belief nextBelief = pomdp.getBeliefAfterActionAndObservation(b, a, o);
- //mainLog.print(i + "/" + unK + ", " + a + "/" + numChoices + ", " + o + "/" + numObservations);
- //mainLog.println(" - " + entry.getValue() + ":" + nextBelief);
- observation_believes.put(o, nextBelief);
- }
- action_observation_probs.add(observation_probs);
- action_observation_Believes.add(observation_believes);
+ List>> beliefMDP = new ArrayList<>();
+ for (Belief belief: beliefs) {
+ beliefMDP.add(buildBeliefMDPState(pomdp, belief));
+ }
+ return beliefMDP;
+ }
+
+ /**
+ * Construct a single single state (belief) of a belief MDP, stored as a
+ * list (over choices) of distributions over target beliefs, stored as a hashmap.
+ */
+ protected List> buildBeliefMDPState(POMDP pomdp, Belief belief)
+ {
+ double[] beliefInDist = belief.toDistributionOverStates(pomdp);
+ List> beliefMDPState = new ArrayList<>();
+ // And for each choice
+ int numChoices = pomdp.getNumChoicesForObservation(belief.so);
+ for (int i = 0; i < numChoices; i++) {
+ // Get successor observations and their probs
+ HashMap obsProbs = pomdp.computeObservationProbsAfterAction(beliefInDist, i);
+ HashMap beliefDist = new HashMap<>();
+ // Find the belief for each observations
+ for (Map.Entry entry : obsProbs.entrySet()) {
+ int o = entry.getKey();
+ Belief nextBelief = pomdp.getBeliefAfterChoiceAndObservation(belief, i, o);
+ beliefDist.put(nextBelief, entry.getValue());
}
- observationProbs.add(action_observation_probs);
- nextBelieves.add(action_observation_Believes);
+ beliefMDPState.add(beliefDist);
}
+ return beliefMDPState;
}
- protected double interpolateOverGrid(int o, Belief belief, HashMap vhash)
+ /**
+ * Approximate the value for a belief {@code belief} by interpolating over values {@code gridValues}
+ * for a representative set of beliefs whose convex hull is the full belief space.
+ */
+ protected double interpolateOverGrid(Belief belief, HashMap gridValues)
{
ArrayList subSimplex = new ArrayList<>();
double[] lambdas = new double[belief.bu.length];
@@ -602,7 +594,7 @@ public class POMDPModelChecker extends ProbModelChecker
double val = 0;
for (int j = 0; j < lambdas.length; j++) {
if (lambdas[j] >= 1e-6) {
- val += lambdas[j] * vhash.get(new Belief(o, subSimplex.get(j)));
+ val += lambdas[j] * gridValues.get(new Belief(belief.so, subSimplex.get(j)));
}
}
return val;
@@ -627,8 +619,9 @@ public class POMDPModelChecker extends ProbModelChecker
src++;
if (isTargetBelief(b.toDistributionOverStates(pomdp), target)) {
mdpTarget.set(src);
+ } else {
+ extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp);
}
- extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, target, mdp);
}
mdp.addLabel("target", mdpTarget);
@@ -824,45 +817,25 @@ public class POMDPModelChecker extends ProbModelChecker
* @param beliefList
*/
protected void extractBestActions(int src, Belief belief, HashMap vhash, POMDP pomdp, MDPRewards mdpRewards, boolean min,
- IndexedSet exploredBelieves, LinkedList toBeExploredBelives, BitSet target, MDPSimple mdp)
+ IndexedSet exploredBelieves, LinkedList toBeExploredBelives, MDPSimple mdp)
{
- if (isTargetBelief(belief.toDistributionOverStates(pomdp), target)) {
- // Add self-loop
- /*Distribution distr = new Distribution();
- distr.set(src, 1);
- mdp.addActionLabelledChoice(src, distr, null);*/
- return;
- }
-
- double[] beliefInDist = belief.toDistributionOverStates(pomdp);
double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
int chosenActionIndex = -1;
ArrayList bestActions = new ArrayList<>();
- List action_reward = new ArrayList<>();
- List> action_observation_probs = new ArrayList<>();
- List> action_observation_Believes = new ArrayList<>();
//evaluate each action in b
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
+ List> beliefMDPState = buildBeliefMDPState(pomdp, belief);
for (int a = 0; a < numChoices; a++) {
double value = 0;
if (mdpRewards != null) {
- value = pomdp.getCostAfterAction(belief, a, mdpRewards); // c(a,b)
+ value = pomdp.getRewardAfterChoice(belief, a, mdpRewards); // c(a,b)
}
- // Build/store successor observations, probabilities and resulting beliefs
- HashMap observation_probs = new HashMap<>();
- HashMap observation_believes = new HashMap<>();
- ((POMDPSimple) pomdp).computeObservationProbsAfterAction(beliefInDist, a, observation_probs);
- for (Map.Entry entry : observation_probs.entrySet()) {
- int o = entry.getKey();
- Belief nextBelief = pomdp.getBeliefAfterActionAndObservation(belief, a, o);
- observation_believes.put(o, nextBelief);
- double observationProb = observation_probs.get(o);
- value += observationProb * interpolateOverGrid(o, nextBelief, vhash);
+ for (Map.Entry entry : beliefMDPState.get(a).entrySet()) {
+ double nextBeliefProb = entry.getValue();
+ Belief nextBelief = entry.getKey();
+ value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash);
}
- // Store the list of observations, probabilities and resulting beliefs for this action
- action_observation_probs.add(observation_probs);
- action_observation_Believes.add(observation_believes);
-
+
//select action that minimizes/maximizes Q(a,b), i.e. value
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6))//value entry : action_observation_probs.get(a).entrySet()) {
- int o = entry.getKey();
- double observationProb = entry.getValue();
- Belief nextBelief = action_observation_Believes.get(a).get(o);
+ for (Map.Entry entry : beliefMDPState.get(a).entrySet()) {
+ double nextBeliefProb = entry.getValue();
+ Belief nextBelief = entry.getKey();
if (exploredBelieves.add(nextBelief)) {
// If so, add to the explore list
toBeExploredBelives.add(nextBelief);
@@ -893,7 +865,7 @@ public class POMDPModelChecker extends ProbModelChecker
}
// Get index of state in state set
int dest = exploredBelieves.getIndexOfLastAdd();
- distr.add(dest, observationProb);
+ distr.add(dest, nextBeliefProb);
}
}
// Add transition distribution, with choice _index_ encoded as action
diff --git a/prism/src/explicit/POMDPSimple.java b/prism/src/explicit/POMDPSimple.java
index 9563f271..1f3ffdb1 100644
--- a/prism/src/explicit/POMDPSimple.java
+++ b/prism/src/explicit/POMDPSimple.java
@@ -228,23 +228,29 @@ public class POMDPSimple extends MDPSimple implements POMDP
*/
public void setObservation(int s, State observ, State unobserv, List observableNames) throws PrismException
{
+ // See if the observation already exists and add it if not
int oIndex = observationsList.indexOf(observ);
if (oIndex == -1) {
+ // Add new observation
observationsList.add(observ);
- observationStates.add(-1);
oIndex = observationsList.size() - 1;
+ // Also extend the observationStates list, to be filled shortly
+ observationStates.add(-1);
}
+ // Assign the observation (index) to the state
try {
setObservation(s, oIndex);
} catch (PrismException e) {
String sObs = observableNames == null ? observ.toString() : observ.toString(observableNames);
throw new PrismException("Problem with observation " + sObs + ": " + e.getMessage());
}
+ // See if the unobservation already exists and add it if not
int unobservIndex = unobservationsList.indexOf(unobserv);
if (unobservIndex == -1) {
unobservationsList.add(unobserv);
unobservIndex = unobservationsList.size() - 1;
}
+ // Assign the unobservation (index) to the state
unobservablesMap.set(s, unobservIndex);
}
@@ -381,44 +387,21 @@ public class POMDPSimple extends MDPSimple implements POMDP
}
@Override
- public double getCostAfterAction(Belief belief, int action, MDPRewards mdpRewards)
- {
- double[] beliefInDist = belief.toDistributionOverStates(this);
- double cost = getCostAfterAction(beliefInDist, action, mdpRewards);
- return cost;
- }
-
- @Override
- public double getCostAfterAction(double[] beliefInDist, int action, MDPRewards mdpRewards)
- {
- double cost = 0;
- for (int i = 0; i < beliefInDist.length; i++) {
- if (beliefInDist[i] == 0) {
- cost += 0;
- } else {
- cost += beliefInDist[i] * (mdpRewards.getTransitionReward(i, action) + mdpRewards.getStateReward(i));
- }
-
- }
- return cost;
- }
-
- @Override
- public Belief getBeliefAfterAction(Belief belief, int action)
+ public Belief getBeliefAfterChoice(Belief belief, int i)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
- double[] nextBeliefInDist = getBeliefInDistAfterAction(beliefInDist, action);
+ double[] nextBeliefInDist = getBeliefInDistAfterChoice(beliefInDist, i);
return beliefInDistToBelief(nextBeliefInDist);
}
@Override
- public double[] getBeliefInDistAfterAction(double[] beliefInDist, int action)
+ public double[] getBeliefInDistAfterChoice(double[] beliefInDist, int i)
{
int n = beliefInDist.length;
double[] nextBeliefInDist = new double[n];
for (int sp = 0; sp < n; sp++) {
if (beliefInDist[sp] >= 1.0e-6) {
- Distribution distr = getChoice(sp, action);
+ Distribution distr = getChoice(sp, i);
for (Map.Entry e : distr) {
int s = (Integer) e.getKey();
double prob = (Double) e.getValue();
@@ -429,72 +412,94 @@ public class POMDPSimple extends MDPSimple implements POMDP
return nextBeliefInDist;
}
+ @Override
+ public Belief getBeliefAfterChoiceAndObservation(Belief belief, int i, int o)
+ {
+ double[] beliefInDist = belief.toDistributionOverStates(this);
+ double[] nextBeliefInDist = getBeliefInDistAfterChoiceAndObservation(beliefInDist, i, o);
+ Belief nextBelief = beliefInDistToBelief(nextBeliefInDist);
+ assert(nextBelief.so == o);
+ return nextBelief;
+ }
+
+ @Override
+ public double[] getBeliefInDistAfterChoiceAndObservation(double[] beliefInDist, int i, int o)
+ {
+ int n = beliefInDist.length;
+ double[] nextBelief = new double[n];
+ double[] beliefAfterAction = this.getBeliefInDistAfterChoice(beliefInDist, i);
+ double prob;
+ for (int s = 0; s < n; s++) {
+ prob = beliefAfterAction[s] * getObservationProb(s, o);
+ nextBelief[s] = prob;
+ }
+ PrismUtils.normalise(nextBelief);
+ return nextBelief;
+ }
+
@Override // SLOW
- public double getObservationProbAfterAction(Belief belief, int action, int observation)
+ public double getObservationProbAfterChoice(Belief belief, int i, int o)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
- double prob = getObservationProbAfterAction(beliefInDist, action, observation);
+ double prob = getObservationProbAfterChoice(beliefInDist, i, o);
return prob;
}
@Override // SLOW
- public double getObservationProbAfterAction(double[] beliefInDist, int action, int observation)
+ public double getObservationProbAfterChoice(double[] beliefInDist, int i, int o)
{
- double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action);
+ double[] beliefAfterAction = this.getBeliefInDistAfterChoice(beliefInDist, i);
int s;
double prob = 0;
for (s = 0; s < beliefAfterAction.length; s++) {
- prob += beliefAfterAction[s] * getObservationProb(s, observation);
+ prob += beliefAfterAction[s] * getObservationProb(s, o);
}
return prob;
}
- public void computeObservationProbsAfterAction(double[] beliefInDist, int action, HashMap observation_probs)
+ @Override
+ public HashMap computeObservationProbsAfterAction(double[] belief, int i)
{
- double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action);
+ HashMap probs = new HashMap<>();
+ double[] beliefAfterAction = this.getBeliefInDistAfterChoice(belief, i);
for (int s = 0; s < beliefAfterAction.length; s++) {
int o = getObservation(s);
double probToAdd = beliefAfterAction[s];
if (probToAdd > 1e-6) {
- Double lookup = observation_probs.get(o);
- if (lookup == null)
- observation_probs.put(o, probToAdd);
- else
- observation_probs.put(o, lookup + probToAdd);
+ Double lookup = probs.get(o);
+ if (lookup == null) {
+ probs.put(o, probToAdd);
+ } else {
+ probs.put(o, lookup + probToAdd);
+ }
}
}
+ return probs;
}
@Override
- public Belief getBeliefAfterActionAndObservation(Belief belief, int action, int observation)
+ public double getRewardAfterChoice(Belief belief, int i, MDPRewards mdpRewards)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
- double[] nextBeliefInDist = getBeliefInDistAfterActionAndObservation(beliefInDist, action, observation);
- Belief nextBelief = beliefInDistToBelief(nextBeliefInDist);
- if (nextBelief.so != observation) {
- System.err.println(nextBelief.so + "<--" + observation
- + " something wrong with POMDPSimple.getBeliefAfterActionAndObservation(Belief belief, int action, int observation)");
- }
- return nextBelief;
+ double cost = getRewardAfterChoice(beliefInDist, i, mdpRewards);
+ return cost;
}
@Override
- public double[] getBeliefInDistAfterActionAndObservation(double[] beliefInDist, int action, int observation)
+ public double getRewardAfterChoice(double[] beliefInDist, int i, MDPRewards mdpRewards)
{
- int n = beliefInDist.length;
- double[] nextBelief = new double[n];
- double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action);
- int i;
- double prob;
- for (i = 0; i < n; i++) {
- prob = beliefAfterAction[i] * getObservationProb(i, observation);
- nextBelief[i] = prob;
+ double cost = 0;
+ for (int s = 0; s < beliefInDist.length; s++) {
+ if (beliefInDist[s] == 0) {
+ cost += 0;
+ } else {
+ cost += beliefInDist[s] * (mdpRewards.getTransitionReward(s, i) + mdpRewards.getStateReward(s));
+ }
+
}
- PrismUtils.normalise(nextBelief);
- return nextBelief;
+ return cost;
}
-
// Helpers
protected Belief beliefInDistToBelief(double[] beliefInDist)