diff --git a/prism/src/explicit/DTMC.java b/prism/src/explicit/DTMC.java index 4ef5ff42..951aa255 100644 --- a/prism/src/explicit/DTMC.java +++ b/prism/src/explicit/DTMC.java @@ -32,6 +32,7 @@ import java.util.PrimitiveIterator.OfInt; import common.IterableStateSet; import prism.Pair; +import prism.PrismException; import explicit.rewards.MCRewards; /** @@ -314,6 +315,52 @@ public interface DTMC extends Model return maxDiff; } + /** + * Do a Gauss-Seidel-style matrix-vector multiplication (in the interval iteration context) for + * the DTMC's transition probability matrix P and the vector {@code vect} passed in, + * storing new values directly in {@code vect} as computed (for use in interval iteration). + * i.e. for all s: vect[s] = (sum_{j!=s} P(s,j)*vect[j]) / (1-P(s,s)) + * @param vect Vector to multiply by (and store the result in) + * @param states Do multiplication for these rows, in this order + * @param ensureMonotonic ensure monotonicity + * @param checkMonotonic check monotonicity + * @param fromBelow iteration from below or from above? (for ensureMonotonicity, checkMonotonicity) + */ + public default void mvMultGSIntervalIter(double vect[], PrimitiveIterator.OfInt states, boolean ensureMonotonic, boolean checkMonotonic, boolean fromBelow) throws PrismException + { + double d; + while (states.hasNext()) { + int s = states.nextInt(); + + d = mvMultJacSingle(s, vect); + if (ensureMonotonic) { + if (fromBelow) { + // from below: do max old and new + if (vect[s] > d) { + d = vect[s]; + } + } else { + // from above: do min old and new + if (vect[s] < d) { + d = vect[s]; + } + } + } + if (checkMonotonic) { + if (fromBelow) { + if (vect[s] > d) { + throw new PrismException("Monotonicity violated (from below): old value " + vect[s] + " > new value " + d); + } + } else { + if (vect[s] < d) { + throw new PrismException("Monotonicity violated (from above): old value " + vect[s] + " < new value " + d); + } + } + } + vect[s] = d; + } + } + /** * Do a Jacobi-style matrix-vector multiplication for the DTMC's transition probability matrix P * and the vector {@code vect} passed in, for the state indices provided by the iterator, @@ -431,6 +478,50 @@ public interface DTMC extends Model return maxDiff; } + /** + * Do a matrix-vector multiplication and sum of action reward (Gauss-Seidel, interval iteration context). + * @param vect Vector to multiply by and store result in + * @param mcRewards The rewards + * @param states Do multiplication for these rows, in the specified order + * @param ensureMonotonic ensure monotonicity + * @param checkMonotonic check monotonicity + * @param fromBelow iteration from below or from above? (for ensureMonotonicity, checkMonotonicity) + */ + public default void mvMultRewGSIntervalIter(double vect[], MCRewards mcRewards, PrimitiveIterator.OfInt states, boolean ensureMonotonic, boolean checkMonotonic, boolean fromBelow) throws PrismException + { + double d; + while (states.hasNext()) { + int s = states.nextInt(); + + d = mvMultRewJacSingle(s, vect, mcRewards); + if (ensureMonotonic) { + if (fromBelow) { + // from below: do max old and new + if (vect[s] > d) { + d = vect[s]; + } + } else { + // from above: do min old and new + if (vect[s] < d) { + d = vect[s]; + } + } + } + if (checkMonotonic) { + if (fromBelow) { + if (vect[s] > d) { + throw new PrismException("Monotonicity violated (from below): old value " + vect[s] + " > new value " + d); + } + } else { + if (vect[s] < d) { + throw new PrismException("Monotonicity violated (from above): old value " + vect[s] + " < new value " + d); + } + } + } + vect[s] = d; + } + } + /** * Do a single row of matrix-vector multiplication and sum of action reward. * @param s Row index diff --git a/prism/src/explicit/IterationMethod.java b/prism/src/explicit/IterationMethod.java index 71b574d2..ea0b01ce 100644 --- a/prism/src/explicit/IterationMethod.java +++ b/prism/src/explicit/IterationMethod.java @@ -27,18 +27,22 @@ package explicit; +import java.util.PrimitiveIterator; + import common.IntSet; import common.PeriodicTimer; import explicit.rewards.MCRewards; import explicit.rewards.MDPRewards; +import prism.OptionsIntervalIteration; import prism.PrismException; +import prism.PrismSettings; import prism.PrismUtils; /** * Abstract class that encapsulates the functionality for the different iteration methods * (e.g., Power, Jacobi, Gauss-Seidel, ...). *
- * Provides methods as well to do the actual work in a (topological) value iteration. + * Provides methods as well to do the actual work in a (topological) value or interval iteration. */ public abstract class IterationMethod { @@ -64,6 +68,30 @@ public abstract class IterationMethod { public Model getModel(); } + /** + * Interface for an object that provides the atomic steps for a value iteration + * in the context of interval iteration, i.e., for the interval iteration + * there are two IterationIntervalIter objects, one from below or from above. + */ + public interface IterationIntervalIter { + /** Initialise the value iteration with the given solution vector */ + public void init(double[] soln); + /** Get the current solution vector */ + public double[] getSolnVector(); + + /** Perform one iteration (over the set of states) */ + public void iterate(IntSet states) throws PrismException; + + /** + * Solve for a given singleton SCC consisting of {@code state} using {@code solver}, + * store the result in the solution vector(s). + */ + public void solveSingletonSCC(int s, SingletonSCCSolver solver); + + /** Return the underlying model */ + public Model getModel(); + } + /** Storage for a single solution vector */ public class IterationBasic { protected final Model model; @@ -104,6 +132,15 @@ public abstract class IterationMethod { } } + /** Abstract base class for an IterationIntervalIter with a single solution vector */ + protected abstract class SingleVectorIterationIntervalIter extends IterationBasic implements IterationIntervalIter + { + public SingleVectorIterationIntervalIter(Model model) + { + super(model); + } + } + /** * Functional interface for a post-processing step after an iteration that involves * a pair of solution vectors. @@ -120,11 +157,11 @@ public abstract class IterationMethod { } /** - * Abstract base class for an IterationValIter that + * Abstract base class for an IterationValIter / IterationIntervalIter that * requires two solution vectors. * Optionally, a post processing step is performed after each iteration. */ - protected abstract class TwoVectorIteration extends IterationBasic implements IterationValIter { + protected abstract class TwoVectorIteration extends IterationBasic implements IterationValIter, IterationIntervalIter { /** The solution vector that serves as the target vector in the iteration step */ protected double[] soln2; /** Post processing, may be null */ @@ -150,6 +187,22 @@ public abstract class IterationMethod { /** Perform one iteration */ public abstract void doIterate(IntSet states) throws PrismException; + @Override + public void iterate(IntSet states) throws PrismException + { + // do the iteration + doIterate(states); + // optionally, post processing + if (postProcessor != null) { + postProcessor.apply(soln, soln2, states); + } + + // switch vectors + double[] tmp = soln; + soln = soln2; + soln2 = tmp; + } + @Override public boolean iterateAndCheckConvergence(IntSet states) throws PrismException { @@ -218,10 +271,26 @@ public abstract class IterationMethod { /** Obtain an Iteration object using mvMult (matrix-vector multiplication) in a DTMC */ public abstract IterationValIter forMvMult(DTMC dtmc) throws PrismException; + /** + * Obtain an Iteration object (for interval iteration) using mvMult + * (matrix-vector multiplication) in a DTMC. + * @param fromBelow for interval iteration from below? + * @param enforceMonotonic enforce element-wise monotonicity of the solution vector + * @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated + */ + public abstract IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException; + /** Obtain an Iteration object using mvMultRew (matrix-vector multiplication with rewards) in a DTMC */ public abstract IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew) throws PrismException; - + /** + * Obtain an Iteration object (for interval iteration) using mvMultRew + * (matrix-vector multiplication with rewards) in a DTMC. + * @param fromBelow for interval iteration from below? + * @param enforceMonotonic enforce element-wise monotonicity of the solution vector + * @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated + */ + public abstract IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException; // ------------ Abstract MDP methods ---------------------------- @@ -234,6 +303,18 @@ public abstract class IterationMethod { */ public abstract IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat) throws PrismException; + /** + * Obtain an Iteration object using mvMultMinMax (matrix-vector multiplication, followed by min/max) + * in an MDP, for interval iteration. + * @param mdp the MDP + * @param min do min? + * @param strat optional, storage for strategy, ignored if null + * @param fromBelow for interval iteration from below? + * @param enforceMonotonic enforce element-wise monotonicity of the solution vector + * @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated + */ + public abstract IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException; + /** * Obtain an Iteration object using mvMultRewMinMax (matrix-vector multiplication with rewards, followed by min/max) * in an MDP. @@ -244,6 +325,18 @@ public abstract class IterationMethod { */ public abstract IterationValIter forMvMultRewMinMax(MDP mdp, MDPRewards rewards, boolean min, int[] strat) throws PrismException; + /** + * Obtain an Iteration object using mvMultRewMinMax (matrix-vector multiplication with rewards, followed by min/max) + * in an MDP, for interval iteration. + * @param mdp the MDP + * @param rewards the reward structure + * @param min do min? + * @param strat optional, storage for strategy, ignored if null + * @param fromBelow for interval iteration from below? + * @param enforceMonotonic enforce element-wise monotonicity of the solution vector + * @param checkMonotonic check the element-wise monotonicity of the solution vector, throw exception if violated + */ + public abstract IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException; // ------------ Abstract generic methods ---------------------------- @@ -416,6 +509,289 @@ public abstract class IterationMethod { return res; } + /** + * Perform the actual work of an interval iteration, i.e., iterate until convergence or abort. + * + * @param mc ProbModelChecker (for log and settings) + * @param description Description (for logging) + * @param below The iteration object for the iteration from below + * @param above The iteration object for the iteration from above + * @param unknownStates The set of unknown states, i.e., whose value should be determined + * @param startTime The start time (for logging purposes, obtained from a call to System.currentTimeMillis()) + * @param iterationsExport an ExportIterations object (optional, ignored if null) + * @return a ModelChecker result with the solution vector and statistics + * @throws PrismException on non-convergence (if mc.errorOnNonConverge is set) + */ + public ModelCheckerResult doIntervalIteration(ProbModelChecker mc, String description, IterationIntervalIter below, IterationIntervalIter above, IntSet unknownStates, long timer, ExportIterations iterationsExport) throws PrismException { + try { + // Start iterations + int iters = 0; + final int maxIters = mc.maxIters; + boolean done = false; + + PeriodicTimer updatesTimer = new PeriodicTimer(ProbModelChecker.UPDATE_DELAY); + updatesTimer.start(); + + while (!done && iters < maxIters) { + iters++; + // Matrix-vector multiply + below.iterate(unknownStates); + above.iterate(unknownStates); + + if (iterationsExport != null) { + iterationsExport.exportVector(below.getSolnVector(), 0); + iterationsExport.exportVector(above.getSolnVector(), 1); + } + + intervalIterationCheckForProblems(below.getSolnVector(), above.getSolnVector(), unknownStates.iterator()); + + // Check termination + done = PrismUtils.doublesAreClose(below.getSolnVector(), above.getSolnVector(), termCritParam, absolute); + + if (done) { + double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute); + mc.getLog().println("Max " + (!absolute ? "relative ": "") + + "diff between upper and lower bound on convergence: " + PrismUtils.formatDouble(diff)); + done = true; + } + + if (!done && updatesTimer.triggered()) { + double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute); + mc.getLog().print("Iteration " + iters + ": "); + mc.getLog().print("max " + (absolute ? "" : "relative ") + "diff=" + PrismUtils.formatDouble(diff)); + mc.getLog().println(", " + PrismUtils.formatDouble2dp(updatesTimer.elapsedMillisTotal() / 1000.0) + " sec so far"); + } + } + + // Finished value iteration + long mvCount = 2 * iters * countTransitions(below.getModel(), unknownStates); + timer = System.currentTimeMillis() - timer; + mc.getLog().print("Interval iteration (" + description + ")"); + mc.getLog().print(" took " + iters + " iterations, "); + mc.getLog().print(mvCount + " MV-multiplications"); + mc.getLog().println(" and " + timer / 1000.0 + " seconds."); + + if (done && OptionsIntervalIteration.from(mc.getSettings()).isSelectMidpointForResult()) { + PrismUtils.selectMidpoint(below.getSolnVector(), above.getSolnVector()); + + if (iterationsExport != null) { + // export midpoint + iterationsExport.exportVector(below.getSolnVector(), 0); + iterationsExport.exportVector(below.getSolnVector(), 1); + } + } + + // Non-convergence is an error (usually) + if (!done && mc.errorOnNonConverge) { + String msg = "Iterative method (interval iteration) did not converge within " + iters + " iterations."; + msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; + throw new PrismException(msg); + } + + // Return results + ModelCheckerResult res = new ModelCheckerResult(); + res.soln = below.getSolnVector(); + res.numIters = iters; + res.timeTaken = timer / 1000.0; + return res; + } finally { + if (iterationsExport != null) + iterationsExport.close(); + } + } + + /** + * Perform the actual work of a topological interval iteration, i.e., iterate until convergence or abort. + * + * @param mc ProbModelChecker (for log and settings) + * @param description Description (for logging) + * @param sccs The information about the SCCs and topological order + * @param below The iteration object for the value iteration from below + * @param above The iteration object for the value iteration from above + * @param singletonSCCSolver The solver for singleton SCCs + * @param startTime The start time (for logging purposes, obtained from a call to System.currentTimeMillis()) + * @param iterationsExport an ExportIterations object (optional, ignored if null) + * @return a ModelChecker result with the solution vector and statistics + * @throws PrismException on non-convergence (if mc.errorOnNonConverge is set) + */ + public ModelCheckerResult doTopologicalIntervalIteration(ProbModelChecker mc, String description, SCCInfo sccs, IterationIntervalIter below, IterationIntervalIter above, SingletonSCCSolver singletonSCCSolver, long timer, ExportIterations iterationsExport) throws PrismException { + try { + // Start iterations + int iters = 0; + long mvCount = 0; + final int maxIters = mc.maxIters; + + PeriodicTimer updatesTimer = new PeriodicTimer(ProbModelChecker.UPDATE_DELAY); + updatesTimer.start(); + + int numSCCs = sccs.getNumSCCs(); + int numNonSingletonSCCs = sccs.countNonSingletonSCCs(); + int finishedNonSingletonSCCs = 0; + + boolean done = true; + for (int scc = 0; scc < numSCCs; scc++) { + boolean doneSCC; + + if (sccs.isSingletonSCC(scc)) { + // get the single state in this SCC + int state = sccs.getStatesForSCC(scc).iterator().nextInt(); + below.solveSingletonSCC(state, singletonSCCSolver); + above.solveSingletonSCC(state, singletonSCCSolver); + + iters++; + mvCount += 2 * countTransitions(below.getModel(), IntSet.asIntSet(state)); + + if (iterationsExport != null) { + iterationsExport.exportVector(below.getSolnVector(), 0); + iterationsExport.exportVector(above.getSolnVector(), 1); + } + + intervalIterationCheckForProblems(below.getSolnVector(), above.getSolnVector(), IntSet.asIntSet(state).iterator()); + + doneSCC = true; + } else { + // complex SCC: do VI + doneSCC = false; + int itersInSCC = 0; + + IntSet statesForSCC = sccs.getStatesForSCC(scc); + + // Adjust upper bound by adding 2*epsilon, + // adding 1*epsilon would be fine, but we are a bit more conservative. + // TODO: We also don't really need to do adjustment for bottom SCCs... + PrimitiveIterator.OfInt it = statesForSCC.iterator(); + final double[] solnAbove = above.getSolnVector(); + final double adjustment = 2*termCritParam; + while (it.hasNext()) { + solnAbove[it.nextInt()] += adjustment; + } + + // abort on convergence or if iterations *in this SCC* are above maxIters + while (!doneSCC && itersInSCC < maxIters) { + iters++; + itersInSCC++; + + // do iteration step + below.iterate(statesForSCC); + above.iterate(statesForSCC); + + if (iterationsExport != null) { + iterationsExport.exportVector(below.getSolnVector(), 0); + iterationsExport.exportVector(above.getSolnVector(), 1); + } + + intervalIterationCheckForProblems(below.getSolnVector(), above.getSolnVector(), statesForSCC.iterator()); + + // Check termination (inside SCC) + doneSCC = PrismUtils.doublesAreClose(below.getSolnVector(), above.getSolnVector(), statesForSCC.iterator(), termCritParam, absolute); + + if (!doneSCC && updatesTimer.triggered()) { + double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute, statesForSCC.iterator()); + mc.getLog().print("Iteration " + iters + ": "); + mc.getLog().print("max " + (absolute ? "" : "relative ") + "diff (for iteration " + itersInSCC + " in current SCC " + (finishedNonSingletonSCCs+1) + " of " + numNonSingletonSCCs + ") = " + PrismUtils.formatDouble(diff)); + mc.getLog().println(", " + PrismUtils.formatDouble2dp(updatesTimer.elapsedMillisTotal() / 1000.0) + " sec so far"); + } + } + + mvCount += 2 * itersInSCC * countTransitions(below.getModel(), statesForSCC); + finishedNonSingletonSCCs++; + } + + if (!doneSCC) { + done = false; + break; + } + } + + if (done) { + double diff = PrismUtils.measureSupNormInterval(below.getSolnVector(), above.getSolnVector(), absolute); + mc.getLog().println("Max " + (absolute ? "" : "relative ") + + "diff between upper and lower bound on convergence: " + PrismUtils.formatDouble(diff)); + done = true; + } + + // Finished value iteration + timer = System.currentTimeMillis() - timer; + mc.getLog().print("Interval iteration (" + description + ", with " + numNonSingletonSCCs + " non-singleton SCCs)"); + mc.getLog().print(" took " + iters + " iterations, "); + mc.getLog().print(mvCount + " MV-multiplications"); + mc.getLog().println(" and " + timer / 1000.0 + " seconds."); + + if (done && OptionsIntervalIteration.from(mc.getSettings()).isSelectMidpointForResult()) { + PrismUtils.selectMidpoint(below.getSolnVector(), above.getSolnVector()); + + if (iterationsExport != null) { + // export midpoint + iterationsExport.exportVector(below.getSolnVector(), 0); + iterationsExport.exportVector(below.getSolnVector(), 1); + } + } + + if (iterationsExport != null) + iterationsExport.close(); + + // Non-convergence is an error (usually) + if (!done && mc.errorOnNonConverge) { + String msg = "Iterative method (interval iteration) did not converge within " + iters + " iterations."; + msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; + throw new PrismException(msg); + } + + // Return results + ModelCheckerResult res = new ModelCheckerResult(); + res.soln = below.getSolnVector(); + res.numIters = iters; + res.timeTaken = timer / 1000.0; + return res; + } finally { + if (iterationsExport != null) + iterationsExport.close(); + } + } + + /** + * Compares the current lower and upper solution vectors in an interval iteration + * and throws an exception if lower bound values are larger than upper bound values, + * as this indicates problems. + * @param lower the current lower iteration solution vector + * @param upper the current upper iteration solution vector + * @param states iterator over the states in question + */ + private static void intervalIterationCheckForProblems(double[] lower, double[] upper, PrimitiveIterator.OfInt states) throws PrismException + { + while (states.hasNext()) { + int s = states.nextInt(); + if (lower[s] > upper[s]) { + throw new PrismException("In interval iteration, the lower value (" + lower[s] + ") is larger than the upper value (" + upper[s] + ").\n" + + "This indicates either problems with numerical stability (rounding, precision of the floating-point representation) or that the initial bounds (for reward computations) are incorrect"); + } + } + } + + /** + * Perform a post-processing for a two-vector value iteration. + * @param solnOld the previous solution vector + * @param solnNew the newly computed solution vector + * @param states the relevant set of states + * @param fromBelow are we iterating from below? + * @param enforceMonotonicity if true, enforces monotonicity + * @param checkMonotonicity if true, checks for monotonicity (and throws error when non-monotonic) + */ + public static void twoVectorPostProcessing(double[] solnOld, double[] solnNew, IntSet states, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException + { + // TODO: use IntSet states + if (enforceMonotonicity) + if (fromBelow) { + PrismUtils.ensureMonotonicityFromBelow(solnOld, solnNew); + } else { + PrismUtils.ensureMonotonicityFromAbove(solnOld, solnNew); + } + + if (checkMonotonicity) { + PrismUtils.checkMonotonicity(solnOld, solnNew, !fromBelow); + } + } + protected long countTransitions(Model model, IntSet unknownStates) { if (model instanceof DTMC) { diff --git a/prism/src/explicit/IterationMethodGS.java b/prism/src/explicit/IterationMethodGS.java index e244ae70..349131ae 100644 --- a/prism/src/explicit/IterationMethodGS.java +++ b/prism/src/explicit/IterationMethodGS.java @@ -68,6 +68,23 @@ public class IterationMethodGS extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) + { + return new SingleVectorIterationIntervalIter(dtmc) { + @Override + public void iterate(IntSet states) throws PrismException + { + // Matrix-vector multiply + dtmc.mvMultGSIntervalIter(soln, + backwards ? states.reversedIterator() : states.iterator(), + enforceMonotonicity, + checkMonotonicity, + fromBelow); + } + }; + } + @Override public IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew) { @@ -87,6 +104,24 @@ public class IterationMethodGS extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) + { + return new SingleVectorIterationIntervalIter(dtmc) { + @Override + public void iterate(IntSet states) throws PrismException + { + // Matrix-vector multiply + dtmc.mvMultRewGSIntervalIter(soln, + rew, + backwards ? states.reversedIterator() : states.iterator(), + enforceMonotonicity, + checkMonotonicity, + fromBelow); + } + }; + } + @Override public IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat) { @@ -107,6 +142,27 @@ public class IterationMethodGS extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity, + boolean checkMonotonicity) throws PrismException + { + return new SingleVectorIterationIntervalIter(mdp) { + @Override + public void iterate(IntSet states) + { + // TODO: check monotonic not yet supported + + // Matrix-vector multiply + mdp.mvMultGSMinMaxIntervalIter(soln, + min, + backwards ? states.reversedIterator() : states.iterator(), + strat, + enforceMonotonicity, + fromBelow); + } + }; + } + @Override public IterationValIter forMvMultRewMinMax(MDP mdp, MDPRewards rewards, boolean min, int[] strat) throws PrismException { @@ -128,6 +184,28 @@ public class IterationMethodGS extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow, + boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException + { + return new SingleVectorIterationIntervalIter(mdp) { + @Override + public void iterate(IntSet states) + { + // TODO: check monotonic not yet supported + + // Matrix-vector multiply + mdp.mvMultRewGSMinMaxIntervalIter(soln, + rewards, + min, + backwards ? states.reversedIterator() : states.iterator(), + strat, + enforceMonotonicity, + fromBelow); + } + }; + } + @Override public String getDescriptionShort() { diff --git a/prism/src/explicit/IterationMethodJacobi.java b/prism/src/explicit/IterationMethodJacobi.java index c74735c0..83a8c8b1 100644 --- a/prism/src/explicit/IterationMethodJacobi.java +++ b/prism/src/explicit/IterationMethodJacobi.java @@ -59,6 +59,22 @@ class IterationMethodJacobi extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) + { + IterationPostProcessor post = (soln, soln2, states) -> { + twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity); + }; + + return new TwoVectorIteration(dtmc, post) { + @Override + public void doIterate(IntSet states) + { + dtmc.mvMultJac(soln, soln2, states.iterator()); + } + }; + } + @Override public IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew) { @@ -71,18 +87,48 @@ class IterationMethodJacobi extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) + { + IterationPostProcessor post = (soln, soln2, states) -> { + twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity); + }; + + return new TwoVectorIteration(dtmc, post) { + @Override + public void doIterate(IntSet states) + { + dtmc.mvMultRewJac(soln, rew, soln2, states.iterator()); + } + }; + } + @Override public IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat) throws PrismException { throw new PrismNotSupportedException("Jacobi not supported for MDPs"); } + @Override + public IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity, + boolean checkMonotonicity) throws PrismException + { + throw new PrismNotSupportedException("Jacobi not supported for MDPs"); + } + @Override public IterationValIter forMvMultRewMinMax(MDP mdp, MDPRewards rewards, boolean min, int[] strat) throws PrismException { throw new PrismNotSupportedException("Jacobi not supported for MDPs"); } + @Override + public IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow, + boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException + { + throw new PrismNotSupportedException("Jacobi not supported for MDPs"); + } + @Override public String getDescriptionShort() { diff --git a/prism/src/explicit/IterationMethodPower.java b/prism/src/explicit/IterationMethodPower.java index 7386eb2f..a066ef97 100644 --- a/prism/src/explicit/IterationMethodPower.java +++ b/prism/src/explicit/IterationMethodPower.java @@ -58,6 +58,22 @@ public class IterationMethodPower extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultInterval(DTMC dtmc, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) + { + IterationPostProcessor post = (soln, soln2, states) -> { + twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity); + }; + + return new TwoVectorIteration(dtmc, post) { + @Override + public void doIterate(IntSet states) + { + dtmc.mvMult(soln, soln2, states.iterator()); + } + }; + } + @Override public IterationValIter forMvMultRew(DTMC dtmc, MCRewards rew) { @@ -70,6 +86,22 @@ public class IterationMethodPower extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultRewInterval(DTMC dtmc, MCRewards rew, boolean fromBelow, boolean enforceMonotonicity, boolean checkMonotonicity) + { + IterationPostProcessor post = (soln, soln2, states) -> { + twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity); + }; + + return new TwoVectorIteration(dtmc, post) { + @Override + public void doIterate(IntSet states) + { + dtmc.mvMultRew(soln, rew, soln2, states.iterator()); + } + }; + } + @Override public IterationValIter forMvMultMinMax(MDP mdp, boolean min, int[] strat) { @@ -82,6 +114,23 @@ public class IterationMethodPower extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultMinMaxInterval(MDP mdp, boolean min, int[] strat, boolean fromBelow, boolean enforceMonotonicity, + boolean checkMonotonicity) throws PrismException + { + IterationPostProcessor post = (soln, soln2, states) -> { + twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity); + }; + + return new TwoVectorIteration(mdp, post) { + @Override + public void doIterate(IntSet states) + { + mdp.mvMultMinMax(soln, min, soln2, states.iterator(), strat); + } + }; + } + @Override public String getDescriptionShort() { @@ -100,4 +149,21 @@ public class IterationMethodPower extends IterationMethod { }; } + @Override + public IterationIntervalIter forMvMultRewMinMaxInterval(MDP mdp, MDPRewards rewards, boolean min, int[] strat, boolean fromBelow, + boolean enforceMonotonicity, boolean checkMonotonicity) throws PrismException + { + IterationPostProcessor post = (soln, soln2, states) -> { + twoVectorPostProcessing(soln, soln2, states, fromBelow, enforceMonotonicity, checkMonotonicity); + }; + + return new TwoVectorIteration(mdp, post) { + @Override + public void doIterate(IntSet states) + { + mdp.mvMultRewMinMax(soln, rewards, min, soln2, states.iterator(), strat); + } + }; + } + } diff --git a/prism/src/explicit/MDP.java b/prism/src/explicit/MDP.java index b0c76074..1a71fdf9 100644 --- a/prism/src/explicit/MDP.java +++ b/prism/src/explicit/MDP.java @@ -424,6 +424,43 @@ public interface MDP extends NondetModel return maxDiff; } + /** + * Do a Gauss-Seidel-style matrix-vector multiplication followed by min/max in the context of interval iteration. + * i.e. for all s: vect[s] = min/max_k { (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) } + * and store new values directly in {@code vect} as computed. + * Optionally, store optimal (memoryless) strategy info. + * @param vect Vector to multiply by (and store the result in) + * @param min Min or max for (true=min, false=max) + * @param subset Only do multiplication for these rows (ignored if null) + * @param states Perform computation for these rows, in the iteration order + * @param ensureMonotonic ensure monotonicity + * @param fromBelow iteration from below or from above? (for ensureMonotonicity) + */ + public default void mvMultGSMinMaxIntervalIter(double vect[], boolean min, PrimitiveIterator.OfInt states, int strat[], boolean ensureMonotonic, boolean fromBelow) + { + double d; + while (states.hasNext()) { + final int s = states.nextInt(); + d = mvMultJacMinMaxSingle(s, vect, min, strat); + if (ensureMonotonic) { + if (fromBelow) { + // from below: do max old and new + if (vect[s] > d) { + d = vect[s]; + } + } else { + // from above: do min old and new + if (vect[s] < d) { + d = vect[s]; + } + } + vect[s] = d; + } else { + vect[s] = d; + } + } + } + /** * Do a single row of Jacobi-style matrix-vector multiplication followed by min/max. * i.e. return min/max_k { (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) } @@ -663,6 +700,45 @@ public interface MDP extends NondetModel return maxDiff; } + /** + * Do a Gauss-Seidel-style matrix-vector multiplication and sum of rewards followed by min/max, + * for interval iteration. + * i.e. for all s: vect[s] = min/max_k { rew(s) + rew_k(s) + (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) } + * and store new values directly in {@code vect} as computed. + * Optionally, store optimal (memoryless) strategy info. + * @param vect Vector to multiply by (and store the result in) + * @param mdpRewards The rewards + * @param min Min or max for (true=min, false=max) + * @param states Perform computation for these rows, in the iteration order + * @param strat Storage for (memoryless) strategy choice indices (ignored if null) + * @param ensureMonotonic enforce monotonicity? + * @param fromBelow interval iteration from below? (for ensureMonotonic) + */ + public default void mvMultRewGSMinMaxIntervalIter(double vect[], MDPRewards mdpRewards, boolean min, PrimitiveIterator.OfInt states, int strat[], boolean ensureMonotonic, boolean fromBelow) + { + double d; + while (states.hasNext()) { + final int s = states.nextInt(); + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min, strat); + if (ensureMonotonic) { + if (fromBelow) { + // from below: do max old and new + if (vect[s] > d) { + d = vect[s]; + } + } else { + // from above: do min old and new + if (vect[s] < d) { + d = vect[s]; + } + } + vect[s] = d; + } else { + vect[s] = d; + } + } + } + /** * Do a single row of Jacobi-style matrix-vector multiplication and sum of rewards followed by min/max. * i.e. return min/max_k { rew(s) + rew_k(s) + (sum_{j!=s} P_k(s,j)*vect[j]) / 1-P_k(s,s) }