Browse Source

(interval iteration, explicit) provide IterationMethod computations for interval iteration

Add some default computations for interval iteration to DTMC / MDP


git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@12141 bbc10eb1-c90d-0410-af57-cb519fbb1720
master
Joachim Klein 9 years ago
parent
commit
d9bad734bd
  1. 91
      prism/src/explicit/DTMC.java
  2. 384
      prism/src/explicit/IterationMethod.java
  3. 78
      prism/src/explicit/IterationMethodGS.java
  4. 46
      prism/src/explicit/IterationMethodJacobi.java
  5. 66
      prism/src/explicit/IterationMethodPower.java
  6. 76
      prism/src/explicit/MDP.java

91
prism/src/explicit/DTMC.java

@ -32,6 +32,7 @@ import java.util.PrimitiveIterator.OfInt;
import common.IterableStateSet;
import prism.Pair;
import prism.PrismException;
import explicit.rewards.MCRewards;
/**
@ -314,6 +315,52 @@ public interface DTMC extends Model
return maxDiff;
}
/**
* Do a Gauss-Seidel-style matrix-vector multiplication (in the interval iteration context) for
* the DTMC's transition probability matrix P and the vector {@code vect} passed in,
* storing new values directly in {@code vect} as computed (for use in interval iteration).
* i.e. for all s: vect[s] = (sum_{j!=s} P(s,j)*vect[j]) / (1-P(s,s))
* @param vect Vector to multiply by (and store the result in)
* @param states Do multiplication for these rows, in this order
* @param ensureMonotonic ensure monotonicity
* @param checkMonotonic check monotonicity
* @param fromBelow iteration from below or from above? (for ensureMonotonicity, checkMonotonicity)
*/
public default void mvMultGSIntervalIter(double vect[], PrimitiveIterator.OfInt states, boolean ensureMonotonic, boolean checkMonotonic, boolean fromBelow) throws PrismException
{
double d;
while (states.hasNext()) {
int s = states.nextInt();
d = mvMultJacSingle(s, vect);
if (ensureMonotonic) {
if (fromBelow) {
// from below: do max old and new
if (vect[s] > d) {
d = vect[s];
}
} else {
// from above: do min old and new
if (vect[s] < d) {
d = vect[s];
}
}
}
if (checkMonotonic) {
if (fromBelow) {
if (vect[s] > d) {
throw new PrismException("Monotonicity violated (from below): old value " + vect[s] + " > new value " + d);
}
} else {
if (vect[s] < d) {
throw new PrismException("Monotonicity violated (from above): old value " + vect[s] + " < new value " + d);
}
}
}
vect[s] = d;
}
}
/**
* Do a Jacobi-style matrix-vector multiplication for the DTMC's transition probability matrix P
* and the vector {@code vect} passed in, for the state indices provided by the iterator,
@ -431,6 +478,50 @@ public interface DTMC extends Model
return maxDiff;
}
/**
* Do a matrix-vector multiplication and sum of action reward (Gauss-Seidel, interval iteration context).
* @param vect Vector to multiply by and store result in
* @param mcRewards The rewards
* @param states Do multiplication for these rows, in the specified order
* @param ensureMonotonic ensure monotonicity
* @param checkMonotonic check monotonicity
* @param fromBelow iteration from below or from above? (for ensureMonotonicity, checkMonotonicity)
*/
public default void mvMultRewGSIntervalIter(double vect[], MCRewards mcRewards, PrimitiveIterator.OfInt states, boolean ensureMonotonic, boolean checkMonotonic, boolean fromBelow) throws PrismException
{
double d;
while (states.hasNext()) {
int s = states.nextInt();
d = mvMultRewJacSingle(s, vect, mcRewards);
if (ensureMonotonic) {
if (fromBelow) {
// from below: do max old and new
if (vect[s] > d) {
d = vect[s];
}
} else {
// from above: do min old and new
if (vect[s] < d) {
d = vect[s];
}
}
}
if (checkMonotonic) {
if (fromBelow) {
if (vect[s] > d) {
throw new PrismException("Monotonicity violated (from below): old value " + vect[s] + " > new value " + d);
}
} else {
if (vect[s] < d) {
throw new PrismException("Monotonicity violated (from above): old value " + vect[s] + " < new value " + d);
}
}
}
vect[s] = d;
}
}
/**
* Do a single row of matrix-vector multiplication and sum of action reward.
* @param s Row index

384
prism/src/explicit/IterationMethod.java

@ -27,18 +27,22 @@
package explicit;
import java.util.PrimitiveIterator;
import common.IntSet;
import common.PeriodicTimer;
import explicit.rewards.MCRewards;
import explicit.rewards.MDPRewards;
import prism.OptionsIntervalIteration;
import prism.PrismException;
import prism.PrismSettings;
import prism.PrismUtils;
/**
* Abstract class that encapsulates the functionality for the different iteration methods
* (e.g., Power, Jacobi, Gauss-Seidel, ...).
* <p>
* Provides methods as well to do the actual work in a (topological) value iteration.
* Provides methods as well to do the actual work in a (topological) value or interval iteration.
*/
public abstract class IterationMethod {
@ -64,6 +68,30 @@ public abstract class IterationMethod {
public Model getModel();
}
/**
* Interface for an object that provides the atomic steps for a value iteration
* in the context of interval iteration, i.e., for the interval iteration
* there are two IterationIntervalIter objects, one from below or from above.
*/
public interface IterationIntervalIter {
/** Initialise the value iteration with the given solution vector */
public void init(double[] soln);
/** Get the current solution vector */
public double[] getSolnVector();
/** Perform one iteration (over the set of states) */
public void iterate(IntSet states) throws PrismException;
/**
* Solve for a given singleton SCC consisting of {@code state} using {@code solver},
* store the result in the solution vector(s).
*/
public void solveSingletonSCC(int s, SingletonSCCSolver solver);
/** Return the underlying model */
public Model getModel();
}
/** Storage for a single solution vector */
public class IterationBasic {
protected final Model model;
@ -104,6 +132,15 @@ public abstract class IterationMethod {
}
}
/** Abstract base class for an IterationIntervalIter with a single solution vector */
protected abstract class SingleVectorIterationIntervalIter extends IterationBasic implements IterationIntervalIter
{
public SingleVectorIterationIntervalIter(Model model)
{
super(model);
}
}
/**
* Functional interface for a post-processing step after an iteration that involves
* a pair of solution vectors.
@ -120,11 +157,11 @@ public abstract class IterationMethod {
}
/**
* Abstract base class for an IterationValIter that
* Abstract base class for an IterationValIter / IterationIntervalIter that
* requires two solution vectors.
* Optionally, a post processing step is performed after each iteration.
*/
protected abstract class TwoVectorIteration extends IterationBasic implements IterationValIter {
protected abstract class TwoVectorIteration extends IterationBasic implements IterationValIter, IterationIntervalIter {
/** The solution vector that serves as the target vector in the iteration step */
protected double[] soln2;
/** Post processing, may be null */
@ -150,6 +187,22 @@ public abstract class IterationMethod {
/** Perform one iteration */
public abstract void doIterate(IntSet states) throws PrismException;
@Override
public void iterate(IntSet states) throws PrismException
{
// do the iteration
doIterate(states);
// optionally, post processing
if (postProcessor != null) {
postProcessor.apply(soln, soln2, states);
}
// switch vectors
double[] tmp = soln;
soln = soln2;
soln2 = tmp;
}
@Override
public boolean iterateAndCheckConvergence(IntSet states) throws PrismException
{
@ -218,10 +271,26 @@ public abstract class IterationMethod {
/** Obtain an Iteration object using mvMult (matrix-vector multiplication) in a DTMC */
public abstract IterationValIter forMvMult(DTMC dtmc) throws PrismException;
/**
* Obtain an Iteration object (for interval iteration) using mvMult
* (matrix-vector multiplication) in a DTMC.
* @param fromBelow for interval iteration from below?
* @param enforceMonotonic enforce element-wise monotonicity of the solution vector
* @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated
*/
public abstract IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException;
/** Obtain an Iteration object using mvMultRew (matrix-vector multiplication with rewards) in a DTMC */
public abstract IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew) throws PrismException;
/**
* Obtain an Iteration object (for interval iteration) using mvMultRew
* (matrix-vector multiplication with rewards) in a DTMC.
* @param fromBelow for interval iteration from below?
* @param enforceMonotonic enforce element-wise monotonicity of the solution vector
* @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated
*/
public abstract IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException;
// ------------ Abstract MDP methods ----------------------------
@ -234,6 +303,18 @@ public abstract class IterationMethod {
*/
public abstract IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat) throws PrismException;
/**
* Obtain an Iteration object using mvMultMinMax (matrix-vector multiplication, followed by min/max)
* in an MDP, for interval iteration.
* @param mdp the MDP
* @param min do min?
* @param strat optional, storage for strategy, ignored if null
* @param fromBelow for interval iteration from below?
* @param enforceMonotonic enforce element-wise monotonicity of the solution vector
* @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated
*/
public abstract IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException;
/**
* Obtain an Iteration object using mvMultRewMinMax (matrix-vector multiplication with rewards, followed by min/max)
* in an MDP.
@ -244,6 +325,18 @@ public abstract class IterationMethod {
*/
public abstract IterationValIter forMvMultRewMinMax(MDP mdp, MDPRewards rewards, boolean min, int[] strat) throws PrismException;
/**
* Obtain an Iteration object using mvMultRewMinMax (matrix-vector multiplication with rewards, followed by min/max)
* in an MDP, for interval iteration.
* @param mdp the MDP
* @param rewards the reward structure
* @param min do min?
* @param strat optional, storage for strategy, ignored if null
* @param fromBelow for interval iteration from below?
* @param enforceMonotonic enforce element-wise monotonicity of the solution vector
* @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated
*/
public abstract IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException;
// ------------ Abstract generic methods ----------------------------
@ -416,6 +509,289 @@ public abstract class IterationMethod {
return res;
}
/**
* Perform the actual work of an interval iteration, i.e., iterate until convergence or abort.
*
* @param mc ProbModelChecker (for log and settings)
* @param description Description (for logging)
* @param below The iteration object for the iteration from below
* @param above The iteration object for the iteration from above
* @param unknownStates The set of unknown states, i.e., whose value should be determined
* @param startTime The start time (for logging purposes, obtained from a call to System.currentTimeMillis())
* @param iterationsExport an ExportIterations object (optional, ignored if null)
* @return a ModelChecker result with the solution vector and statistics
* @throws PrismException on non-convergence (if mc.errorOnNonConverge is set)
*/
public ModelCheckerResult doIntervalIteration(ProbModelChecker mc, String description, IterationIntervalIter below, IterationIntervalIter above, IntSet unknownStates, long timer, ExportIterations iterationsExport) throws PrismException {
try {
// Start iterations
int iters = 0;
final int maxIters = mc.maxIters;
boolean done = false;
PeriodicTimer updatesTimer = new PeriodicTimer(ProbModelChecker.UPDATE_DELAY);
updatesTimer.start();
while (!done && iters < maxIters) {
iters++;
// Matrix-vector multiply
below.iterate(unknownStates);
above.iterate(unknownStates);
if (iterationsExport != null) {
iterationsExport.exportVector(below.getSolnVector(), 0);
iterationsExport.exportVector(above.getSolnVector(), 1);
}
intervalIterationCheckForProblems(below.getSolnVector(), above.getSolnVector(), unknownStates.iterator());
// Check termination
done = PrismUtils.doublesAreClose(below.getSolnVector(), above.getSolnVector(), termCritParam, absolute);
if (done) {
double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute);
mc.getLog().println("Max " + (!absolute ? "relative ": "") +
"diff between upper and lower bound on convergence: " + PrismUtils.formatDouble(diff));
done = true;
}
if (!done && updatesTimer.triggered()) {
double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute);
mc.getLog().print("Iteration " + iters + ": ");
mc.getLog().print("max " + (absolute ? "" : "relative ") + "diff=" + PrismUtils.formatDouble(diff));
mc.getLog().println(", " + PrismUtils.formatDouble2dp(updatesTimer.elapsedMillisTotal() / 1000.0) + " sec so far");
}
}
// Finished value iteration
long mvCount = 2 * iters * countTransitions(below.getModel(), unknownStates);
timer = System.currentTimeMillis() - timer;
mc.getLog().print("Interval iteration (" + description + ")");
mc.getLog().print(" took " + iters + " iterations, ");
mc.getLog().print(mvCount + " MV-multiplications");
mc.getLog().println(" and " + timer / 1000.0 + " seconds.");
if (done && OptionsIntervalIteration.from(mc.getSettings()).isSelectMidpointForResult()) {
PrismUtils.selectMidpoint(below.getSolnVector(), above.getSolnVector());
if (iterationsExport != null) {
// export midpoint
iterationsExport.exportVector(below.getSolnVector(), 0);
iterationsExport.exportVector(below.getSolnVector(), 1);
}
}
// Non-convergence is an error (usually)
if (!done && mc.errorOnNonConverge) {
String msg = "Iterative method (interval iteration) did not converge within " + iters + " iterations.";
msg += "\nConsider using a different numerical method or increasing the maximum number of iterations";
throw new PrismException(msg);
}
// Return results
ModelCheckerResult res = new ModelCheckerResult();
res.soln = below.getSolnVector();
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;
} finally {
if (iterationsExport != null)
iterationsExport.close();
}
}
/**
* Perform the actual work of a topological interval iteration, i.e., iterate until convergence or abort.
*
* @param mc ProbModelChecker (for log and settings)
* @param description Description (for logging)
* @param sccs The information about the SCCs and topological order
* @param below The iteration object for the value iteration from below
* @param above The iteration object for the value iteration from above
* @param singletonSCCSolver The solver for singleton SCCs
* @param startTime The start time (for logging purposes, obtained from a call to System.currentTimeMillis())
* @param iterationsExport an ExportIterations object (optional, ignored if null)
* @return a ModelChecker result with the solution vector and statistics
* @throws PrismException on non-convergence (if mc.errorOnNonConverge is set)
*/
public ModelCheckerResult doTopologicalIntervalIteration(ProbModelChecker mc, String description, SCCInfo sccs, IterationIntervalIter below, IterationIntervalIter above, SingletonSCCSolver singletonSCCSolver, long timer, ExportIterations iterationsExport) throws PrismException {
try {
// Start iterations
int iters = 0;
long mvCount = 0;
final int maxIters = mc.maxIters;
PeriodicTimer updatesTimer = new PeriodicTimer(ProbModelChecker.UPDATE_DELAY);
updatesTimer.start();
int numSCCs = sccs.getNumSCCs();
int numNonSingletonSCCs = sccs.countNonSingletonSCCs();
int finishedNonSingletonSCCs = 0;
boolean done = true;
for (int scc = 0; scc < numSCCs; scc++) {
boolean doneSCC;
if (sccs.isSingletonSCC(scc)) {
// get the single state in this SCC
int state = sccs.getStatesForSCC(scc).iterator().nextInt();
below.solveSingletonSCC(state, singletonSCCSolver);
above.solveSingletonSCC(state, singletonSCCSolver);
iters++;
mvCount += 2 * countTransitions(below.getModel(), IntSet.asIntSet(state));
if (iterationsExport != null) {
iterationsExport.exportVector(below.getSolnVector(), 0);
iterationsExport.exportVector(above.getSolnVector(), 1);
}
intervalIterationCheckForProblems(below.getSolnVector(), above.getSolnVector(), IntSet.asIntSet(state).iterator());
doneSCC = true;
} else {
// complex SCC: do VI
doneSCC = false;
int itersInSCC = 0;
IntSet statesForSCC = sccs.getStatesForSCC(scc);
// Adjust upper bound by adding 2*epsilon,
// adding 1*epsilon would be fine, but we are a bit more conservative.
// TODO: We also don't really need to do adjustment for bottom SCCs...
PrimitiveIterator.OfInt it = statesForSCC.iterator();
final double[] solnAbove = above.getSolnVector();
final double adjustment = 2*termCritParam;
while (it.hasNext()) {
solnAbove[it.nextInt()] += adjustment;
}
// abort on convergence or if iterations *in this SCC* are above maxIters
while (!doneSCC && itersInSCC < maxIters) {
iters++;
itersInSCC++;
// do iteration step
below.iterate(statesForSCC);
above.iterate(statesForSCC);
if (iterationsExport != null) {
iterationsExport.exportVector(below.getSolnVector(), 0);
iterationsExport.exportVector(above.getSolnVector(), 1);
}
intervalIterationCheckForProblems(below.getSolnVector(), above.getSolnVector(), statesForSCC.iterator());
// Check termination (inside SCC)
doneSCC = PrismUtils.doublesAreClose(below.getSolnVector(), above.getSolnVector(), statesForSCC.iterator(), termCritParam, absolute);
if (!doneSCC && updatesTimer.triggered()) {
double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute, statesForSCC.iterator());
mc.getLog().print("Iteration " + iters + ": ");
mc.getLog().print("max " + (absolute ? "" : "relative ") + "diff (for iteration " + itersInSCC + " in current SCC " + (finishedNonSingletonSCCs+1) + " of " + numNonSingletonSCCs + ") = " + PrismUtils.formatDouble(diff));
mc.getLog().println(", " + PrismUtils.formatDouble2dp(updatesTimer.elapsedMillisTotal() / 1000.0) + " sec so far");
}
}
mvCount += 2 * itersInSCC * countTransitions(below.getModel(), statesForSCC);
finishedNonSingletonSCCs++;
}
if (!doneSCC) {
done = false;
break;
}
}
if (done) {
double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute);
mc.getLog().println("Max " + (absolute ? "" : "relative ") +
"diff between upper and lower bound on convergence: " + PrismUtils.formatDouble(diff));
done = true;
}
// Finished value iteration
timer = System.currentTimeMillis() - timer;
mc.getLog().print("Interval iteration (" + description + ", with " + numNonSingletonSCCs + " non-singleton SCCs)");
mc.getLog().print(" took " + iters + " iterations, ");
mc.getLog().print(mvCount + " MV-multiplications");
mc.getLog().println(" and " + timer / 1000.0 + " seconds.");
if (done && OptionsIntervalIteration.from(mc.getSettings()).isSelectMidpointForResult()) {
PrismUtils.selectMidpoint(below.getSolnVector(), above.getSolnVector());
if (iterationsExport != null) {
// export midpoint
iterationsExport.exportVector(below.getSolnVector(), 0);
iterationsExport.exportVector(below.getSolnVector(), 1);
}
}
if (iterationsExport != null)
iterationsExport.close();
// Non-convergence is an error (usually)
if (!done && mc.errorOnNonConverge) {
String msg = "Iterative method (interval iteration) did not converge within " + iters + " iterations.";
msg += "\nConsider using a different numerical method or increasing the maximum number of iterations";
throw new PrismException(msg);
}
// Return results
ModelCheckerResult res = new ModelCheckerResult();
res.soln = below.getSolnVector();
res.numIters = iters;
res.timeTaken = timer / 1000.0;
return res;
} finally {
if (iterationsExport != null)
iterationsExport.close();
}
}
/**
* Compares the current lower and upper solution vectors in an interval iteration
* and throws an exception if lower bound values are larger than upper bound values,
* as this indicates problems.
* @param lower the current lower iteration solution vector
* @param upper the current upper iteration solution vector
* @param states iterator over the states in question
*/
private static void intervalIterationCheckForProblems(double[] lower, double[] upper, PrimitiveIterator.OfInt states) throws PrismException
{
while (states.hasNext()) {
int s = states.nextInt();
if (lower[s] > upper[s]) {
throw new PrismException("In interval iteration, the lower value (" + lower[s] + ") is larger than the upper value (" + upper[s] + ").\n"
+ "This indicates either problems with numerical stability (rounding, precision of the floating-point representation) or that the initial bounds (for reward computations) are incorrect");
}
}
}
/**
* Perform a post-processing for a two-vector value iteration.
* @param solnOld the previous solution vector
* @param solnNew the newly computed solution vector
* @param states the relevant set of states
* @param fromBelow are we iterating from below?
* @param enforceMonotonicity if true, enforces monotonicity
* @param checkMonotonicity if true, checks for monotonicity (and throws error when non-monotonic)
*/
public static void twoVectorPostProcessing(double[] solnOld, double[] solnNew, IntSet states, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException
{
// TODO: use IntSet states
if (enforceMonotonicity)
if (fromBelow) {
PrismUtils.ensureMonotonicityFromBelow(solnOld, solnNew);
} else {
PrismUtils.ensureMonotonicityFromAbove(solnOld, solnNew);
}
if (checkMonotonicity) {
PrismUtils.checkMonotonicity(solnOld, solnNew, !fromBelow);
}
}
protected long countTransitions(Model model, IntSet unknownStates)
{
if (model instanceof DTMC) {

78
prism/src/explicit/IterationMethodGS.java

@ -68,6 +68,23 @@ public class IterationMethodGS extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity)
{
return new SingleVectorIterationIntervalIter(dtmc) {
@Override
public void iterate(IntSet states) throws PrismException
{
// Matrix-vector multiply
dtmc.mvMultGSIntervalIter(soln,
backwards ? states.reversedIterator() : states.iterator(),
enforceMonotonicity,
checkMonotonicity,
fromBelow);
}
};
}
@Override
public IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew)
{
@ -87,6 +104,24 @@ public class IterationMethodGS extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity)
{
return new SingleVectorIterationIntervalIter(dtmc) {
@Override
public void iterate(IntSet states) throws PrismException
{
// Matrix-vector multiply
dtmc.mvMultRewGSIntervalIter(soln,
rew,
backwards ? states.reversedIterator() : states.iterator(),
enforceMonotonicity,
checkMonotonicity,
fromBelow);
}
};
}
@Override
public IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat)
{
@ -107,6 +142,27 @@ public class IterationMethodGS extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity,
boolean checkMonotonicity) throws PrismException
{
return new SingleVectorIterationIntervalIter(mdp) {
@Override
public void iterate(IntSet states)
{
// TODO: check monotonic not yet supported
// Matrix-vector multiply
mdp.mvMultGSMinMaxIntervalIter(soln,
min,
backwards ? states.reversedIterator() : states.iterator(),
strat,
enforceMonotonicity,
fromBelow);
}
};
}
@Override
public IterationValIter forMvMultRewMinMax(MDP mdp, MDPRewards rewards, boolean min, int[] strat) throws PrismException
{
@ -128,6 +184,28 @@ public class IterationMethodGS extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow,
boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException
{
return new SingleVectorIterationIntervalIter(mdp) {
@Override
public void iterate(IntSet states)
{
// TODO: check monotonic not yet supported
// Matrix-vector multiply
mdp.mvMultRewGSMinMaxIntervalIter(soln,
rewards,
min,
backwards ? states.reversedIterator() : states.iterator(),
strat,
enforceMonotonicity,
fromBelow);
}
};
}
@Override
public String getDescriptionShort()
{

46
prism/src/explicit/IterationMethodJacobi.java

@ -59,6 +59,22 @@ class IterationMethodJacobi extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity)
{
IterationPostProcessor post = (soln, soln2, states) -> {
twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity);
};
return new TwoVectorIteration(dtmc, post) {
@Override
public void doIterate(IntSet states)
{
dtmc.mvMultJac(soln, soln2, states.iterator());
}
};
}
@Override
public IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew)
{
@ -71,18 +87,48 @@ class IterationMethodJacobi extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity)
{
IterationPostProcessor post = (soln, soln2, states) -> {
twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity);
};
return new TwoVectorIteration(dtmc, post) {
@Override
public void doIterate(IntSet states)
{
dtmc.mvMultRewJac(soln, rew, soln2, states.iterator());
}
};
}
@Override
public IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat) throws PrismException
{
throw new PrismNotSupportedException("Jacobi not supported for MDPs");
}
@Override
public IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity,
boolean checkMonotonicity) throws PrismException
{
throw new PrismNotSupportedException("Jacobi not supported for MDPs");
}
@Override
public IterationValIter forMvMultRewMinMax(MDP mdp, MDPRewards rewards, boolean min, int[] strat) throws PrismException
{
throw new PrismNotSupportedException("Jacobi not supported for MDPs");
}
@Override
public IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow,
boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException
{
throw new PrismNotSupportedException("Jacobi not supported for MDPs");
}
@Override
public String getDescriptionShort()
{

66
prism/src/explicit/IterationMethodPower.java

@ -58,6 +58,22 @@ public class IterationMethodPower extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity)
{
IterationPostProcessor post = (soln, soln2, states) -> {
twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity);
};
return new TwoVectorIteration(dtmc, post) {
@Override
public void doIterate(IntSet states)
{
dtmc.mvMult(soln, soln2, states.iterator());
}
};
}
@Override
public IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew)
{
@ -70,6 +86,22 @@ public class IterationMethodPower extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity)
{
IterationPostProcessor post = (soln, soln2, states) -> {
twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity);
};
return new TwoVectorIteration(dtmc, post) {
@Override
public void doIterate(IntSet states)
{
dtmc.mvMultRew(soln, rew, soln2, states.iterator());
}
};
}
@Override
public IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat)
{
@ -82,6 +114,23 @@ public class IterationMethodPower extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity,
boolean checkMonotonicity) throws PrismException
{
IterationPostProcessor post = (soln, soln2, states) -> {
twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity);
};
return new TwoVectorIteration(mdp, post) {
@Override
public void doIterate(IntSet states)
{
mdp.mvMultMinMax(soln, min, soln2, states.iterator(), strat);
}
};
}
@Override
public String getDescriptionShort()
{
@ -100,4 +149,21 @@ public class IterationMethodPower extends IterationMethod {
};
}
@Override
public IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow,
boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException
{
IterationPostProcessor post = (soln, soln2, states) -> {
twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity);
};
return new TwoVectorIteration(mdp, post) {
@Override
public void doIterate(IntSet states)
{
mdp.mvMultRewMinMax(soln, rewards, min, soln2, states.iterator(), strat);
}
};
}
}

