Browse Source

More refactoring in POMDP solution.

* Separate data structures for BeliefMDPState and POMDPStrategyModel
* Push reward construction/storage into the code for the belief MDP
* Factor out prob/reward backup operations into methods for re-use
* Store value function + backup using functional interfaces
* Collapse (now simpler) buildStrategyModel into one method
accumulation-v4.7
Dave Parker 5 years ago
parent
commit
90c7df8209
  1. 348
      prism/src/explicit/POMDPModelChecker.java

348
prism/src/explicit/POMDPModelChecker.java

@ -35,6 +35,8 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import explicit.graphviz.Decoration;
import explicit.graphviz.Decorator;
@ -53,6 +55,45 @@ import prism.PrismUtils;
*/
public class POMDPModelChecker extends ProbModelChecker
{
// Some local data structures for convenience
/**
* Info for a single state of a belief MDP:
* (1) a list (over choices in the state) of distributions over beliefs, stored as hashmap;
* (2) optionally, a list (over choices in the state) of rewards
*/
class BeliefMDPState
{
public List<HashMap<Belief, Double>> trans;
public List<Double> rewards;
public BeliefMDPState()
{
trans = new ArrayList<>();
rewards = new ArrayList<>();
}
}
/**
* Value backup function for belief state value iteration:
* mapping from a state and its definition (reward + transitions)
* to a pair of the optimal value + choice index.
*/
@FunctionalInterface
interface BeliefMDPBackUp extends BiFunction<Belief, BeliefMDPState, Pair<Double, Integer>> {}
/**
* A model constructed to represent a fragment of a belief MDP induced by a strategy:
* (1) the model (represented as an MDP for ease of storing actions labels)
* (2) optionally, a reward structure
* (3) a list of the beliefs corresponding to each state of the model
*/
class POMDPStrategyModel
{
public MDP mdp;
public MDPRewards mdpRewards;
public List<Belief> beliefs;
}
/**
* Create a new POMDPModelChecker, inherit basic state from parent (unless null).
*/
@ -122,16 +163,12 @@ public class POMDPModelChecker extends ProbModelChecker
// Find out the observations for the target/remain states
// And determine set of observations actually need to perform computation for
BitSet targetObs = null;
try {
targetObs = getObservationsMatchingStates(pomdp, target);
} catch (PrismException e) {
BitSet targetObs = getObservationsMatchingStates(pomdp, target);;
if (targetObs == null) {
throw new PrismException("Target for reachability is not observable");
}
BitSet remainObs = null;
try {
remainObs = remain == null ? null : getObservationsMatchingStates(pomdp, remain);
} catch (PrismException e) {
BitSet remainObs = (remain == null) ? null : getObservationsMatchingStates(pomdp, remain);
if (remain != null && remainObs == null) {
throw new PrismException("Left-hand side of until is not observable");
}
BitSet unknownObs = new BitSet();
@ -146,20 +183,23 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size());
// Construct grid belief "MDP"
mainLog.println("Building belief space approximation...");
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, gridPoints);
List<BeliefMDPState> beliefMDP = buildBeliefMDP(pomdp, null, gridPoints);
// Initialise hashmaps for storing values for the grid belief states
// Initialise hashmaps for storing values for the unknown belief states
HashMap<Belief, Double> vhash = new HashMap<>();
HashMap<Belief, Double> vhash_backUp = new HashMap<>();
for (Belief belief : gridPoints) {
vhash.put(belief, 0.0);
vhash_backUp.put(belief, 0.0);
}
// Define value function for the full set of belief states
Function<Belief, Double> values = belief -> approximateReachProb(belief, vhash_backUp, targetObs, unknownObs);
// Define value backup function
BeliefMDPBackUp backup = (belief, beliefState) -> approximateReachProbBackup(belief, beliefState, values, min);
// Start iterations
mainLog.println("Solving belief space approximation...");
long timer2 = System.currentTimeMillis();
double value, chosenValue;
int iters = 0;
boolean done = false;
while (!done && iters < maxIters) {
@ -167,23 +207,8 @@ public class POMDPModelChecker extends ProbModelChecker
int unK = gridPoints.size();
for (int b = 0; b < unK; b++) {
Belief belief = gridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
for (int i = 0; i < numChoices; i++) {
value = 0;
for (Map.Entry<Belief, Double> entry : beliefMDP.get(b).get(i).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs, unknownObs);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
}
}
//update V(b) to the chosenValue
vhash.put(belief, chosenValue);
Pair<Double, Integer> valChoice = backup.apply(belief, beliefMDP.get(b));
vhash.put(belief, valChoice.first);
}
// Check termination
done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE);
@ -204,10 +229,10 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds.");
// Find discretized grid points to approximate the initialBelief
// Extract (approximate) solution value for the initial belief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs, unknownObs);
double outerBound = values.apply(initialBelief);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@ -215,8 +240,8 @@ public class POMDPModelChecker extends ProbModelChecker
// Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model...");
List<Belief> listBeliefs = new ArrayList<>();
MDP mdp = buildStrategyModel(pomdp, null, vhash, targetObs, unknownObs, min, listBeliefs).mdp;
POMDPStrategyModel psm = buildStrategyModel(pomdp, null, targetObs, unknownObs, backup);
MDP mdp = psm.mdp;
mainLog.print("Strategy-induced model: " + mdp.infoString());
// Export?
if (stratFilename != null) {
@ -227,7 +252,7 @@ public class POMDPModelChecker extends ProbModelChecker
@Override
public Decoration decorateState(int state, Decoration d)
{
d.labelAddBelow(listBeliefs.get(state).toString(pomdp));
d.labelAddBelow(psm.beliefs.get(state).toString(pomdp));
return d;
}
}));
@ -335,10 +360,8 @@ public class POMDPModelChecker extends ProbModelChecker
// Find out the observations for the target states
// And determine set of observations actually need to perform computation for
BitSet targetObs = null;
try {
targetObs = getObservationsMatchingStates(pomdp, target);
} catch (PrismException e) {
BitSet targetObs = getObservationsMatchingStates(pomdp, target);;
if (targetObs == null) {
throw new PrismException("Target for expected reachability is not observable");
}
BitSet unknownObs = new BitSet();
@ -350,55 +373,32 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size());
// Construct grid belief "MDP"
mainLog.println("Building belief space approximation...");
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, gridPoints);
// Rewards
List<List<Double>> rewards = new ArrayList<>(); // memoization for reuse
int unK = gridPoints.size();
for (int b = 0; b < unK; b++) {
Belief belief = gridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
List<Double> action_reward = new ArrayList<>();// for memoization
for (int i = 0; i < numChoices; i++) {
action_reward.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards)); // c(a,b)
}
rewards.add(action_reward);
}
List<BeliefMDPState> beliefMDP = buildBeliefMDP(pomdp, mdpRewards, gridPoints);
// Initialise hashmaps for storing values for the grid belief states
// Initialise hashmaps for storing values for the unknown belief states
HashMap<Belief, Double> vhash = new HashMap<>();
HashMap<Belief, Double> vhash_backUp = new HashMap<>();
for (Belief belief : gridPoints) {
vhash.put(belief, 0.0);
vhash_backUp.put(belief, 0.0);
}
// Define value function for the full set of belief states
Function<Belief, Double> values = belief -> approximateReachReward(belief, vhash_backUp, targetObs);
// Define value backup function
BeliefMDPBackUp backup = (belief, beliefState) -> approximateReachRewardBackup(belief, beliefState, values, min);
// Start iterations
mainLog.println("Solving belief space approximation...");
long timer2 = System.currentTimeMillis();
double value, chosenValue;
int iters = 0;
boolean done = false;
while (!done && iters < maxIters) {
// Iterate over all (unknown) grid points
int unK = gridPoints.size();
for (int b = 0; b < unK; b++) {
Belief belief = gridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
for (int i = 0; i < numChoices; i++) {
value = rewards.get(b).get(i);
for (Map.Entry<Belief, Double> entry : beliefMDP.get(b).get(i).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
value += nextBeliefProb * approximateReachReward(nextBelief, vhash_backUp, targetObs);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
}
}
//update V(b) to the chosenValue
vhash.put(belief, chosenValue);
Pair<Double, Integer> valChoice = backup.apply(belief, beliefMDP.get(b));
vhash.put(belief, valChoice.first);
}
// Check termination
done = PrismUtils.doublesAreClose(vhash, vhash_backUp, termCritParam, termCrit == TermCrit.RELATIVE);
@ -419,10 +419,10 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds.");
// Find discretized grid points to approximate the initialBelief
// Extract (approximate) solution value for the initial belief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = approximateReachReward(initialBelief, vhash_backUp, targetObs);
double outerBound = values.apply(initialBelief);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@ -430,8 +430,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model...");
List<Belief> listBeliefs = new ArrayList<>();
POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, vhash, targetObs, unknownObs, min, listBeliefs);
POMDPStrategyModel psm = buildStrategyModel(pomdp, mdpRewards, targetObs, unknownObs, backup);
MDP mdp = psm.mdp;
MDPRewards mdpRewardsNew = psm.mdpRewards;
mainLog.print("Strategy-induced model: " + mdp.infoString());
@ -444,7 +443,7 @@ public class POMDPModelChecker extends ProbModelChecker
@Override
public Decoration decorateState(int state, Decoration d)
{
d.labelAddBelow(listBeliefs.get(state).toString(pomdp));
d.labelAddBelow(psm.beliefs.get(state).toString(pomdp));
return d;
}
}));
@ -500,9 +499,9 @@ public class POMDPModelChecker extends ProbModelChecker
* The states should correspond exactly to a set of observations,
* i.e., if a state corresponding to an observation is in the set,
* then all other states corresponding to it should also be.
* An exception is thrown if not.
* Returns null if not.
*/
protected BitSet getObservationsMatchingStates(POMDP pomdp, BitSet set) throws PrismException
protected BitSet getObservationsMatchingStates(POMDP pomdp, BitSet set)
{
// Find observations corresponding to each state in the set
BitSet setObs = new BitSet();
@ -518,7 +517,7 @@ public class POMDPModelChecker extends ProbModelChecker
}
}
if (!set.equals(set2)) {
throw new PrismException("Set is not observable");
return null;
}
return setObs;
}
@ -556,43 +555,99 @@ public class POMDPModelChecker extends ProbModelChecker
/**
* Construct (part of) a belief MDP, just for the set of passed in belief states.
* It is stored as a list (over source beliefs) of lists (over choices)
* of distributions over target beliefs, stored as a hashmap.
* If provided, also construct a list of rewards for each state.
* It is stored as a list (over source beliefs) of BeliefMDPState objects.
*/
protected List<List<HashMap<Belief, Double>>> buildBeliefMDP(POMDP pomdp, List<Belief> beliefs)
protected List<BeliefMDPState> buildBeliefMDP(POMDP pomdp, MDPRewards mdpRewards, List<Belief> beliefs)
{
List<List<HashMap<Belief, Double>>> beliefMDP = new ArrayList<>();
List<BeliefMDPState> beliefMDP = new ArrayList<>();
for (Belief belief: beliefs) {
beliefMDP.add(buildBeliefMDPState(pomdp, belief));
beliefMDP.add(buildBeliefMDPState(pomdp, mdpRewards, belief));
}
return beliefMDP;
}
/**
* Construct a single single state (belief) of a belief MDP, stored as a
* list (over choices) of distributions over target beliefs, stored as a hashmap.
* list (over choices) of distributions over target beliefs.
* If provided, also construct a list of rewards for the state.
* It is stored as a BeliefMDPState object.
*/
protected List<HashMap<Belief, Double>> buildBeliefMDPState(POMDP pomdp, Belief belief)
protected BeliefMDPState buildBeliefMDPState(POMDP pomdp, MDPRewards mdpRewards, Belief belief)
{
double[] beliefInDist = belief.toDistributionOverStates(pomdp);
List<HashMap<Belief, Double>> beliefMDPState = new ArrayList<>();
BeliefMDPState beliefMDPState = new BeliefMDPState();
// And for each choice
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
for (int i = 0; i < numChoices; i++) {
// Get successor observations and their probs
HashMap<Integer, Double> obsProbs = pomdp.computeObservationProbsAfterAction(beliefInDist, i);
HashMap<Belief, Double> beliefDist = new HashMap<>();
// Find the belief for each observations
// Find the belief for each observation
for (Map.Entry<Integer, Double> entry : obsProbs.entrySet()) {
int o = entry.getKey();
Belief nextBelief = pomdp.getBeliefAfterChoiceAndObservation(belief, i, o);
beliefDist.put(nextBelief, entry.getValue());
}
beliefMDPState.add(beliefDist);
beliefMDPState.trans.add(beliefDist);
// Store reward too, if required
if (mdpRewards != null) {
beliefMDPState.rewards.add(pomdp.getRewardAfterChoice(belief, i, mdpRewards));
}
}
return beliefMDPState;
}
/**
* Perform a single backup step of (approximate) value iteration for probabilistic reachability
*/
protected Pair<Double, Integer> approximateReachProbBackup(Belief belief, BeliefMDPState beliefMDPState, Function<Belief, Double> values, boolean min)
{
int numChoices = beliefMDPState.trans.size();
double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
int chosenActionIndex = -1;
for (int i = 0; i < numChoices; i++) {
double value = 0;
for (Map.Entry<Belief, Double> entry : beliefMDPState.trans.get(i).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
value += nextBeliefProb * values.apply(nextBelief);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
chosenActionIndex = i;
} else if (Math.abs(value - chosenValue) < 1.0e-6) {
chosenActionIndex = i;
}
}
return new Pair<Double, Integer>(chosenValue, chosenActionIndex);
}
/**
* Perform a single backup step of (approximate) value iteration for reward reachability
*/
protected Pair<Double, Integer> approximateReachRewardBackup(Belief belief, BeliefMDPState beliefMDPState, Function<Belief, Double> values, boolean min)
{
int numChoices = beliefMDPState.trans.size();
double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
int chosenActionIndex = -1;
for (int i = 0; i < numChoices; i++) {
double value = beliefMDPState.rewards.get(i);
for (Map.Entry<Belief, Double> entry : beliefMDPState.trans.get(i).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
value += nextBeliefProb * values.apply(nextBelief);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
chosenActionIndex = i;
} else if (Math.abs(value - chosenValue) < 1.0e-6) {
chosenActionIndex = i;
}
}
return new Pair<Double, Integer>(chosenValue, chosenActionIndex);
}
/**
* Compute the grid-based approximate value for a belief for probabilistic reachability
*/
@ -645,12 +700,6 @@ public class POMDPModelChecker extends ProbModelChecker
return val;
}
class POMDPStrategyModel
{
public MDP mdp;
public MDPRewards mdpRewards;
}
/**
* Build a (Markov chain) model representing the fragment of the belief MDP induced by an optimal strategy.
* The model is stored as an MDP to allow easier attachment of optional actions.
@ -662,112 +711,69 @@ public class POMDPModelChecker extends ProbModelChecker
* @param min
* @param listBeliefs
*/
protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, HashMap<Belief, Double> vhash, BitSet targetObs, BitSet unknownObs, boolean min, List<Belief> listBeliefs)
protected POMDPStrategyModel buildStrategyModel(POMDP pomdp, MDPRewards mdpRewards, BitSet targetObs, BitSet unknownObs, BeliefMDPBackUp backup)
{
// Initialise model/state/rewards storage
MDPSimple mdp = new MDPSimple();
IndexedSet<Belief> exploredBelieves = new IndexedSet<>(true);
LinkedList<Belief> toBeExploredBelives = new LinkedList<>();
IndexedSet<Belief> exploredBeliefs = new IndexedSet<>(true);
LinkedList<Belief> toBeExploredBeliefs = new LinkedList<>();
BitSet mdpTarget = new BitSet();
StateRewardsSimple stateRewards = new StateRewardsSimple();
// Add initial state
Belief initialBelief = pomdp.getInitialBelief();
exploredBelieves.add(initialBelief);
toBeExploredBelives.offer(initialBelief);
exploredBeliefs.add(initialBelief);
toBeExploredBeliefs.offer(initialBelief);
mdp.addState();
mdp.addInitialState(0);
// Explore model
int src = -1;
while (!toBeExploredBelives.isEmpty()) {
Belief b = toBeExploredBelives.pollFirst();
while (!toBeExploredBeliefs.isEmpty()) {
Belief belief = toBeExploredBeliefs.pollFirst();
src++;
// Remember if this is a target state
if (targetObs.get(b.so)) {
if (targetObs.get(belief.so)) {
mdpTarget.set(src);
}
// Only explore "unknown" states
if (unknownObs.get(b.so)) {
extractBestActions(src, b, vhash, targetObs, unknownObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards);
if (unknownObs.get(belief.so)) {
// Build the belief MDP for this belief state and solve
BeliefMDPState beliefMDPState = buildBeliefMDPState(pomdp, mdpRewards, belief);
Pair<Double, Integer> valChoice = backup.apply(belief, beliefMDPState);
int chosenActionIndex = valChoice.second;
// Build a distribution over successor belief states and add to MDP
Distribution distr = new Distribution();
for (Map.Entry<Belief, Double> entry : beliefMDPState.trans.get(chosenActionIndex).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// Add each successor belief to the MDP and the "to explore" set if new
if (exploredBeliefs.add(nextBelief)) {
toBeExploredBeliefs.add(nextBelief);
mdp.addState();
}
// Get index of state in state set
int dest = exploredBeliefs.getIndexOfLastAdd();
distr.add(dest, nextBeliefProb);
}
// Add transition distribution, with choice _index_ encoded as action
mdp.addActionLabelledChoice(src, distr, pomdp.getActionForObservation(belief.so, chosenActionIndex));
// Store reward too, if needed
if (mdpRewards != null) {
stateRewards.setStateReward(src, pomdp.getRewardAfterChoice(belief, chosenActionIndex, mdpRewards));
}
}
}
// Attach a label marking target states
mdp.addLabel("target", mdpTarget);
listBeliefs.addAll(exploredBelieves.toArrayList());
// Return
POMDPStrategyModel psm = new POMDPStrategyModel();
psm.mdp = mdp;
psm.mdpRewards = stateRewards;
psm.beliefs = new ArrayList<>();
psm.beliefs.addAll(exploredBeliefs.toArrayList());
return psm;
}
/**
* Find the best action for this belief state, add the belief state to the list
* of ones examined so far, and store the strategy info. We store this as an MDP.
* @param belief Belief state to examine
* @param vhash
* @param pomdp
* @param mdpRewards
* @param min
* @param beliefList
*/
protected void extractBestActions(int src, Belief belief, HashMap<Belief, Double> vhash, BitSet targetObs, BitSet unknownObs, POMDP pomdp, MDPRewards mdpRewards, boolean min,
IndexedSet<Belief> exploredBelieves, LinkedList<Belief> toBeExploredBelives, MDPSimple mdp, StateRewardsSimple stateRewards)
{
double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
int chosenActionIndex = -1;
//evaluate each action in b
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
List<HashMap<Belief, Double>> beliefMDPState = buildBeliefMDPState(pomdp, belief);
for (int a = 0; a < numChoices; a++) {
double value = 0;
if (mdpRewards != null) {
value = pomdp.getRewardAfterChoice(belief, a, mdpRewards); // c(a,b)
}
for (Map.Entry<Belief, Double> entry : beliefMDPState.get(a).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
if (mdpRewards == null) {
value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs, unknownObs);
} else {
value += nextBeliefProb * approximateReachReward(nextBelief, vhash, targetObs);
}
}
//select action that minimizes/maximizes Q(a,b), i.e. value
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6))//value<bestValue
{
chosenValue = value;
chosenActionIndex = a;
} else if (Math.abs(value - chosenValue) < 1.0e-6)//value==chosenValue
{
//random tie broker
chosenActionIndex = Math.random() < 0.5 ? a : chosenActionIndex;
chosenActionIndex = a;
}
}
// Build a distribution over successor belief states and add to MDP
Distribution distr = new Distribution();
for (Map.Entry<Belief, Double> entry : beliefMDPState.get(chosenActionIndex).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// Add each successor belief to the MDP and the "to explore" set if new
if (exploredBelieves.add(nextBelief)) {
toBeExploredBelives.add(nextBelief);
mdp.addState();
}
// Get index of state in state set
int dest = exploredBelieves.getIndexOfLastAdd();
distr.add(dest, nextBeliefProb);
}
// Add transition distribution, with choice _index_ encoded as action
mdp.addActionLabelledChoice(src, distr, pomdp.getActionForObservation(belief.so, chosenActionIndex));
// Store reward too, if needed
if (mdpRewards != null) {
stateRewards.setStateReward(src, pomdp.getRewardAfterChoice(belief, chosenActionIndex, mdpRewards));
}
}
protected ArrayList<ArrayList<Integer>> assignGPrime(int startIndex, int min, int max, int length)
{
ArrayList<ArrayList<Integer>> result = new ArrayList<ArrayList<Integer>>();

Loading…
Cancel
Save