diff --git a/prism/src/explicit/MDPModelChecker.java b/prism/src/explicit/MDPModelChecker.java index d2d5fb0d..b9120698 100644 --- a/prism/src/explicit/MDPModelChecker.java +++ b/prism/src/explicit/MDPModelChecker.java @@ -53,6 +53,7 @@ import acceptance.AcceptanceReach; import acceptance.AcceptanceType; import automata.DA; import automata.LTL2WDBA; +import common.IntSet; import common.IterableBitSet; import explicit.modelviews.EquivalenceRelationInteger; import explicit.modelviews.MDPEquiv; @@ -525,12 +526,13 @@ public class MDPModelChecker extends ProbModelChecker { ModelCheckerResult res = null; + IterationMethod iterationMethod = null; switch (method) { case VALUE_ITERATION: - res = computeReachProbsValIter(mdp, no, yes, min, init, known, strat); + iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam); break; case GAUSS_SEIDEL: - res = computeReachProbsGaussSeidel(mdp, no, yes, min, init, known, strat); + iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false); break; case POLICY_ITERATION: res = computeReachProbsPolIter(mdp, no, yes, min, strat); @@ -542,6 +544,10 @@ public class MDPModelChecker extends ProbModelChecker throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName()); } + if (res == null) { // not yet computed, use iterationMethod + res = doValueIterationReachProbs(mdp, no, yes, min, init, known, iterationMethod, false, strat); + } + return res; } @@ -751,29 +757,48 @@ public class MDPModelChecker extends ProbModelChecker protected ModelCheckerResult computeReachProbsValIter(MDP mdp, BitSet no, BitSet yes, boolean min, double init[], BitSet known, int strat[]) throws PrismException { - ModelCheckerResult res; + IterationMethodPower iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam); + return doValueIterationReachProbs(mdp, no, yes, min, init, known, iterationMethod, false, strat); + } + + /** + * Compute reachability probabilities using value iteration. + * Optionally, store optimal (memoryless) strategy info. + * @param mdp The MDP + * @param no Probability 0 states + * @param yes Probability 1 states + * @param min Min or max probabilities (true=min, false=max) + * @param init Optionally, an initial solution vector (will be overwritten) + * @param known Optionally, a set of states for which the exact answer is known + * @param iterationMethod The iteration method + * @param topological Do topological value iteration? + * @param strat Storage for (memoryless) strategy choice indices (ignored if null) + * Note: if 'known' is specified (i.e. is non-null), 'init' must also be given and is used for the exact values. + */ + protected ModelCheckerResult doValueIterationReachProbs(MDP mdp, BitSet no, BitSet yes, boolean min, double init[], BitSet known, IterationMethod iterationMethod, boolean topological, int strat[]) + throws PrismException + { BitSet unknown; - int i, n, iters; - double soln[], soln2[], tmpsoln[], initVal; - boolean done; + int i, n; + double initVal; long timer; // Start value iteration timer = System.currentTimeMillis(); - mainLog.println("Starting value iteration (" + (min ? "min" : "max") + ")..."); + String description = (min ? "min" : "max") + + (topological ? ", topological": "" ) + + ", with " + iterationMethod.getDescriptionShort(); + + mainLog.println("Starting value iteration (" + description + ")..."); ExportIterations iterationsExport = null; if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) { - iterationsExport = new ExportIterations("Explicit ReachRewards value iteration"); + iterationsExport = new ExportIterations("Explicit MDP ReachProbs value iteration (" + description + ")"); } // Store num states n = mdp.getNumStates(); - // Create solution vector(s) - soln = new double[n]; - soln2 = (init == null) ? new double[n] : init; - // Initialise solution vectors. Use (where available) the following in order of preference: // (1) exact answer, if already known; (2) 1.0/0.0 if in yes/no; (3) passed in initial value; (4) initVal // where initVal is 0.0 or 1.0, depending on whether we converge from below/above. @@ -781,14 +806,15 @@ public class MDPModelChecker extends ProbModelChecker if (init != null) { if (known != null) { for (i = 0; i < n; i++) - soln[i] = soln2[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; + init[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; } else { for (i = 0; i < n; i++) - soln[i] = soln2[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; + init[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; } } else { + init = new double[n]; for (i = 0; i < n; i++) - soln[i] = soln2[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal; + init[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal; } // Determine set of states actually need to compute values for @@ -800,48 +826,28 @@ public class MDPModelChecker extends ProbModelChecker unknown.andNot(known); if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); + iterationsExport.exportVector(init, 0); - // Start iterations - iters = 0; - done = false; - while (!done && iters < maxIters) { - iters++; - // Matrix-vector multiply and min/max ops - mdp.mvMultMinMax(soln, min, soln2, unknown, false, strat); + IterationMethod.IterationValIter iteration = iterationMethod.forMvMultMinMax(mdp, min, strat); + iteration.init(init); - if (iterationsExport != null) - iterationsExport.exportVector(soln2, 0); + IntSet unknownStates = IntSet.asIntSet(unknown); - // Check termination - done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE); - // Swap vectors for next iter - tmpsoln = soln; - soln = soln2; - soln2 = tmpsoln; - } + if (topological) { + // Compute SCCInfo, including trivial SCCs in the subgraph obtained when only considering + // states in unknown + SCCInfo sccs = SCCComputer.computeTopologicalOrdering(this, mdp, true, unknown::get); - // Finished value iteration - timer = System.currentTimeMillis() - timer; - mainLog.print("Value iteration (" + (min ? "min" : "max") + ")"); - mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); + IterationMethod.SingletonSCCSolver singletonSCCSolver = (int s, double[] soln) -> { + soln[s] = mdp.mvMultJacMinMaxSingle(s, soln, min, strat); + }; - if (iterationsExport != null) - iterationsExport.close(); - - // Non-convergence is an error (usually) - if (!done && errorOnNonConverge) { - String msg = "Iterative method did not converge within " + iters + " iterations."; - msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; - throw new PrismException(msg); + // run the actual value iteration + return iterationMethod.doTopologicalValueIteration(this, description, sccs, iteration, singletonSCCSolver, timer, iterationsExport); + } else { + // run the actual value iteration + return iterationMethod.doValueIteration(this, description, iteration, unknownStates, timer, iterationsExport); } - - // Return results - res = new ModelCheckerResult(); - res.soln = soln; - res.numIters = iters; - res.timeTaken = timer / 1000.0; - return res; } /** @@ -858,92 +864,8 @@ public class MDPModelChecker extends ProbModelChecker protected ModelCheckerResult computeReachProbsGaussSeidel(MDP mdp, BitSet no, BitSet yes, boolean min, double init[], BitSet known, int strat[]) throws PrismException { - ModelCheckerResult res; - BitSet unknown; - int i, n, iters; - double soln[], initVal, maxDiff; - boolean done; - long timer; - - // Start value iteration - timer = System.currentTimeMillis(); - mainLog.println("Starting Gauss-Seidel (" + (min ? "min" : "max") + ")..."); - - ExportIterations iterationsExport = null; - if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) { - iterationsExport = new ExportIterations("Explicit MDP ReachProbs Gauss-Seidel iteration"); - } - - // Store num states - n = mdp.getNumStates(); - - // Create solution vector - soln = (init == null) ? new double[n] : init; - - // Initialise solution vector. Use (where available) the following in order of preference: - // (1) exact answer, if already known; (2) 1.0/0.0 if in yes/no; (3) passed in initial value; (4) initVal - // where initVal is 0.0 or 1.0, depending on whether we converge from below/above. - initVal = (valIterDir == ValIterDir.BELOW) ? 0.0 : 1.0; - if (init != null) { - if (known != null) { - for (i = 0; i < n; i++) - soln[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; - } else { - for (i = 0; i < n; i++) - soln[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; - } - } else { - for (i = 0; i < n; i++) - soln[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal; - } - - // Determine set of states actually need to compute values for - unknown = new BitSet(); - unknown.set(0, n); - unknown.andNot(yes); - unknown.andNot(no); - if (known != null) - unknown.andNot(known); - - if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); - - // Start iterations - iters = 0; - done = false; - while (!done && iters < maxIters) { - iters++; - // Matrix-vector multiply - maxDiff = mdp.mvMultGSMinMax(soln, min, unknown, false, termCrit == TermCrit.ABSOLUTE, strat); - - if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); - - // Check termination - done = maxDiff < termCritParam; - } - - // Finished Gauss-Seidel - timer = System.currentTimeMillis() - timer; - mainLog.print("Gauss-Seidel"); - mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); - - if (iterationsExport != null) - iterationsExport.close(); - - // Non-convergence is an error (usually) - if (!done && errorOnNonConverge) { - String msg = "Iterative method did not converge within " + iters + " iterations."; - msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; - throw new PrismException(msg); - } - - // Return results - res = new ModelCheckerResult(); - res.soln = soln; - res.numIters = iters; - res.timeTaken = timer / 1000.0; - return res; + IterationMethodGS iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false); + return doValueIterationReachProbs(mdp, no, yes, min, init, known, iterationMethod, false, strat); } /** @@ -1628,20 +1550,7 @@ public class MDPModelChecker extends ProbModelChecker } } - // Compute rewards - switch (mdpSolnMethod) { - case VALUE_ITERATION: - res = computeReachRewardsValIter(mdp, mdpRewards, target, inf, min, init, known, strat); - break; - case GAUSS_SEIDEL: - res = computeReachRewardsGaussSeidel(mdp, mdpRewards, target, inf, min, init, known, strat); - break; - case POLICY_ITERATION: - res = computeReachRewardsPolIter(mdp, mdpRewards, target, inf, min, strat); - break; - default: - throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName()); - } + res = computeReachRewardsNumeric(mdp, mdpRewards, mdpSolnMethod, target, inf, min, init, known, strat); // Store strategy if (genStrat) { @@ -1668,6 +1577,32 @@ public class MDPModelChecker extends ProbModelChecker return res; } + protected ModelCheckerResult computeReachRewardsNumeric(MDP mdp, MDPRewards mdpRewards, MDPSolnMethod method, BitSet target, BitSet inf, boolean min, double init[], BitSet known, int strat[]) throws PrismException + { + ModelCheckerResult res = null; + + IterationMethod iterationMethod = null; + switch (method) { + case VALUE_ITERATION: + iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam); + break; + case GAUSS_SEIDEL: + iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false); + break; + case POLICY_ITERATION: + res = computeReachRewardsPolIter(mdp, mdpRewards, target, inf, min, strat); + break; + default: + throw new PrismException("Unknown MDP solution method " + method.fullName()); + } + + if (res == null) { // not yet computed, use iterationMethod + res = doValueIterationReachRewards(mdp, mdpRewards, iterationMethod, target, inf, min, init, known, false, strat); + } + + return res; + } + /** * Compute expected reachability rewards using value iteration. * Optionally, store optimal (memoryless) strategy info. @@ -1684,42 +1619,58 @@ public class MDPModelChecker extends ProbModelChecker protected ModelCheckerResult computeReachRewardsValIter(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], BitSet known, int strat[]) throws PrismException { - ModelCheckerResult res; + IterationMethodPower iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam); + return doValueIterationReachRewards(mdp, mdpRewards, iterationMethod, target, inf, min, init, known, min, strat); + } + + /** + * Compute expected reachability rewards using value iteration. + * Optionally, store optimal (memoryless) strategy info. + * @param mdp The MDP + * @param mdpRewards The rewards + * @param target Target states + * @param inf States for which reward is infinite + * @param min Min or max rewards (true=min, false=max) + * @param init Optionally, an initial solution vector (will be overwritten) + * @param known Optionally, a set of states for which the exact answer is known + * @param topological Do topological value iteration? + * @param strat Storage for (memoryless) strategy choice indices (ignored if null) + * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. + */ + protected ModelCheckerResult doValueIterationReachRewards(MDP mdp, MDPRewards mdpRewards, IterationMethod iterationMethod, BitSet target, BitSet inf, boolean min, double init[], BitSet known, boolean topological, int strat[]) + throws PrismException + { BitSet unknown; - int i, n, iters; - double soln[], soln2[], tmpsoln[]; - boolean done; + int i, n; long timer; // Start value iteration timer = System.currentTimeMillis(); - mainLog.println("Starting value iteration (" + (min ? "min" : "max") + ")..."); + String description = (min ? "min" : "max") + (topological ? ", topological" : "" ) + ", with " + iterationMethod.getDescriptionShort(); + mainLog.println("Starting value iteration (" + description + ")..."); ExportIterations iterationsExport = null; if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) { - iterationsExport = new ExportIterations("Explicit MDP ReachProbs value iteration"); + iterationsExport = new ExportIterations("Explicit MDP ReachRewards value iteration (" + description +")"); } // Store num states n = mdp.getNumStates(); - // Create solution vector(s) - soln = new double[n]; - soln2 = (init == null) ? new double[n] : init; - // Initialise solution vectors. Use (where available) the following in order of preference: // (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0 if (init != null) { if (known != null) { for (i = 0; i < n; i++) - soln[i] = soln2[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + init[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; } else { for (i = 0; i < n; i++) - soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + init[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; } } else { + init = new double[n]; for (i = 0; i < n; i++) - soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; + init[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; } // Determine set of states actually need to compute values for @@ -1731,49 +1682,28 @@ public class MDPModelChecker extends ProbModelChecker unknown.andNot(known); if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); - - // Start iterations - iters = 0; - done = false; - while (!done && iters < maxIters) { - //mainLog.println(soln); - iters++; - // Matrix-vector multiply and min/max ops - mdp.mvMultRewMinMax(soln, mdpRewards, min, soln2, unknown, false, strat); + iterationsExport.exportVector(init, 0); - if (iterationsExport != null) - iterationsExport.exportVector(soln2, 0); + IterationMethod.IterationValIter forMvMultRewMinMax = iterationMethod.forMvMultRewMinMax(mdp, mdpRewards, min, strat); + forMvMultRewMinMax.init(init); - // Check termination - done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE); - // Swap vectors for next iter - tmpsoln = soln; - soln = soln2; - soln2 = tmpsoln; - } + IntSet unknownStates = IntSet.asIntSet(unknown); - if (iterationsExport != null) - iterationsExport.close(); + if (topological) { + // Compute SCCInfo, including trivial SCCs in the subgraph obtained when only considering + // states in unknown + SCCInfo sccs = SCCComputer.computeTopologicalOrdering(this, mdp, true, unknown::get); - // Finished value iteration - timer = System.currentTimeMillis() - timer; - mainLog.print("Value iteration (" + (min ? "min" : "max") + ")"); - mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); + IterationMethod.SingletonSCCSolver singletonSCCSolver = (int s, double[] soln) -> { + soln[s] = mdp.mvMultRewJacMinMaxSingle(s, soln, mdpRewards, min, strat); + }; - // Non-convergence is an error (usually) - if (!done && errorOnNonConverge) { - String msg = "Iterative method did not converge within " + iters + " iterations."; - msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; - throw new PrismException(msg); + // run the actual value iteration + return iterationMethod.doTopologicalValueIteration(this, description, sccs, forMvMultRewMinMax, singletonSCCSolver, timer, iterationsExport); + } else { + // run the actual value iteration + return iterationMethod.doValueIteration(this, description, forMvMultRewMinMax, unknownStates, timer, iterationsExport); } - - // Return results - res = new ModelCheckerResult(); - res.soln = soln; - res.numIters = iters; - res.timeTaken = timer / 1000.0; - return res; } /** @@ -1792,76 +1722,8 @@ public class MDPModelChecker extends ProbModelChecker protected ModelCheckerResult computeReachRewardsGaussSeidel(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], BitSet known, int strat[]) throws PrismException { - ModelCheckerResult res; - BitSet unknown; - int i, n, iters; - double soln[], maxDiff; - boolean done; - long timer; - - // Start value iteration - timer = System.currentTimeMillis(); - mainLog.println("Starting Gauss-Seidel (" + (min ? "min" : "max") + ")..."); - - // Store num states - n = mdp.getNumStates(); - - // Create solution vector(s) - soln = (init == null) ? new double[n] : init; - - // Initialise solution vector. Use (where available) the following in order of preference: - // (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0 - if (init != null) { - if (known != null) { - for (i = 0; i < n; i++) - soln[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; - } else { - for (i = 0; i < n; i++) - soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; - } - } else { - for (i = 0; i < n; i++) - soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; - } - - // Determine set of states actually need to compute values for - unknown = new BitSet(); - unknown.set(0, n); - unknown.andNot(target); - unknown.andNot(inf); - if (known != null) - unknown.andNot(known); - - // Start iterations - iters = 0; - done = false; - while (!done && iters < maxIters) { - //mainLog.println(soln); - iters++; - // Matrix-vector multiply and min/max ops - maxDiff = mdp.mvMultRewGSMinMax(soln, mdpRewards, min, unknown, false, termCrit == TermCrit.ABSOLUTE, strat); - // Check termination - done = maxDiff < termCritParam; - } - - // Finished Gauss-Seidel - timer = System.currentTimeMillis() - timer; - mainLog.print("Gauss-Seidel (" + (min ? "min" : "max") + ")"); - mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); - - // Non-convergence is an error (usually) - if (!done && errorOnNonConverge) { - String msg = "Iterative method did not converge within " + iters + " iterations."; - msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; - throw new PrismException(msg); - } - - // Return results - res = new ModelCheckerResult(); - res.soln = soln; - res.numIters = iters; - res.timeTaken = timer / 1000.0; - return res; + IterationMethodGS iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, false); + return doValueIterationReachRewards(mdp, mdpRewards, iterationMethod, target, inf, min, init, known, min, strat); } /**