76
prism/src/explicit/MDP.java

@ -424,6 +424,43 @@ public interface MDP extends NondetModel
return maxDiff;
}
/**
* Do a Gauss-Seidel-style matrix-vector multiplication followed by min/max in the context of interval iteration.
* i.e. for all s: vect[s] = min/max_k { (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) }
* and store new values directly in {@code vect} as computed.
* Optionally, store optimal (memoryless) strategy info.
* @param vect Vector to multiply by (and store the result in)
* @param min Min or max for (true=min, false=max)
* @param subset Only do multiplication for these rows (ignored if null)
* @param states Perform computation for these rows, in the iteration order
* @param ensureMonotonic ensure monotonicity
* @param fromBelow iteration from below or from above? (for ensureMonotonicity)
*/
public default void mvMultGSMinMaxIntervalIter(double vect[], boolean min, PrimitiveIterator.OfInt states, int strat[], boolean ensureMonotonic, boolean fromBelow)
{
double d;
while (states.hasNext()) {
final int s = states.nextInt();
d = mvMultJacMinMaxSingle(s, vect, min, strat);
if (ensureMonotonic) {
if (fromBelow) {
// from below: do max old and new
if (vect[s] > d) {
d = vect[s];
}
} else {
// from above: do min old and new
if (vect[s] < d) {
d = vect[s];
}
}
vect[s] = d;
} else {
vect[s] = d;
}
}
}
/**
* Do a single row of Jacobi-style matrix-vector multiplication followed by min/max.
* i.e. return min/max_k { (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) }
@ -663,6 +700,45 @@ public interface MDP extends NondetModel
return maxDiff;
}
/**
* Do a Gauss-Seidel-style matrix-vector multiplication and sum of rewards followed by min/max,
* for interval iteration.
* i.e. for all s: vect[s] = min/max_k { rew(s) + rew_k(s) + (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) }
* and store new values directly in {@code vect} as computed.
* Optionally, store optimal (memoryless) strategy info.
* @param vect Vector to multiply by (and store the result in)
* @param mdpRewards The rewards
* @param min Min or max for (true=min, false=max)
* @param states Perform computation for these rows, in the iteration order
* @param strat Storage for (memoryless) strategy choice indices (ignored if null)
* @param ensureMonotonic enforce monotonicity?
* @param fromBelow interval iteration from below? (for ensureMonotonic)
*/
public default void mvMultRewGSMinMaxIntervalIter(double vect[], MDPRewards mdpRewards, boolean min, PrimitiveIterator.OfInt states, int strat[], boolean ensureMonotonic, boolean fromBelow)
{
double d;
while (states.hasNext()) {
final int s = states.nextInt();
d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min, strat);
if (ensureMonotonic) {
if (fromBelow) {
// from below: do max old and new
if (vect[s] > d) {
d = vect[s];
}
} else {
// from above: do min old and new
if (vect[s] < d) {
d = vect[s];
}
}
vect[s] = d;
} else {
vect[s] = d;
}
}
}
/**
* Do a single row of Jacobi-style matrix-vector multiplication and sum of rewards followed by min/max.
* i.e. return min/max_k { rew(s) + rew_k(s) + (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) }

Loading…
Cancel
Save