Browse Source

Improved accuracy reporting for POMDPs (individual bounds and final result).

accumulation-v4.7
Dave Parker 5 years ago
parent
commit
c4f5a5867f
  1. 73
      prism/src/explicit/POMDPModelChecker.java
  2. 43
      prism/src/prism/AccuracyFactory.java
  3. 29
      prism/src/prism/PrismUtils.java

73
prism/src/explicit/POMDPModelChecker.java

@ -38,8 +38,9 @@ import java.util.TreeSet;
import explicit.rewards.MDPRewards;
import explicit.rewards.MDPRewardsSimple;
import parser.Observation;
import prism.Accuracy;
import prism.AccuracyFactory;
import prism.Pair;
import prism.Accuracy.AccuracyLevel;
import prism.PrismComponent;
import prism.PrismException;
@ -193,11 +194,15 @@ public class POMDPModelChecker extends ProbModelChecker
timer2 = System.currentTimeMillis() - timer2;
mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds.");
// find discretized grid points to approximate the initialBelief
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp);
mainLog.println("Outer bound: " + outerBound);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
mainLog.println("Outer bound: " + outerBound + " (" + outerBoundAcc.toString(outerBound) + ")");
// Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model...");
@ -219,26 +224,38 @@ public class POMDPModelChecker extends ProbModelChecker
// Solve MDP to get inner bound
ModelCheckerResult mcRes = mcMDP.computeReachProbs(mdp, mdp.getLabelStates("target"), true);
double innerBound = mcRes.soln[0];
mainLog.println("Inner bound: " + innerBound);
Accuracy innerBoundAcc = mcRes.accuracy;
// Print result
String innerBoundStr = "" + innerBound;
if (innerBoundAcc != null) {
innerBoundStr += " (" + innerBoundAcc.toString(innerBound) + ")";
}
mainLog.println("Inner bound: " + innerBoundStr);
// Finished fixed-resolution grid approximation
timer = System.currentTimeMillis() - timer;
mainLog.print("\nFixed-resolution grid approximation (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + timer / 1000.0 + " seconds.");
// Extract Store result
double lowerBound = Math.min(innerBound, outerBound);
double upperBound = Math.max(innerBound, outerBound);
mainLog.println("Result bounds: [" + lowerBound + "," + upperBound + "]");
// Extract and store result
Pair<Double,Accuracy> resultValAndAcc;
if (min) {
resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(outerBound, outerBoundAcc, innerBound, innerBoundAcc);
} else {
resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(innerBound, innerBoundAcc, outerBound, outerBoundAcc);
}
double resultVal = resultValAndAcc.first;
Accuracy resultAcc = resultValAndAcc.second;
mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]");
double soln[] = new double[pomdp.getNumStates()];
for (int initialState : pomdp.getInitialStates()) {
soln[initialState] = (lowerBound + upperBound) / 2.0;
soln[initialState] = resultVal;
}
// Return results
ModelCheckerResult res = new ModelCheckerResult();
res.soln = soln;
res.accuracy = new Accuracy(AccuracyLevel.BOUNDED, (upperBound - lowerBound) / 2.0);
res.accuracy = resultAcc;
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;
@ -383,10 +400,14 @@ public class POMDPModelChecker extends ProbModelChecker
mainLog.print("Belief space value iteration (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + iters + " iterations and " + timer2 / 1000.0 + " seconds.");
// find discretized grid points to approximate the initialBelief
// Find discretized grid points to approximate the initialBelief
// Also get (approximate) accuracy of result from value iteration
Belief initialBelief = pomdp.getInitialBelief();
double outerBound = interpolateOverGrid(initialBelief.so, initialBelief, vhash_backUp);
mainLog.println("Outer bound: " + outerBound);
double outerBoundMaxDiff = PrismUtils.measureSupNorm(vhash, vhash_backUp, termCrit == TermCrit.RELATIVE);
Accuracy outerBoundAcc = AccuracyFactory.valueIteration(termCritParam, outerBoundMaxDiff, termCrit == TermCrit.RELATIVE);
// Print result
mainLog.println("Outer bound: " + outerBound + " (" + outerBoundAcc.toString(outerBound) + ")");
// Build DTMC to get inner bound (and strategy)
mainLog.println("\nBuilding strategy-induced model...");
@ -419,26 +440,38 @@ public class POMDPModelChecker extends ProbModelChecker
// Solve MDP to get inner bound
ModelCheckerResult mcRes = mcMDP.computeReachRewards(mdp, mdpRewardsNew, mdp.getLabelStates("target"), true);
double innerBound = mcRes.soln[0];
mainLog.println("Inner bound: " + innerBound);
Accuracy innerBoundAcc = mcRes.accuracy;
// Print result
String innerBoundStr = "" + innerBound;
if (innerBoundAcc != null) {
innerBoundStr += " (" + innerBoundAcc.toString(innerBound) + ")";
}
mainLog.println("Inner bound: " + innerBoundStr);
// Finished fixed-resolution grid approximation
timer = System.currentTimeMillis() - timer;
mainLog.print("\nFixed-resolution grid approximation (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + timer / 1000.0 + " seconds.");
// Extract Store result
double lowerBound = Math.min(innerBound, outerBound);
double upperBound = Math.max(innerBound, outerBound);
mainLog.println("Result bounds: [" + lowerBound + "," + upperBound + "]");
// Extract and store result
Pair<Double,Accuracy> resultValAndAcc;
if (min) {
resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(outerBound, outerBoundAcc, innerBound, innerBoundAcc);
} else {
resultValAndAcc = AccuracyFactory.valueAndAccuracyFromInterval(innerBound, innerBoundAcc, outerBound, outerBoundAcc);
}
double resultVal = resultValAndAcc.first;
Accuracy resultAcc = resultValAndAcc.second;
mainLog.println("Result bounds: [" + resultAcc.getResultLowerBound(resultVal) + "," + resultAcc.getResultUpperBound(resultVal) + "]");
double soln[] = new double[pomdp.getNumStates()];
for (int initialState : pomdp.getInitialStates()) {
soln[initialState] = (lowerBound + upperBound) / 2.0;
soln[initialState] = resultVal;
}
// Return results
ModelCheckerResult res = new ModelCheckerResult();
res.soln = soln;
res.accuracy = new Accuracy(AccuracyLevel.BOUNDED, (upperBound - lowerBound) / 2.0);
res.accuracy = resultAcc;
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;

43
prism/src/prism/AccuracyFactory.java

@ -92,4 +92,47 @@ public class AccuracyFactory
{
return new Accuracy(AccuracyLevel.EXACT_FLOATING_POINT);
}
/**
* Create a pair of a (double) value and an associated {@link Accuracy} object
* representing an interval of values, specified by its lower and upper bounds.
* Optionally, the accuracy of each bound can be specified.
* @param loVal Lower bound
* @param loAcc Lower bound accuracy
* @param hiVal Upper bound
* @param hiAcc Upper bound accuracy
*/
public static Pair<Double,Accuracy> valueAndAccuracyFromInterval(double loVal, Accuracy loAcc, double hiVal, Accuracy hiAcc) throws PrismException
{
// Don't support probabilistic accuracy bounds
if (loAcc != null && loAcc.getLevel() == AccuracyLevel.PROBABLY_BOUNDED) {
throw new PrismException("Cannot create interval accuracy from probabilistic bounds");
}
if (loAcc != null && loAcc.getLevel() == AccuracyLevel.PROBABLY_BOUNDED) {
throw new PrismException("Cannot create interval accuracy from probabilistic bounds");
}
// Extract lower/upper bounds (taking into account accuracy if present)
double lo = loAcc == null ? loVal : loAcc.getResultLowerBound(loVal);
double hi = hiAcc == null ? hiVal : hiAcc.getResultLowerBound(hiVal);
// Compute new mid point value and error bound
double mid = (lo + hi) / 2.0;
double err = (hi - lo) / 2.0;
// Compute accuracy of new result value:
// "bounded" if lower/upper bounds were provided with bounded accuracy;
// "estimated bounded" if either bound was estimated or missing;
// "exactfp" if "bounded" with error 0
AccuracyLevel accLev;
if (loAcc == null || loAcc.getLevel() == AccuracyLevel.ESTIMATED_BOUNDED) {
accLev = AccuracyLevel.ESTIMATED_BOUNDED;
} else if (hiAcc == null || hiAcc.getLevel() == AccuracyLevel.ESTIMATED_BOUNDED) {
accLev = AccuracyLevel.ESTIMATED_BOUNDED;
} else if (err == 0.0) {
accLev = AccuracyLevel.EXACT_FLOATING_POINT;
} else {
accLev = AccuracyLevel.BOUNDED;
}
// Return pair
Accuracy acc = new Accuracy(accLev, err, true);
return new Pair<>(mid, acc);
}
}

29
prism/src/prism/PrismUtils.java

@ -199,6 +199,35 @@ public class PrismUtils
return true;
}
/**
* Measure supremum norm, either absolute or relative,
* return the maximum difference.
*/
public static <X> double measureSupNorm(HashMap<X,Double> map1, HashMap<X,Double> map2, boolean abs)
{
assert(map1.size() == map2.size());
double value = 0;
Set<Entry<X,Double>> entries = map1.entrySet();
for (Entry<X,Double> entry : entries) {
double diff;
double d1 = entry.getValue();
if (map2.get(entry.getKey()) != null) {
double d2 = map2.get(entry.getKey());
if (abs) {
diff = measureSupNormAbs(d1, d2);
} else {
diff = measureSupNormRel(d1, d2);
}
if (diff > value) {
value = diff;
}
} else {
diff = Double.POSITIVE_INFINITY;
}
}
return value;
}
/**
* Measure supremum norm, either absolute or relative,
* return the maximum difference.

Loading…
Cancel
Save