Browse Source

Tidying, refactoring and commenting in POMDP code.

accumulation-v4.7
Dave Parker 5 years ago
parent
commit
1bf07ddbcf
  1. 63
      prism/src/explicit/POMDP.java
  2. 224
      prism/src/explicit/POMDPModelChecker.java
  3. 125
      prism/src/explicit/POMDPSimple.java

63
prism/src/explicit/POMDP.java

@ -27,6 +27,7 @@
package explicit;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
@ -38,6 +39,11 @@ import prism.PrismUtils;
/**
* Interface for classes that provide (read) access to an explicit-state POMDP.
* <br><br>
* POMDPs require that states with the same observation have the same set of
* available actions. Class implementing this interface must further ensure
* that these actions appear in the same order (in terms of choice indexing)
* in each observationally equivalent state.
*/
public interface POMDP extends MDP, PartiallyObservableModel
{
@ -105,46 +111,71 @@ public interface POMDP extends MDP, PartiallyObservableModel
// Accessors
/**
* Get initial belief state
* Get the initial belief state, as a {@link Belief} object.
*/
public Belief getInitialBelief();
/**
* Get initial belief state as an distribution over all states (array).
* Get the initial belief state, as an array of probabilities over all states.
*/
public double[] getInitialBeliefInDist();
/**
* Get the updated belief after action {@code action}.
* Get the belief state (as a {@link Belief} object)
* after taking the {@code i}th choice from belief state {@code belief}.
*/
public Belief getBeliefAfterAction(Belief belief, int action);
public Belief getBeliefAfterChoice(Belief belief, int i);
/**
* Get the updated belief after action {@code action} using the distribution over all states belief representation.
* Get the belief state (as an array of probabilities over all states)
* after taking the {@code i}th choice from belief state {@code belief}.
*/
public double[] getBeliefInDistAfterAction(double[] belief, int action);
public double[] getBeliefInDistAfterChoice(double[] belief, int i);
/**
* Get the updated belief after action {@code action} and observation {@code observation}.
* Get the belief state (as a {@link Belief} object)
* after taking the {@code i}th choice from belief state {@code belief}
* and seeing observation {@code o} in the next state.
*/
public Belief getBeliefAfterActionAndObservation(Belief belief, int action, int observation);
public Belief getBeliefAfterChoiceAndObservation(Belief belief, int i, int o);
/**
* Get the updated belief after action {@code action} and observation {@code observation} using the distribution over all states belief representation.
* Get the belief state (as an array of probabilities over all states)
* after taking the {@code i}th choice from belief state {@code belief}
* and seeing observation {@code o} in the next state.
*/
public double[] getBeliefInDistAfterActionAndObservation(double[] belief, int action, int observation);
public double[] getBeliefInDistAfterChoiceAndObservation(double[] belief, int i, int o);
/**
* Get the probability of an observation {@code observation}} after action {@code action} from belief {@code belief}.
* Get the probability of seeing observation {@code o} after taking the
* {@code i}th choice from belief state {@code belief}.
*/
public double getObservationProbAfterAction(Belief belief, int action, int observation);
public double getObservationProbAfterChoice(Belief belief, int i, int o);
public double getObservationProbAfterAction(double[] belief, int action, int observation);
/**
* Get the probability of seeing observation {@code o} after taking the
* {@code i}th choice from belief state {@code belief}.
* The belief state is given as an array of probabilities over all states.
*/
public double getObservationProbAfterChoice(double[] belief, int i, int o);
/**
* Get the cost (reward) of an action {@code action}} from a belief {@code belief}.
* Get the (non-zero) probabilities of seeing each observation after taking the
* {@code i}th choice from belief state {@code belief}.
* The belief state is given as an array of probabilities over all states.
*/
public double getCostAfterAction(Belief belief, int action, MDPRewards mdpRewards);
public HashMap<Integer, Double> computeObservationProbsAfterAction(double[] belief, int i);
/**
* Get the expected (state and transition) reward value when taking the
* {@code i}th choice from belief state {@code belief}.
*/
public double getRewardAfterChoice(Belief belief, int i, MDPRewards mdpRewards);
public double getCostAfterAction(double[] belief, int action, MDPRewards mdpRewards);
/**
* Get the expected (state and transition) reward value when taking the
* {@code i}th choice from belief state {@code belief}.
* The belief state is given as an array of probabilities over all states.
*/
public double getRewardAfterChoice(double[] belief, int i, MDPRewards mdpRewards);
}

