Browse Source

Refactor and improve POMDP approximate value iteration.

Grid points are only used for unknown (non-target) belief states.
accumulation-v4.7
Dave Parker 5 years ago
parent
commit
bc23e4e354
  1. 139
      prism/src/explicit/POMDPModelChecker.java

139
prism/src/explicit/POMDPModelChecker.java

@ -119,31 +119,25 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")...");
// Find out the observations for the target states
// And determine set of observations actually need to perform computation for
BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
// Initialise the grid points
ArrayList<Belief> gridPoints = new ArrayList<>();//the set of grid points (discretized believes)
ArrayList<Belief> unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes)
initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints);
int unK = unknownGridPoints.size();
mainLog.print("Grid statistics: resolution=" + gridResolution);
mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK);
// Construct grid belief "MDP" (over all unknown grid points)
BitSet unknownObs = new BitSet();
unknownObs.set(0, pomdp.getNumObservations());
unknownObs.andNot(targetObs);
// Initialise the grid points (just for unknown beliefs)
List<Belief> gridPoints = initialiseGridPoints(pomdp, unknownObs);
mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size());
// Construct grid belief "MDP"
mainLog.println("Building belief space approximation...");
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints);
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, gridPoints);
// HashMap for storing real time values for the discretized grid belief states
// Initialise hashmaps for storing values for the grid belief states
HashMap<Belief, Double> vhash = new HashMap<>();
HashMap<Belief, Double> vhash_backUp = new HashMap<>();
for (Belief g : gridPoints) {
if (unknownGridPoints.contains(g)) {
vhash.put(g, 0.0);
vhash_backUp.put(g, 0.0);
} else {
vhash.put(g, 1.0);
vhash_backUp.put(g, 1.0);
}
for (Belief belief : gridPoints) {
vhash.put(belief, 0.0);
vhash_backUp.put(belief, 0.0);
}
// Start iterations
@ -154,8 +148,9 @@ public class POMDPModelChecker extends ProbModelChecker
boolean done = false;
while (!done && iters < maxIters) {
// Iterate over all (unknown) grid points
int unK = gridPoints.size();
for (int b = 0; b < unK; b++) {
Belief belief = unknownGridPoints.get(b);
Belief belief = gridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
@ -165,7 +160,7 @@ public class POMDPModelChecker extends ProbModelChecker
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp);
value += nextBeliefProb * approximateReachProb(nextBelief, vhash_backUp, targetObs);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
@ -196,7 +191,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = interpolateOverGrid(initialBelief, vhash_backUp);
double outerBound = approximateReachProb(initialBelief, vhash_backUp, targetObs);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@ -322,23 +317,24 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.println("Starting fixed-resolution grid approximation (" + (min ? "min" : "max") + ")...");
// Find out the observations for the target states
// And determine set of observations actually need to perform computation for
BitSet targetObs = getAndCheckTargetObservations(pomdp, target);
// Initialise the grid points
ArrayList<Belief> gridPoints = new ArrayList<>();//the set of grid points (discretized believes)
ArrayList<Belief> unknownGridPoints = new ArrayList<>();//the set of unknown grid points (discretized believes)
initialiseGridPoints(pomdp, targetObs, gridPoints, unknownGridPoints);
int unK = unknownGridPoints.size();
mainLog.print("Grid statistics: resolution=" + gridResolution);
mainLog.println(", points=" + gridPoints.size() + ", unknown points=" + unK);
// Construct grid belief "MDP" (over all unknown grid points)
BitSet unknownObs = new BitSet();
unknownObs.set(0, pomdp.getNumObservations());
unknownObs.andNot(targetObs);
// Initialise the grid points (just for unknown beliefs)
List<Belief> gridPoints = initialiseGridPoints(pomdp, unknownObs);
mainLog.println("Grid statistics: resolution=" + gridResolution + ", points=" + gridPoints.size());
// Construct grid belief "MDP"
mainLog.println("Building belief space approximation...");
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, unknownGridPoints);
List<List<HashMap<Belief, Double>>> beliefMDP = buildBeliefMDP(pomdp, gridPoints);
// Rewards
List<List<Double>> rewards = new ArrayList<>(); // memoization for reuse
int unK = gridPoints.size();
for (int b = 0; b < unK; b++) {
Belief belief = unknownGridPoints.get(b);
Belief belief = gridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
List<Double> action_reward = new ArrayList<>();// for memoization
for (int i = 0; i < numChoices; i++) {
@ -347,12 +343,12 @@ public class POMDPModelChecker extends ProbModelChecker
rewards.add(action_reward);
}
// HashMap for storing real time values for the discretized grid belief states
// Initialise hashmaps for storing values for the grid belief states
HashMap<Belief, Double> vhash = new HashMap<>();
HashMap<Belief, Double> vhash_backUp = new HashMap<>();
for (Belief g : gridPoints) {
vhash.put(g, 0.0);
vhash_backUp.put(g, 0.0);
for (Belief belief : gridPoints) {
vhash.put(belief, 0.0);
vhash_backUp.put(belief, 0.0);
}
// Start iterations
@ -364,7 +360,7 @@ public class POMDPModelChecker extends ProbModelChecker
while (!done && iters < maxIters) {
// Iterate over all (unknown) grid points
for (int b = 0; b < unK; b++) {
Belief belief = unknownGridPoints.get(b);
Belief belief = gridPoints.get(b);
int numChoices = pomdp.getNumChoicesForObservation(belief.so);
chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
for (int i = 0; i < numChoices; i++) {
@ -373,7 +369,7 @@ public class POMDPModelChecker extends ProbModelChecker
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
// find discretized grid points to approximate the nextBelief
value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash_backUp);
value += nextBeliefProb * approximateReachReward(nextBelief, vhash_backUp, targetObs);
}
if ((min && chosenValue - value > 1.0e-6) || (!min && value - chosenValue > 1.0e-6)) {
chosenValue = value;
@ -404,7 +400,7 @@ public class POMDPModelChecker extends ProbModelChecker
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = interpolateOverGrid(initialBelief, vhash_backUp);
double outerBound = approximateReachReward(initialBelief, vhash_backUp, targetObs);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
@ -503,13 +499,17 @@ public class POMDPModelChecker extends ProbModelChecker
return targetObs;
}
protected void initialiseGridPoints(POMDP pomdp, BitSet targetObs, ArrayList<Belief> gridPoints, ArrayList<Belief> unknownGridPoints)
/**
* Construct a list of beliefs for a grid-based approximation of the belief space.
* Only beliefs with observable values from {@code unknownObs) are added.
*/
protected List<Belief> initialiseGridPoints(POMDP pomdp, BitSet unknownObs)
{
List<Belief> gridPoints = new ArrayList<>();
ArrayList<ArrayList<Double>> assignment;
int numObservations = pomdp.getNumObservations();
int numUnobservations = pomdp.getNumUnobservations();
int numStates = pomdp.getNumStates();
for (int so = 0; so < numObservations; so++) {
for (int so = unknownObs.nextSetBit(0); so >= 0; so = unknownObs.nextSetBit(so + 1)) {
ArrayList<Integer> unobservsForObserv = new ArrayList<>();
for (int s = 0; s < numStates; s++) {
if (so == pomdp.getObservation(s)) {
@ -524,14 +524,10 @@ public class POMDPModelChecker extends ProbModelChecker
bu[unobservForObserv] = inner.get(k);
k++;
}
Belief g = new Belief(so, bu);
gridPoints.add(g);
if (!targetObs.get(so)) {
unknownGridPoints.add(g);
}
gridPoints.add(new Belief(so, bu));
}
}
return gridPoints;
}
/**
@ -573,6 +569,36 @@ public class POMDPModelChecker extends ProbModelChecker
return beliefMDPState;
}
/**
* Compute the grid-based approximate value for a belief for probabilistic reachability
*/
protected double approximateReachProb(Belief belief, HashMap<Belief, Double> gridValues, BitSet targetObs)
{
// 1 for target states
if (targetObs.get(belief.so)) {
return 1.0;
}
// Otherwise approximate vie interpolation over grid points
else {
return interpolateOverGrid(belief, gridValues);
}
}
/**
* Compute the grid-based approximate value for a belief for reward reachability
*/
protected double approximateReachReward(Belief belief, HashMap<Belief, Double> gridValues, BitSet targetObs)
{
// 0 for target states
if (targetObs.get(belief.so)) {
return 0.0;
}
// Otherwise approximate vie interpolation over grid points
else {
return interpolateOverGrid(belief, gridValues);
}
}
/**
* Approximate the value for a belief {@code belief} by interpolating over values {@code gridValues}
* for a representative set of beliefs whose convex hull is the full belief space.
@ -582,7 +608,6 @@ public class POMDPModelChecker extends ProbModelChecker
ArrayList<double[]> subSimplex = new ArrayList<>();
double[] lambdas = new double[belief.bu.length];
getSubSimplexAndLambdas(belief.bu, subSimplex, lambdas, gridResolution);
//calculate the approximate value for the belief
double val = 0;
for (int j = 0; j < lambdas.length; j++) {
if (lambdas[j] >= 1e-6) {
@ -633,7 +658,7 @@ public class POMDPModelChecker extends ProbModelChecker
if (targetObs.get(b.so)) {
mdpTarget.set(src);
} else {
extractBestActions(src, b, vhash, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards);
extractBestActions(src, b, vhash, targetObs, pomdp, mdpRewards, min, exploredBelieves, toBeExploredBelives, mdp, stateRewards);
}
}
// Attach a label marking target states
@ -656,7 +681,7 @@ public class POMDPModelChecker extends ProbModelChecker
* @param min
* @param beliefList
*/
protected void extractBestActions(int src, Belief belief, HashMap<Belief, Double> vhash, POMDP pomdp, MDPRewards mdpRewards, boolean min,
protected void extractBestActions(int src, Belief belief, HashMap<Belief, Double> vhash, BitSet targetObs, POMDP pomdp, MDPRewards mdpRewards, boolean min,
IndexedSet<Belief> exploredBelieves, LinkedList<Belief> toBeExploredBelives, MDPSimple mdp, StateRewardsSimple stateRewards)
{
double chosenValue = min ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
@ -672,7 +697,11 @@ public class POMDPModelChecker extends ProbModelChecker
for (Map.Entry<Belief, Double> entry : beliefMDPState.get(a).entrySet()) {
double nextBeliefProb = entry.getValue();
Belief nextBelief = entry.getKey();
value += nextBeliefProb * interpolateOverGrid(nextBelief, vhash);
if (mdpRewards == null) {
value += nextBeliefProb * approximateReachProb(nextBelief, vhash, targetObs);
} else {
value += nextBeliefProb * approximateReachReward(nextBelief, vhash, targetObs);
}
}
//select action that minimizes/maximizes Q(a,b), i.e. value

Loading…
Cancel
Save