diff --git a/prism/src/explicit/DTMCModelChecker.java b/prism/src/explicit/DTMCModelChecker.java index 2a465706..478de7be 100644 --- a/prism/src/explicit/DTMCModelChecker.java +++ b/prism/src/explicit/DTMCModelChecker.java @@ -46,6 +46,7 @@ import prism.PrismUtils; import acceptance.AcceptanceReach; import acceptance.AcceptanceType; import automata.DA; +import common.IntSet; import common.IterableBitSet; import common.StopWatch; import explicit.LTLModelChecker.LTLProduct; @@ -592,7 +593,14 @@ public class DTMCModelChecker extends ProbModelChecker LinEqMethod linEqMethod = this.linEqMethod; // Switch to a supported method, if necessary - if (!(linEqMethod == LinEqMethod.POWER || linEqMethod == LinEqMethod.GAUSS_SEIDEL)) { + switch (linEqMethod) + { + case POWER: + case GAUSS_SEIDEL: + case BACKWARDS_GAUSS_SEIDEL: + case JACOBI: + break; // supported + default: linEqMethod = LinEqMethod.GAUSS_SEIDEL; mainLog.printWarning("Switching to linear equation solution method \"" + linEqMethod.fullName() + "\""); } @@ -663,18 +671,30 @@ public class DTMCModelChecker extends ProbModelChecker numNo = no.cardinality(); mainLog.println("target=" + target.cardinality() + ", yes=" + numYes + ", no=" + numNo + ", maybe=" + (n - (numYes + numNo))); + boolean termCritAbsolute = termCrit == TermCrit.ABSOLUTE; + // Compute probabilities + IterationMethod iterationMethod = null; + switch (linEqMethod) { case POWER: - res = computeReachProbsValIter(dtmc, no, yes, init, known); + iterationMethod = new IterationMethodPower(termCritAbsolute, termCritParam); + break; + case JACOBI: + iterationMethod = new IterationMethodJacobi(termCritAbsolute, termCritParam); break; case GAUSS_SEIDEL: - res = computeReachProbsGaussSeidel(dtmc, no, yes, init, known); + case BACKWARDS_GAUSS_SEIDEL: { + boolean backwards = linEqMethod == LinEqMethod.BACKWARDS_GAUSS_SEIDEL; + iterationMethod = new IterationMethodGS(termCritAbsolute, termCritParam, backwards); break; + } default: throw new PrismException("Unknown linear equation solution method " + linEqMethod.fullName()); } + res = doValueIterationReachProbs(dtmc, no, yes, init, known, iterationMethod, false); + // Finished probabilistic reachability timer = System.currentTimeMillis() - timer; mainLog.println("Probabilistic reachability took " + timer / 1000.0 + " seconds."); @@ -945,35 +965,31 @@ public class DTMCModelChecker extends ProbModelChecker * @param dtmc The DTMC * @param no Probability 0 states * @param yes Probability 1 states - * @param init Optionally, an initial solution vector (will be overwritten) + * @param init Optionally, an initial solution vector (will be overwritten) * @param known Optionally, a set of states for which the exact answer is known - * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. + * Note: if 'known' is specified (i.e. is non-null), 'init' must also be given and is used for the exact values. + * @param topological do topological value iteration? */ - protected ModelCheckerResult computeReachProbsValIter(DTMC dtmc, BitSet no, BitSet yes, double init[], BitSet known) throws PrismException + protected ModelCheckerResult doValueIterationReachProbs(DTMC dtmc, BitSet no, BitSet yes, double init[], BitSet known, IterationMethod iterationMethod, boolean topological) throws PrismException { - ModelCheckerResult res; BitSet unknown; - int i, n, iters; - double soln[], soln2[], tmpsoln[], initVal; - boolean done; + int i, n; + double initVal; long timer; // Start value iteration timer = System.currentTimeMillis(); - mainLog.println("Starting value iteration..."); + String description = (topological ? "topological, " : "" ) + "with " + iterationMethod.getDescriptionShort(); + mainLog.println("Starting value iteration (" + description + ")..."); ExportIterations iterationsExport = null; if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) { - iterationsExport = new ExportIterations("Explicit DTMC ReachProbs value iteration"); + iterationsExport = new ExportIterations("Explicit DTMC ReachProbs value iteration (" + description + ")"); } // Store num states n = dtmc.getNumStates(); - // Create solution vector(s) - soln = new double[n]; - soln2 = (init == null) ? new double[n] : init; - // Initialise solution vectors. Use (where available) the following in order of preference: // (1) exact answer, if already known; (2) 1.0/0.0 if in yes/no; (3) passed in initial value; (4) initVal // where initVal is 0.0 or 1.0, depending on whether we converge from below/above. @@ -981,14 +997,15 @@ public class DTMCModelChecker extends ProbModelChecker if (init != null) { if (known != null) { for (i = 0; i < n; i++) - soln[i] = soln2[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; + init[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; } else { for (i = 0; i < n; i++) - soln[i] = soln2[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; + init[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; } } else { + init = new double[n]; for (i = 0; i < n; i++) - soln[i] = soln2[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal; + init[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal; } // Determine set of states actually need to compute values for @@ -999,49 +1016,59 @@ public class DTMCModelChecker extends ProbModelChecker if (known != null) unknown.andNot(known); + IterationMethod.IterationValIter iterationReachProbs = iterationMethod.forMvMult(dtmc); + iterationReachProbs.init(init); + if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); + iterationsExport.exportVector(init, 0); - // Start iterations - iters = 0; - done = false; - while (!done && iters < maxIters) { - iters++; - // Matrix-vector multiply - dtmc.mvMult(soln, soln2, unknown, false); + IntSet unknownStates = IntSet.asIntSet(unknown); - if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); + if (topological) { + // Compute SCCInfo, including trivial SCCs in the subgraph obtained when only considering + // states in unknown + SCCInfo sccs = SCCComputer.computeTopologicalOrdering(this, dtmc, true, unknown::get); - // Check termination - done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE); - // Swap vectors for next iter - tmpsoln = soln; - soln = soln2; - soln2 = tmpsoln; - } + IterationMethod.SingletonSCCSolver singletonSCCSolver = (int s, double[] soln) -> { + soln[s] = dtmc.mvMultJacSingle(s, soln); + }; - // Finished value iteration - timer = System.currentTimeMillis() - timer; - mainLog.print("Value iteration"); - mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); + // run the actual value iteration + return iterationMethod.doTopologicalValueIteration(this, description, sccs, iterationReachProbs, singletonSCCSolver, timer, iterationsExport); + } else { + // run the actual value iteration + return iterationMethod.doValueIteration(this, description, iterationReachProbs, unknownStates, timer, iterationsExport); + } + } - if (iterationsExport != null) - iterationsExport.close(); - // Non-convergence is an error (usually) - if (!done && errorOnNonConverge) { - String msg = "Iterative method did not converge within " + iters + " iterations."; - msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; - throw new PrismException(msg); - } + /** + * Compute reachability probabilities using value iteration. + * @param dtmc The DTMC + * @param no Probability 0 states + * @param yes Probability 1 states + * @param init Optionally, an initial solution vector (will be overwritten) + * @param known Optionally, a set of states for which the exact answer is known + * Note: if 'known' is specified (i.e. is non-null), 'init' must also be given and is used for the exact values. + */ + protected ModelCheckerResult computeReachProbsValIter(DTMC dtmc, BitSet no, BitSet yes, double init[], BitSet known) throws PrismException + { + IterationMethodPower iterationMethod = new IterationMethodPower(termCrit == TermCrit.ABSOLUTE, termCritParam); + return doValueIterationReachProbs(dtmc, no, yes, init, known, iterationMethod, false); + } - // Return results - res = new ModelCheckerResult(); - res.soln = soln; - res.numIters = iters; - res.timeTaken = timer / 1000.0; - return res; + /** + * Compute reachability probabilities using Gauss-Seidel (forward). + * @param dtmc The DTMC + * @param no Probability 0 states + * @param yes Probability 1 states + * @param init Optionally, an initial solution vector (will be overwritten) + * @param known Optionally, a set of states for which the exact answer is known + * Note: if 'known' is specified (i.e. is non-null), 'init' must also be given and is used for the exact values. + */ + protected ModelCheckerResult computeReachProbsGaussSeidel(DTMC dtmc, BitSet no, BitSet yes, double init[], BitSet known) throws PrismException + { + return computeReachProbsGaussSeidel(dtmc, no, yes, init, known, false); } /** @@ -1051,96 +1078,13 @@ public class DTMCModelChecker extends ProbModelChecker * @param yes Probability 1 states * @param init Optionally, an initial solution vector (will be overwritten) * @param known Optionally, a set of states for which the exact answer is known - * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. + * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. + * @param backwards do backward Gauss-Seidel? */ - protected ModelCheckerResult computeReachProbsGaussSeidel(DTMC dtmc, BitSet no, BitSet yes, double init[], BitSet known) throws PrismException + protected ModelCheckerResult computeReachProbsGaussSeidel(DTMC dtmc, BitSet no, BitSet yes, double init[], BitSet known, boolean backwards) throws PrismException { - ModelCheckerResult res; - BitSet unknown; - int i, n, iters; - double soln[], initVal, maxDiff; - boolean done; - long timer; - - // Start value iteration - timer = System.currentTimeMillis(); - mainLog.println("Starting Gauss-Seidel..."); - - ExportIterations iterationsExport = null; - if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) { - iterationsExport = new ExportIterations("Explicit DTMC ReachProbs Gauss Seidel value iteration"); - } - - // Store num states - n = dtmc.getNumStates(); - - // Create solution vector - soln = (init == null) ? new double[n] : init; - - // Initialise solution vector. Use (where available) the following in order of preference: - // (1) exact answer, if already known; (2) 1.0/0.0 if in yes/no; (3) passed in initial value; (4) initVal - // where initVal is 0.0 or 1.0, depending on whether we converge from below/above. - initVal = (valIterDir == ValIterDir.BELOW) ? 0.0 : 1.0; - if (init != null) { - if (known != null) { - for (i = 0; i < n; i++) - soln[i] = known.get(i) ? init[i] : yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; - } else { - for (i = 0; i < n; i++) - soln[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : init[i]; - } - } else { - for (i = 0; i < n; i++) - soln[i] = yes.get(i) ? 1.0 : no.get(i) ? 0.0 : initVal; - } - - // Determine set of states actually need to compute values for - unknown = new BitSet(); - unknown.set(0, n); - unknown.andNot(yes); - unknown.andNot(no); - if (known != null) - unknown.andNot(known); - - if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); - - // Start iterations - iters = 0; - done = false; - while (!done && iters < maxIters) { - iters++; - // Matrix-vector multiply - maxDiff = dtmc.mvMultGS(soln, unknown, false, termCrit == TermCrit.ABSOLUTE); - - if (iterationsExport != null) - iterationsExport.exportVector(soln, 0); - - // Check termination - done = maxDiff < termCritParam; - } - - // Finished Gauss-Seidel - timer = System.currentTimeMillis() - timer; - mainLog.print("Gauss-Seidel"); - mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); - - if (iterationsExport != null) - iterationsExport.close(); - - // Non-convergence is an error (usually) - if (!done && errorOnNonConverge) { - String msg = "Iterative method did not converge within " + iters + " iterations."; - msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; - throw new PrismException(msg); - } - - // Return results - res = new ModelCheckerResult(); - res.soln = soln; - res.numIters = iters; - res.timeTaken = timer / 1000.0; - return res; + IterationMethodGS iterationMethod = new IterationMethodGS(termCrit == TermCrit.ABSOLUTE, termCritParam, backwards); + return doValueIterationReachProbs(dtmc, no, yes, init, known, iterationMethod, false); } /** @@ -1271,7 +1215,7 @@ public class DTMCModelChecker extends ProbModelChecker * @param dtmc The DTMC * @param mcRewards The rewards * @param target Target states - * @param init Optionally, an initial solution vector (may be overwritten) + * @param init Optionally, an initial solution vector (may be overwritten) * @param known Optionally, a set of states for which the exact answer is known * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. */ @@ -1285,8 +1229,15 @@ public class DTMCModelChecker extends ProbModelChecker LinEqMethod linEqMethod = this.linEqMethod; // Switch to a supported method, if necessary - if (!(linEqMethod == LinEqMethod.POWER)) { - linEqMethod = LinEqMethod.POWER; + switch (linEqMethod) + { + case POWER: + case GAUSS_SEIDEL: + case BACKWARDS_GAUSS_SEIDEL: + case JACOBI: + break; // supported + default: + linEqMethod = LinEqMethod.GAUSS_SEIDEL; mainLog.printWarning("Switching to linear equation solution method \"" + linEqMethod.fullName() + "\""); } @@ -1329,15 +1280,30 @@ public class DTMCModelChecker extends ProbModelChecker numInf = inf.cardinality(); mainLog.println("target=" + numTarget + ", inf=" + numInf + ", rest=" + (n - (numTarget + numInf))); + boolean termCritAbsolute = termCrit == TermCrit.ABSOLUTE; + + IterationMethod iterationMethod; + // Compute rewards switch (linEqMethod) { case POWER: - res = computeReachRewardsValIter(dtmc, mcRewards, target, inf, init, known); + iterationMethod = new IterationMethodPower(termCritAbsolute, termCritParam); + break; + case JACOBI: + iterationMethod = new IterationMethodJacobi(termCritAbsolute, termCritParam); + break; + case GAUSS_SEIDEL: + case BACKWARDS_GAUSS_SEIDEL: { + boolean backwards = linEqMethod == LinEqMethod.BACKWARDS_GAUSS_SEIDEL; + iterationMethod = new IterationMethodGS(termCritAbsolute, termCritParam, backwards); break; + } default: throw new PrismException("Unknown linear equation solution method " + linEqMethod.fullName()); } + res = doValueIterationReachRewards(dtmc, mcRewards, target, inf, init, known, iterationMethod, false); + // Finished expected reachability timer = System.currentTimeMillis() - timer; mainLog.println("Expected reachability took " + timer / 1000.0 + " seconds."); @@ -1355,7 +1321,7 @@ public class DTMCModelChecker extends ProbModelChecker * @param mcRewards The rewards * @param target Target states * @param inf States for which reward is infinite - * @param init Optionally, an initial solution vector (will be overwritten) + * @param init Optionally, an initial solution vector (will be overwritten) * @param known Optionally, a set of states for which the exact answer is known * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. */ @@ -1454,6 +1420,84 @@ public class DTMCModelChecker extends ProbModelChecker return res; } + /** + * Compute expected reachability rewards using value iteration. + * @param dtmc The DTMC + * @param mcRewards The rewards + * @param target Target states + * @param inf States for which reward is infinite + * @param init Optionally, an initial solution vector (will be overwritten) + * @param known Optionally, a set of states for which the exact answer is known + * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. + * @param topological do topological value iteration? + */ + protected ModelCheckerResult doValueIterationReachRewards(DTMC dtmc, final MCRewards mcRewards, BitSet target, BitSet inf, double init[], BitSet known, IterationMethod iterationMethod, boolean topological) throws PrismException + { + BitSet unknown; + int i, n; + long timer; + + // Start value iteration + timer = System.currentTimeMillis(); + String description = (topological ? "topological, " : "" ) + "with " + iterationMethod.getDescriptionShort(); + mainLog.println("Starting value iteration (" + description + ") ..."); + + ExportIterations iterationsExport = null; + if (settings.getBoolean(PrismSettings.PRISM_EXPORT_ITERATIONS)) { + iterationsExport = new ExportIterations("Explicit DTMC ReachRewards value iteration (" + description + ")"); + } + + // Store num states + n = dtmc.getNumStates(); + + // Initialise solution vector. Use (where available) the following in order of preference: + // (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0 + if (init != null) { + if (known != null) { + for (i = 0; i < n; i++) + init[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + } else { + for (i = 0; i < n; i++) + init[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + } + } else { + init = new double[n]; + for (i = 0; i < n; i++) + init[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; + } + + // Determine set of states actually need to compute values for + unknown = new BitSet(); + unknown.set(0, n); + unknown.andNot(target); + unknown.andNot(inf); + if (known != null) + unknown.andNot(known); + + if (iterationsExport != null) + iterationsExport.exportVector(init, 0); + + IntSet unknownStates = IntSet.asIntSet(unknown); + IterationMethod.IterationValIter forMvMultRew = iterationMethod.forMvMultRew(dtmc, mcRewards); + forMvMultRew.init(init); + + if (topological) { + SCCInfo sccs = new SCCInfo(n); + SCCComputer sccComputer = SCCComputer.createSCCComputer(this, dtmc, sccs); + // Compute SCCInfo, including trivial SCCs in the subgraph obtained when only considering + // states in unknown + sccComputer.computeSCCs(false, unknown::get); + + IterationMethod.SingletonSCCSolver singletonSCCSolver = (int s, double[] soln) -> { + soln[s] = dtmc.mvMultRewJacSingle(s, soln, mcRewards); + }; + + return iterationMethod.doTopologicalValueIteration(this, description, sccs, forMvMultRew, singletonSCCSolver, timer, iterationsExport); + } else { + return iterationMethod.doValueIteration(this, description, forMvMultRew, unknownStates, timer, iterationsExport); + } + } + /** * Compute (forwards) steady-state probabilities * i.e. compute the long-run probability of being in each state,