224
prism/src/explicit/POMDPModelChecker.java

@ -28,7 +28,6 @@
package explicit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
@ -36,7 +35,6 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import explicit.graphviz.Decoration;
import explicit.graphviz.Decorator;
@ -45,7 +43,6 @@ import explicit.rewards.MDPRewardsSimple;
import prism.Accuracy;
import prism.AccuracyFactory;
import prism.Pair;
import prism.Accuracy.AccuracyLevel;
import prism.PrismComponent;
import prism.PrismException;
import prism.PrismNotSupportedException;
@ -122,21 +119,19 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")...");
// Find out the observations for the target states
LinkedList<Integer> targetObservs = getAndCheckTargetObservations(pomdp, target);
BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
// Initialise the grid points
ArrayList<Belief> gridPoints = new ArrayList<>();//the set of grid points (discretized believes)
ArrayList<Belief> unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes)
initialiseGridPoints(pomdp, targetObservs, gridPoints, unknownGridPoints);
initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints);
int unK = unknownGridPoints.size();
mainLog.print("Grid statistics: resolution=" + gridResolution);
mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK);
// Construct grid belief "MDP" (over all unknown grid points_)
// Construct grid belief "MDP" (over all unknown grid points)
mainLog.println("Building belief space approximation...");
List<List<HashMap<Integer, Double>>> observationProbs = new ArrayList<>();//memoization for reuse
List<List<HashMap<Integer, Belief>>> nextBelieves = new ArrayList<>();//memoization for reuse
buildBeliefMDP(pomdp, unknownGridPoints, observationProbs, nextBelieves);
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints);
// HashMap for storing real time values for the discretized grid belief states
HashMap<Belief, Double> vhash = new HashMap<>();
@ -159,26 +154,25 @@ public class POMDPModelChecker extends ProbModelChecker
boolean done = false;
while (!done && iters < maxIters) {
// Iterate over all (unknown) grid points
for (int i = 0; i < unK; i++) {
Belief b = unknownGridPoints.get(i);
int numChoices = pomdp.getNumChoicesForObservation(b.so);
for (int b = 0; b < unK; b++) {
Belief belief = unknownGridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
for (int a = 0; a < numChoices; a++) {
for (int i = 0; i < numChoices; i++) {
value = 0;
for (Map.Entry<Integer, Double> entry : observationProbs.get(i).get(a).entrySet()) {
int o = entry.getKey();
double observationProb = entry.getValue();
Belief nextBelief = nextBelieves.get(i).get(a).get(o);
for (Map.Entry<Belief, Double> entry : beliefMDP.get(b).get(i).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
value += observationProb * interpolateOverGrid(o, nextBelief, vhash_backUp);
value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
}
}
//update V(b) to the chosenValue
vhash.put(b, chosenValue);
vhash.put(belief, chosenValue);
}
// Check termination
done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE);
@ -202,7 +196,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp);
double outerBound = interpolateOverGrid(initialBelief, vhash_backUp);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@ -328,29 +322,27 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")...");
// Find out the observations for the target states
LinkedList<Integer> targetObservs = getAndCheckTargetObservations(pomdp, target);
BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
// Initialise the grid points
ArrayList<Belief> gridPoints = new ArrayList<>();//the set of grid points (discretized believes)
ArrayList<Belief> unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes)
initialiseGridPoints(pomdp, targetObservs, gridPoints, unknownGridPoints);
initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints);
int unK = unknownGridPoints.size();
mainLog.print("Grid statistics: resolution=" + gridResolution);
mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK);
// Construct grid belief "MDP" (over all unknown grid points_)
// Construct grid belief "MDP" (over all unknown grid points)
mainLog.println("Building belief space approximation...");
List<List<HashMap<Integer, Double>>> observationProbs = new ArrayList<>();// memoization for reuse
List<List<HashMap<Integer, Belief>>> nextBelieves = new ArrayList<>();// memoization for reuse
buildBeliefMDP(pomdp, unknownGridPoints, observationProbs, nextBelieves);
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints);
// Rewards
List<List<Double>> rewards = new ArrayList<>(); // memoization for reuse
for (int i = 0; i < unK; i++) {
Belief b = unknownGridPoints.get(i);
int numChoices = pomdp.getNumChoicesForObservation(b.so);
for (int b = 0; b < unK; b++) {
Belief belief = unknownGridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
List<Double> action_reward = new ArrayList<>();// for memoization
for (int a = 0; a < numChoices; a++) {
action_reward.add(pomdp.getCostAfterAction(b, a, mdpRewards)); // c(a,b)
for (int i = 0; i < numChoices; i++) {
action_reward.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards)); // c(a,b)
}
rewards.add(action_reward);
}
@ -371,25 +363,24 @@ public class POMDPModelChecker extends ProbModelChecker
boolean done = false;
while (!done && iters < maxIters) {
// Iterate over all (unknown) grid points
for (int i = 0; i < unK; i++) {
Belief b = unknownGridPoints.get(i);
int numChoices = pomdp.getNumChoicesForObservation(b.so);
for (int b = 0; b < unK; b++) {
Belief belief = unknownGridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
for (int a = 0; a < numChoices; a++) {
value = rewards.get(i).get(a);
for (Map.Entry<Integer, Double> entry : observationProbs.get(i).get(a).entrySet()) {
int o = entry.getKey();
double observationProb = entry.getValue();
Belief nextBelief = nextBelieves.get(i).get(a).get(o);
for (int i = 0; i < numChoices; i++) {
value = rewards.get(b).get(i);
for (Map.Entry<Belief, Double> entry : beliefMDP.get(b).get(i).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
value += observationProb * interpolateOverGrid(o, nextBelief, vhash_backUp);
value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
}
}
//update V(b) to the chosenValue
vhash.put(b, chosenValue);
vhash.put(belief, chosenValue);
}
// Check termination
done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE);
@ -413,7 +404,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp);
double outerBound = interpolateOverGrid(initialBelief, vhash_backUp);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@ -430,7 +421,7 @@ public class POMDPModelChecker extends ProbModelChecker
for (int ii = 0; ii < numStates; ii++) {
if (mdp.getNumChoices(ii) > 0) {
int action = ((Integer) mdp.getAction(ii, 0));
double rew = pomdp.getCostAfterAction(listBeliefs.get(ii), action, mdpRewards);
double rew = pomdp.getRewardAfterChoice(listBeliefs.get(ii), action, mdpRewards);
mdpRewardsNew.addToStateReward(ii, rew);
}
}
@ -495,37 +486,34 @@ public class POMDPModelChecker extends ProbModelChecker
/**
* Get a list of target observations from a set of target states
* (both are represented by their indices).
* (both are represented by BitSets over their indices).
* Also check that the set of target states corresponds to a set
* of observations, and throw an exception if not.
*/
protected LinkedList<Integer> getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException
protected BitSet getAndCheckTargetObservations(POMDP pomdp, BitSet target) throws PrismException
{
// Find observations corresponding to each state in the target
TreeSet<Integer> targetObservsSet = new TreeSet<>();
BitSet targetObs = new BitSet();
for (int s = target.nextSetBit(0); s >= 0; s = target.nextSetBit(s + 1)) {
targetObservsSet.add(pomdp.getObservation(s));
targetObs.set(pomdp.getObservation(s));
}
LinkedList<Integer> targetObservs = new LinkedList<>(targetObservsSet);
// Rereate the set of target states from the target observations
// and make sure it matches
// Recreate the set of target states from the target observations and make sure it matches
BitSet target2 = new BitSet();
int numStates = pomdp.getNumStates();
for (int s = 0; s < numStates; s++) {
if (targetObservs.contains(pomdp.getObservation(s))) {
if (targetObs.get(pomdp.getObservation(s))) {
target2.set(s);
}
}
if (!target.equals(target2)) {
throw new PrismException("Target is not observable");
}
return targetObservs;
return targetObs;
}
protected void initialiseGridPoints(POMDP pomdp, LinkedList<Integer> targetObservs, ArrayList<Belief> gridPoints, ArrayList<Belief> unknownGridPoints)
protected void initialiseGridPoints(POMDP pomdp, BitSet targetObs, ArrayList<Belief> gridPoints, ArrayList<Belief> unknownGridPoints)
{
ArrayList<ArrayList<Double>> assignment;
boolean isTargetObserv;
int numObservations = pomdp.getNumObservations();
int numUnobservations = pomdp.getNumUnobservations();
int numStates = pomdp.getNumStates();
@ -537,12 +525,6 @@ public class POMDPModelChecker extends ProbModelChecker
}
}
assignment = fullAssignment(unobservsForObserv.size(), gridResolution);
isTargetObserv = targetObservs.isEmpty() ? false : ((Integer) targetObservs.peekFirst() == so);
if (isTargetObserv) {
targetObservs.removeFirst();
}
for (ArrayList<Double> inner : assignment) {
double[] bu = new double[numUnobservations];
int k = 0;
@ -553,47 +535,57 @@ public class POMDPModelChecker extends ProbModelChecker
Belief g = new Belief(so, bu);
gridPoints.add(g);
if (!isTargetObserv) {
if (!targetObs.get(so)) {
unknownGridPoints.add(g);
}
}
}
}
protected void buildBeliefMDP(POMDP pomdp, ArrayList<Belief> unknownGridPoints, List<List<HashMap<Integer, Double>>> observationProbs, List<List<HashMap<Integer, Belief>>> nextBelieves)
/**
* Construct (part of) a belief MDP, just for the set of passed in belief states.
* It is stored as a list (over source beliefs) of lists (over choices)
* of distributions over target beliefs, stored as a hashmap.
*/
protected List<List<HashMap<Belief, Double>>> buildBeliefMDP(POMDP pomdp, List<Belief> beliefs)
{
int unK = unknownGridPoints.size();
for (int i = 0; i < unK; i++) {
Belief b = unknownGridPoints.get(i);
double[] beliefInDist = b.toDistributionOverStates(pomdp);
//mainLog.println("Belief " + i + ": " + b);
//mainLog.print("Belief dist:");
//mainLog.println(beliefInDist);
List<HashMap<Integer, Double>> action_observation_probs = new ArrayList<>();// for memoization
List<HashMap<Integer, Belief>> action_observation_Believes = new ArrayList<>();// for memoization
int numChoices = pomdp.getNumChoicesForObservation(b.so);
for (int a = 0; a < numChoices; a++) {
//mainLog.println(i+"/"+unK+", "+a+"/"+numChoices);
HashMap<Integer, Double> observation_probs = new HashMap<>();// for memoization
HashMap<Integer, Belief> observation_believes = new HashMap<>();// for memoization
((POMDPSimple) pomdp).computeObservationProbsAfterAction(beliefInDist, a, observation_probs);
for (Map.Entry<Integer, Double> entry : observation_probs.entrySet()) {
int o = entry.getKey();
//mainLog.println(i+"/"+unK+", "+a+"/"+numChoices+", "+o+"/"+numObservations);
Belief nextBelief = pomdp.getBeliefAfterActionAndObservation(b, a, o);
//mainLog.print(i + "/" + unK + ", " + a + "/" + numChoices + ", " + o + "/" + numObservations);
//mainLog.println(" - " + entry.getValue() + ":" + nextBelief);
observation_believes.put(o, nextBelief);
}
action_observation_probs.add(observation_probs);
action_observation_Believes.add(observation_believes);
List<List<HashMap<Belief, Double>>> beliefMDP = new ArrayList<>();
for (Belief belief: beliefs) {
beliefMDP.add(buildBeliefMDPState(pomdp, belief));
}
return beliefMDP;
}
/**
* Construct a single single state (belief) of a belief MDP, stored as a
* list (over choices) of distributions over target beliefs, stored as a hashmap.
*/
protected List<HashMap<Belief, Double>> buildBeliefMDPState(POMDP pomdp, Belief belief)
{
double[] beliefInDist = belief.toDistributionOverStates(pomdp);
List<HashMap<Belief, Double>> beliefMDPState = new ArrayList<>();
// And for each choice
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
for (int i = 0; i < numChoices; i++) {
// Get successor observations and their probs
HashMap<Integer, Double> obsProbs = pomdp.computeObservationProbsAfterAction(beliefInDist, i);
HashMap<Belief, Double> beliefDist = new HashMap<>();
// Find the belief for each observations
for (Map.Entry<Integer, Double> entry : obsProbs.entrySet()) {
int o = entry.getKey();
Belief nextBelief = pomdp.getBeliefAfterChoiceAndObservation(belief, i, o);
beliefDist.put(nextBelief, entry.getValue());
}
observationProbs.add(action_observation_probs);
nextBelieves.add(action_observation_Believes);
beliefMDPState.add(beliefDist);
}
return beliefMDPState;
}
protected double interpolateOverGrid(int o, Belief belief, HashMap<Belief, Double> vhash)
/**
* Approximate the value for a belief {@code belief} by interpolating over values {@code gridValues}
* for a representative set of beliefs whose convex hull is the full belief space.
*/
protected double interpolateOverGrid(Belief belief, HashMap<Belief, Double> gridValues)
{
ArrayList<double[]> subSimplex = new ArrayList<>();
double[] lambdas = new double[belief.bu.length];
@ -602,7 +594,7 @@ public class POMDPModelChecker extends ProbModelChecker
double val = 0;
for (int j = 0; j < lambdas.length; j++) {
if (lambdas[j] >= 1e-6) {
val += lambdas[j] * vhash.get(new Belief(o, subSimplex.get(j)));
val += lambdas[j] * gridValues.get(new Belief(belief.so, subSimplex.get(j)));
}
}
return val;
@ -627,8 +619,9 @@ public class POMDPModelChecker extends ProbModelChecker
src++;
if (isTargetBelief(b.toDistributionOverStates(pomdp), target)) {
mdpTarget.set(src);
} else {
extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp);
}
extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, target, mdp);
}
mdp.addLabel("target", mdpTarget);
@ -824,45 +817,25 @@ public class POMDPModelChecker extends ProbModelChecker
* @param beliefList
*/
protected void extractBestActions(int src, Belief belief, HashMap<Belief, Double> vhash, POMDP pomdp, MDPRewards mdpRewards, boolean min,
IndexedSet<Belief> exploredBelieves, LinkedList<Belief> toBeExploredBelives, BitSet target, MDPSimple mdp)
IndexedSet<Belief> exploredBelieves, LinkedList<Belief> toBeExploredBelives, MDPSimple mdp)
{
if (isTargetBelief(belief.toDistributionOverStates(pomdp), target)) {
// Add self-loop
/*Distribution distr = new Distribution();
distr.set(src, 1);
mdp.addActionLabelledChoice(src, distr, null);*/
return;
}
double[] beliefInDist = belief.toDistributionOverStates(pomdp);
double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
int chosenActionIndex = -1;
ArrayList<Integer> bestActions = new ArrayList<>();
List<Double> action_reward = new ArrayList<>();
List<HashMap<Integer, Double>> action_observation_probs = new ArrayList<>();
List<HashMap<Integer, Belief>> action_observation_Believes = new ArrayList<>();
//evaluate each action in b
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
List<HashMap<Belief, Double>> beliefMDPState = buildBeliefMDPState(pomdp, belief);
for (int a = 0; a < numChoices; a++) {
double value = 0;
if (mdpRewards != null) {
value = pomdp.getCostAfterAction(belief, a, mdpRewards); // c(a,b)
value = pomdp.getRewardAfterChoice(belief, a, mdpRewards); // c(a,b)
}
// Build/store successor observations, probabilities and resulting beliefs
HashMap<Integer, Double> observation_probs = new HashMap<>();
HashMap<Integer, Belief> observation_believes = new HashMap<>();
((POMDPSimple) pomdp).computeObservationProbsAfterAction(beliefInDist, a, observation_probs);
for (Map.Entry<Integer, Double> entry : observation_probs.entrySet()) {
int o = entry.getKey();
Belief nextBelief = pomdp.getBeliefAfterActionAndObservation(belief, a, o);
observation_believes.put(o, nextBelief);
double observationProb = observation_probs.get(o);
value += observationProb * interpolateOverGrid(o, nextBelief, vhash);
for (Map.Entry<Belief, Double> entry : beliefMDPState.get(a).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash);
}
// Store the list of observations, probabilities and resulting beliefs for this action
action_observation_probs.add(observation_probs);
action_observation_Believes.add(observation_believes);
//select action that minimizes/maximizes Q(a,b), i.e. value
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6))//value<bestValue
{
@ -881,10 +854,9 @@ public class POMDPModelChecker extends ProbModelChecker
Distribution distr = new Distribution();
for (Integer a : bestActions) {
for (Map.Entry<Integer, Double> entry : action_observation_probs.get(a).entrySet()) {
int o = entry.getKey();
double observationProb = entry.getValue();
Belief nextBelief = action_observation_Believes.get(a).get(o);
for (Map.Entry<Belief, Double> entry : beliefMDPState.get(a).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
if (exploredBelieves.add(nextBelief)) {
// If so, add to the explore list
toBeExploredBelives.add(nextBelief);
@ -893,7 +865,7 @@ public class POMDPModelChecker extends ProbModelChecker
}
// Get index of state in state set
int dest = exploredBelieves.getIndexOfLastAdd();
distr.add(dest, observationProb);
distr.add(dest, nextBeliefProb);
}
}
// Add transition distribution, with choice _index_ encoded as action

125
prism/src/explicit/POMDPSimple.java

@ -228,23 +228,29 @@ public class POMDPSimple extends MDPSimple implements POMDP
*/
public void setObservation(int s, State observ, State unobserv, List<String> observableNames) throws PrismException
{
// See if the observation already exists and add it if not
int oIndex = observationsList.indexOf(observ);
if (oIndex == -1) {
// Add new observation
observationsList.add(observ);
observationStates.add(-1);
oIndex = observationsList.size() - 1;
// Also extend the observationStates list, to be filled shortly
observationStates.add(-1);
}
// Assign the observation (index) to the state
try {
setObservation(s, oIndex);
} catch (PrismException e) {
String sObs = observableNames == null ? observ.toString() : observ.toString(observableNames);
throw new PrismException("Problem with observation " + sObs + ": " + e.getMessage());
}
// See if the unobservation already exists and add it if not
int unobservIndex = unobservationsList.indexOf(unobserv);
if (unobservIndex == -1) {
unobservationsList.add(unobserv);
unobservIndex = unobservationsList.size() - 1;
}
// Assign the unobservation (index) to the state
unobservablesMap.set(s, unobservIndex);
}
@ -381,44 +387,21 @@ public class POMDPSimple extends MDPSimple implements POMDP
}
@Override
public double getCostAfterAction(Belief belief, int action, MDPRewards mdpRewards)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
double cost = getCostAfterAction(beliefInDist, action, mdpRewards);
return cost;
}
@Override
public double getCostAfterAction(double[] beliefInDist, int action, MDPRewards mdpRewards)
{
double cost = 0;
for (int i = 0; i < beliefInDist.length; i++) {
if (beliefInDist[i] == 0) {
cost += 0;
} else {
cost += beliefInDist[i] * (mdpRewards.getTransitionReward(i, action) + mdpRewards.getStateReward(i));
}
}
return cost;
}
@Override
public Belief getBeliefAfterAction(Belief belief, int action)
public Belief getBeliefAfterChoice(Belief belief, int i)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
double[] nextBeliefInDist = getBeliefInDistAfterAction(beliefInDist, action);
double[] nextBeliefInDist = getBeliefInDistAfterChoice(beliefInDist, i);
return beliefInDistToBelief(nextBeliefInDist);
}
@Override
public double[] getBeliefInDistAfterAction(double[] beliefInDist, int action)
public double[] getBeliefInDistAfterChoice(double[] beliefInDist, int i)
{
int n = beliefInDist.length;
double[] nextBeliefInDist = new double[n];
for (int sp = 0; sp < n; sp++) {
if (beliefInDist[sp] >= 1.0e-6) {
Distribution distr = getChoice(sp, action);
Distribution distr = getChoice(sp, i);
for (Map.Entry<Integer, Double> e : distr) {
int s = (Integer) e.getKey();
double prob = (Double) e.getValue();
@ -429,72 +412,94 @@ public class POMDPSimple extends MDPSimple implements POMDP
return nextBeliefInDist;
}
@Override
public Belief getBeliefAfterChoiceAndObservation(Belief belief, int i, int o)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
double[] nextBeliefInDist = getBeliefInDistAfterChoiceAndObservation(beliefInDist, i, o);
Belief nextBelief = beliefInDistToBelief(nextBeliefInDist);
assert(nextBelief.so == o);
return nextBelief;
}
@Override
public double[] getBeliefInDistAfterChoiceAndObservation(double[] beliefInDist, int i, int o)
{
int n = beliefInDist.length;
double[] nextBelief = new double[n];
double[] beliefAfterAction = this.getBeliefInDistAfterChoice(beliefInDist, i);
double prob;
for (int s = 0; s < n; s++) {
prob = beliefAfterAction[s] * getObservationProb(s, o);
nextBelief[s] = prob;
}
PrismUtils.normalise(nextBelief);
return nextBelief;
}
@Override // SLOW
public double getObservationProbAfterAction(Belief belief, int action, int observation)
public double getObservationProbAfterChoice(Belief belief, int i, int o)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
double prob = getObservationProbAfterAction(beliefInDist, action, observation);
double prob = getObservationProbAfterChoice(beliefInDist, i, o);
return prob;
}
@Override // SLOW
public double getObservationProbAfterAction(double[] beliefInDist, int action, int observation)
public double getObservationProbAfterChoice(double[] beliefInDist, int i, int o)
{
double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action);
double[] beliefAfterAction = this.getBeliefInDistAfterChoice(beliefInDist, i);
int s;
double prob = 0;
for (s = 0; s < beliefAfterAction.length; s++) {
prob += beliefAfterAction[s] * getObservationProb(s, observation);
prob += beliefAfterAction[s] * getObservationProb(s, o);
}
return prob;
}
public void computeObservationProbsAfterAction(double[] beliefInDist, int action, HashMap<Integer, Double> observation_probs)
@Override
public HashMap<Integer, Double> computeObservationProbsAfterAction(double[] belief, int i)
{
double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action);
HashMap<Integer, Double> probs = new HashMap<>();
double[] beliefAfterAction = this.getBeliefInDistAfterChoice(belief, i);
for (int s = 0; s < beliefAfterAction.length; s++) {
int o = getObservation(s);
double probToAdd = beliefAfterAction[s];
if (probToAdd > 1e-6) {
Double lookup = observation_probs.get(o);
if (lookup == null)
observation_probs.put(o, probToAdd);
else
observation_probs.put(o, lookup + probToAdd);
Double lookup = probs.get(o);
if (lookup == null) {
probs.put(o, probToAdd);
} else {
probs.put(o, lookup + probToAdd);
}
}
}
return probs;
}
@Override
public Belief getBeliefAfterActionAndObservation(Belief belief, int action, int observation)
public double getRewardAfterChoice(Belief belief, int i, MDPRewards mdpRewards)
{
double[] beliefInDist = belief.toDistributionOverStates(this);
double[] nextBeliefInDist = getBeliefInDistAfterActionAndObservation(beliefInDist, action, observation);
Belief nextBelief = beliefInDistToBelief(nextBeliefInDist);
if (nextBelief.so != observation) {
System.err.println(nextBelief.so + "<--" + observation
+ " something wrong with POMDPSimple.getBeliefAfterActionAndObservation(Belief belief, int action, int observation)");
}
return nextBelief;
double cost = getRewardAfterChoice(beliefInDist, i, mdpRewards);
return cost;
}
@Override
public double[] getBeliefInDistAfterActionAndObservation(double[] beliefInDist, int action, int observation)
public double getRewardAfterChoice(double[] beliefInDist, int i, MDPRewards mdpRewards)
{
int n = beliefInDist.length;
double[] nextBelief = new double[n];
double[] beliefAfterAction = this.getBeliefInDistAfterAction(beliefInDist, action);
int i;
double prob;
for (i = 0; i < n; i++) {
prob = beliefAfterAction[i] * getObservationProb(i, observation);
nextBelief[i] = prob;
double cost = 0;
for (int s = 0; s < beliefInDist.length; s++) {
if (beliefInDist[s] == 0) {
cost += 0;
} else {
cost += beliefInDist[s] * (mdpRewards.getTransitionReward(s, i) + mdpRewards.getStateReward(s));
}
}
PrismUtils.normalise(nextBelief);
return nextBelief;
return cost;
}
// Helpers
protected Belief beliefInDistToBelief(double[] beliefInDist)

Loading…
Cancel
Save