Browse Source

(explicit iteration refactoring) DTMCModelChecker: use the new infrastructure for doing the numerical iteration computations.

git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@12129 bbc10eb1-c90d-0410-af57-cb519fbb1720
master
Joachim Klein 9 years ago
parent
commit
12e377de4a
  1. 404
      prism/src/explicit/MDPModelChecker.java

404
prism/src/explicit/MDPModelChecker.java

@ -53,6 +53,7 @@ import acceptance.AcceptanceReach;
import acceptance.AcceptanceType;
import automata.DA;
import automata.LTL2WDBA;
import common.IntSet;
import common.IterableBitSet;
import explicit.modelviews.EquivalenceRelationInteger;
import explicit.modelviews.MDPEquiv;
@ -525,12 +526,13 @@ public class MDPModelChecker extends ProbModelChecker
{
ModelCheckerResult res = null;
IterationMethod iterationMethod = null;
switch (method) {
case VALUE_ITERATION:
res = computeReachProbsValIter(mdp, no, yes, min, init, known, strat);
iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam);
break;
case GAUSS_SEIDEL:
res = computeReachProbsGaussSeidel(mdp, no, yes, min, init, known, strat);
iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false);
break;
case POLICY_ITERATION:
res = computeReachProbsPolIter(mdp, no, yes, min, strat);
@ -542,6 +544,10 @@ public class MDPModelChecker extends ProbModelChecker
throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName());
}
if (res == null) { // not yet computed, use iterationMethod
res = doValueIterationReachProbs(mdp, no, yes, min, init, known, iterationMethod, false, strat);
}
return res;
}
@ -751,29 +757,48 @@ public class MDPModelChecker extends ProbModelChecker
protected ModelCheckerResult computeReachProbsValIter(MDP mdp, BitSet no, BitSet yes, boolean min, double init[], BitSet known, int strat[])
throws PrismException
{
ModelCheckerResult res;
IterationMethodPower iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam);
return doValueIterationReachProbs(mdp, no, yes, min, init, known, iterationMethod, false, strat);
}
/**
* Compute reachability probabilities using value iteration.
* Optionally, store optimal (memoryless) strategy info.
* @param mdp The MDP
* @param no Probability 0 states
* @param yes Probability 1 states
* @param min Min or max probabilities (true=min, false=max)
* @param init Optionally, an initial solution vector (will be overwritten)
* @param known Optionally, a set of states for which the exact answer is known
* @param iterationMethod The iteration method
* @param topological Do topological value iteration?
* @param strat Storage for (memoryless) strategy choice indices (ignored if null)
* Note: if 'known' is specified (i.e. is non-null), 'init' must also be given and is used for the exact values.
*/
protected ModelCheckerResult doValueIterationReachProbs(MDP mdp, BitSet no, BitSet yes, boolean min, double init[], BitSet known, IterationMethod iterationMethod, boolean topological, int strat[])
throws PrismException
{
BitSet unknown;
int i, n, iters;
double soln[], soln2[], tmpsoln[], initVal;
boolean done;
int i, n;
double initVal;
long timer;
// Start value iteration
timer = System.currentTimeMillis();
mainLog.println("Starting value iteration (" + (min ? "min" : "max") + ")...");
String description = (min ? "min" : "max")
+ (topological ? ", topological": "" )
+ ", with " + iterationMethod.getDescriptionShort();
mainLog.println("Starting value iteration (" + description + ")...");
ExportIterations iterationsExport = null;
if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) {
iterationsExport = new ExportIterations("Explicit ReachRewards value iteration");
iterationsExport = new ExportIterations("Explicit MDP ReachProbs value iteration (" + description + ")");
}
// Store num states
n = mdp.getNumStates();
// Create solution vector(s)
soln = new double[n];
soln2 = (init == null) ? new double[n] : init;
// Initialise solution vectors. Use (where available) the following in order of preference:
// (1) exact answer, if already known; (2) 1.0/0.0 if in yes/no; (3) passed in initial value; (4) initVal
// where initVal is 0.0 or 1.0, depending on whether we converge from below/above.
@ -781,14 +806,15 @@ public class MDPModelChecker extends ProbModelChecker
if (init != null) {
if (known != null) {
for (i = 0; i < n; i++)
soln[i] = soln2[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i];
init[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i];
} else {
for (i = 0; i < n; i++)
soln[i] = soln2[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i];
init[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i];
}
} else {
init = new double[n];
for (i = 0; i < n; i++)
soln[i] = soln2[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal;
init[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal;
}
// Determine set of states actually need to compute values for
@ -800,48 +826,28 @@ public class MDPModelChecker extends ProbModelChecker
unknown.andNot(known);
if (iterationsExport != null)
iterationsExport.exportVector(soln, 0);
iterationsExport.exportVector(init, 0);
// Start iterations
iters = 0;
done = false;
while (!done && iters < maxIters) {
iters++;
// Matrix-vector multiply and min/max ops
mdp.mvMultMinMax(soln, min, soln2, unknown, false, strat);
IterationMethod.IterationValIter iteration = iterationMethod.forMvMultMinMax(mdp, min, strat);
iteration.init(init);
if (iterationsExport != null)
iterationsExport.exportVector(soln2, 0);
IntSet unknownStates = IntSet.asIntSet(unknown);
// Check termination
done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE);
// Swap vectors for next iter
tmpsoln = soln;
soln = soln2;
soln2 = tmpsoln;
}
if (topological) {
// Compute SCCInfo, including trivial SCCs in the subgraph obtained when only considering
// states in unknown
SCCInfo sccs = SCCComputer.computeTopologicalOrdering(this, mdp, true, unknown::get);
// Finished value iteration
timer = System.currentTimeMillis() - timer;
mainLog.print("Value iteration (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds.");
IterationMethod.SingletonSCCSolver singletonSCCSolver = (int s, double[] soln) -> {
soln[s] = mdp.mvMultJacMinMaxSingle(s, soln, min, strat);
};
if (iterationsExport != null)
iterationsExport.close();
// Non-convergence is an error (usually)
if (!done && errorOnNonConverge) {
String msg = "Iterative method did not converge within " + iters + " iterations.";
msg += "\nConsider using a different numerical method or increasing the maximum number of iterations";
throw new PrismException(msg);
// run the actual value iteration
return iterationMethod.doTopologicalValueIteration(this, description, sccs, iteration, singletonSCCSolver, timer, iterationsExport);
} else {
// run the actual value iteration
return iterationMethod.doValueIteration(this, description, iteration, unknownStates, timer, iterationsExport);
}
// Return results
res = new ModelCheckerResult();
res.soln = soln;
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;
}
/**
@ -858,92 +864,8 @@ public class MDPModelChecker extends ProbModelChecker
protected ModelCheckerResult computeReachProbsGaussSeidel(MDP mdp, BitSet no, BitSet yes, boolean min, double init[], BitSet known, int strat[])
throws PrismException
{
ModelCheckerResult res;
BitSet unknown;
int i, n, iters;
double soln[], initVal, maxDiff;
boolean done;
long timer;
// Start value iteration
timer = System.currentTimeMillis();
mainLog.println("Starting Gauss-Seidel (" + (min ? "min" : "max") + ")...");
ExportIterations iterationsExport = null;
if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) {
iterationsExport = new ExportIterations("Explicit MDP ReachProbs Gauss-Seidel iteration");
}
// Store num states
n = mdp.getNumStates();
// Create solution vector
soln = (init == null) ? new double[n] : init;
// Initialise solution vector. Use (where available) the following in order of preference:
// (1) exact answer, if already known; (2) 1.0/0.0 if in yes/no; (3) passed in initial value; (4) initVal
// where initVal is 0.0 or 1.0, depending on whether we converge from below/above.
initVal = (valIterDir == ValIterDir.BELOW) ? 0.0 : 1.0;
if (init != null) {
if (known != null) {
for (i = 0; i < n; i++)
soln[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i];
} else {
for (i = 0; i < n; i++)
soln[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i];
}
} else {
for (i = 0; i < n; i++)
soln[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal;
}
// Determine set of states actually need to compute values for
unknown = new BitSet();
unknown.set(0, n);
unknown.andNot(yes);
unknown.andNot(no);
if (known != null)
unknown.andNot(known);
if (iterationsExport != null)
iterationsExport.exportVector(soln, 0);
// Start iterations
iters = 0;
done = false;
while (!done && iters < maxIters) {
iters++;
// Matrix-vector multiply
maxDiff = mdp.mvMultGSMinMax(soln, min, unknown, false, termCrit == TermCrit.ABSOLUTE, strat);
if (iterationsExport != null)
iterationsExport.exportVector(soln, 0);
// Check termination
done = maxDiff < termCritParam;
}
// Finished Gauss-Seidel
timer = System.currentTimeMillis() - timer;
mainLog.print("Gauss-Seidel");
mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds.");
if (iterationsExport != null)
iterationsExport.close();
// Non-convergence is an error (usually)
if (!done && errorOnNonConverge) {
String msg = "Iterative method did not converge within " + iters + " iterations.";
msg += "\nConsider using a different numerical method or increasing the maximum number of iterations";
throw new PrismException(msg);
}
// Return results
res = new ModelCheckerResult();
res.soln = soln;
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;
IterationMethodGS iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false);
return doValueIterationReachProbs(mdp, no, yes, min, init, known, iterationMethod, false, strat);
}
/**
@ -1628,20 +1550,7 @@ public class MDPModelChecker extends ProbModelChecker
}
}
// Compute rewards
switch (mdpSolnMethod) {
case VALUE_ITERATION:
res = computeReachRewardsValIter(mdp, mdpRewards, target, inf, min, init, known, strat);
break;
case GAUSS_SEIDEL:
res = computeReachRewardsGaussSeidel(mdp, mdpRewards, target, inf, min, init, known, strat);
break;
case POLICY_ITERATION:
res = computeReachRewardsPolIter(mdp, mdpRewards, target, inf, min, strat);
break;
default:
throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName());
}
res = computeReachRewardsNumeric(mdp, mdpRewards, mdpSolnMethod, target, inf, min, init, known, strat);
// Store strategy
if (genStrat) {
@ -1668,6 +1577,32 @@ public class MDPModelChecker extends ProbModelChecker
return res;
}
protected ModelCheckerResult computeReachRewardsNumeric(MDP mdp, MDPRewards mdpRewards, MDPSolnMethod method, BitSet target, BitSet inf, boolean min, double init[], BitSet known, int strat[]) throws PrismException
{
ModelCheckerResult res = null;
IterationMethod iterationMethod = null;
switch (method) {
case VALUE_ITERATION:
iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam);
break;
case GAUSS_SEIDEL:
iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false);
break;
case POLICY_ITERATION:
res = computeReachRewardsPolIter(mdp, mdpRewards, target, inf, min, strat);
break;
default:
throw new PrismException("Unknown MDP solution method " + method.fullName());
}
if (res == null) { // not yet computed, use iterationMethod
res = doValueIterationReachRewards(mdp, mdpRewards, iterationMethod, target, inf, min, init, known, false, strat);
}
return res;
}
/**
* Compute expected reachability rewards using value iteration.
* Optionally, store optimal (memoryless) strategy info.
@ -1684,42 +1619,58 @@ public class MDPModelChecker extends ProbModelChecker
protected ModelCheckerResult computeReachRewardsValIter(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], BitSet known, int strat[])
throws PrismException
{
ModelCheckerResult res;
IterationMethodPower iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam);
return doValueIterationReachRewards(mdp, mdpRewards, iterationMethod, target, inf, min, init, known, min, strat);
}
/**
* Compute expected reachability rewards using value iteration.
* Optionally, store optimal (memoryless) strategy info.
* @param mdp The MDP
* @param mdpRewards The rewards
* @param target Target states
* @param inf States for which reward is infinite
* @param min Min or max rewards (true=min, false=max)
* @param init Optionally, an initial solution vector (will be overwritten)
* @param known Optionally, a set of states for which the exact answer is known
* @param topological Do topological value iteration?
* @param strat Storage for (memoryless) strategy choice indices (ignored if null)
* Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values.
*/
protected ModelCheckerResult doValueIterationReachRewards(MDP mdp, MDPRewards mdpRewards, IterationMethod iterationMethod, BitSet target, BitSet inf, boolean min, double init[], BitSet known, boolean topological, int strat[])
throws PrismException
{
BitSet unknown;
int i, n, iters;
double soln[], soln2[], tmpsoln[];
boolean done;
int i, n;
long timer;
// Start value iteration
timer = System.currentTimeMillis();
mainLog.println("Starting value iteration (" + (min ? "min" : "max") + ")...");
String description = (min ? "min" : "max") + (topological ? ", topological" : "" ) + ", with " + iterationMethod.getDescriptionShort();
mainLog.println("Starting value iteration (" + description + ")...");
ExportIterations iterationsExport = null;
if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) {
iterationsExport = new ExportIterations("Explicit MDP ReachProbs value iteration");
iterationsExport = new ExportIterations("Explicit MDP ReachRewards value iteration (" + description +")");
}
// Store num states
n = mdp.getNumStates();
// Create solution vector(s)
soln = new double[n];
soln2 = (init == null) ? new double[n] : init;
// Initialise solution vectors. Use (where available) the following in order of preference:
// (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0
if (init != null) {
if (known != null) {
for (i = 0; i < n; i++)
soln[i] = soln2[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i];
init[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i];
} else {
for (i = 0; i < n; i++)
soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i];
init[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i];
}
} else {
init = new double[n];
for (i = 0; i < n; i++)
soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0;
init[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0;
}
// Determine set of states actually need to compute values for
@ -1731,49 +1682,28 @@ public class MDPModelChecker extends ProbModelChecker
unknown.andNot(known);
if (iterationsExport != null)
iterationsExport.exportVector(soln, 0);
// Start iterations
iters = 0;
done = false;
while (!done && iters < maxIters) {
//mainLog.println(soln);
iters++;
// Matrix-vector multiply and min/max ops
mdp.mvMultRewMinMax(soln, mdpRewards, min, soln2, unknown, false, strat);
iterationsExport.exportVector(init, 0);
if (iterationsExport != null)
iterationsExport.exportVector(soln2, 0);
IterationMethod.IterationValIter forMvMultRewMinMax = iterationMethod.forMvMultRewMinMax(mdp, mdpRewards, min, strat);
forMvMultRewMinMax.init(init);
// Check termination
done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE);
// Swap vectors for next iter
tmpsoln = soln;
soln = soln2;
soln2 = tmpsoln;
}
IntSet unknownStates = IntSet.asIntSet(unknown);
if (iterationsExport != null)
iterationsExport.close();
if (topological) {
// Compute SCCInfo, including trivial SCCs in the subgraph obtained when only considering
// states in unknown
SCCInfo sccs = SCCComputer.computeTopologicalOrdering(this, mdp, true, unknown::get);
// Finished value iteration
timer = System.currentTimeMillis() - timer;
mainLog.print("Value iteration (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds.");
IterationMethod.SingletonSCCSolver singletonSCCSolver = (int s, double[] soln) -> {
soln[s] = mdp.mvMultRewJacMinMaxSingle(s, soln, mdpRewards, min, strat);
};
// Non-convergence is an error (usually)
if (!done && errorOnNonConverge) {
String msg = "Iterative method did not converge within " + iters + " iterations.";
msg += "\nConsider using a different numerical method or increasing the maximum number of iterations";
throw new PrismException(msg);
// run the actual value iteration
return iterationMethod.doTopologicalValueIteration(this, description, sccs, forMvMultRewMinMax, singletonSCCSolver, timer, iterationsExport);
} else {
// run the actual value iteration
return iterationMethod.doValueIteration(this, description, forMvMultRewMinMax, unknownStates, timer, iterationsExport);
}
// Return results
res = new ModelCheckerResult();
res.soln = soln;
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;
}
/**
@ -1792,76 +1722,8 @@ public class MDPModelChecker extends ProbModelChecker
protected ModelCheckerResult computeReachRewardsGaussSeidel(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[],
BitSet known, int strat[]) throws PrismException
{
ModelCheckerResult res;
BitSet unknown;
int i, n, iters;
double soln[], maxDiff;
boolean done;
long timer;
// Start value iteration
timer = System.currentTimeMillis();
mainLog.println("Starting Gauss-Seidel (" + (min ? "min" : "max") + ")...");
// Store num states
n = mdp.getNumStates();
// Create solution vector(s)
soln = (init == null) ? new double[n] : init;
// Initialise solution vector. Use (where available) the following in order of preference:
// (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0
if (init != null) {
if (known != null) {
for (i = 0; i < n; i++)
soln[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i];
} else {
for (i = 0; i < n; i++)
soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i];
}
} else {
for (i = 0; i < n; i++)
soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0;
}
// Determine set of states actually need to compute values for
unknown = new BitSet();
unknown.set(0, n);
unknown.andNot(target);
unknown.andNot(inf);
if (known != null)
unknown.andNot(known);
// Start iterations
iters = 0;
done = false;
while (!done && iters < maxIters) {
//mainLog.println(soln);
iters++;
// Matrix-vector multiply and min/max ops
maxDiff = mdp.mvMultRewGSMinMax(soln, mdpRewards, min, unknown, false, termCrit == TermCrit.ABSOLUTE, strat);
// Check termination
done = maxDiff < termCritParam;
}
// Finished Gauss-Seidel
timer = System.currentTimeMillis() - timer;
mainLog.print("Gauss-Seidel (" + (min ? "min" : "max") + ")");
mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds.");
// Non-convergence is an error (usually)
if (!done && errorOnNonConverge) {
String msg = "Iterative method did not converge within " + iters + " iterations.";
msg += "\nConsider using a different numerical method or increasing the maximum number of iterations";
throw new PrismException(msg);
}
// Return results
res = new ModelCheckerResult();
res.soln = soln;
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;
IterationMethodGS iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false);
return doValueIterationReachRewards(mdp, mdpRewards, iterationMethod, target, inf, min, init, known, min, strat);
}
/**

Loading…
Cancel
Save