From 193366cb7d7aa6b01fadef565b2bd29fd686139e Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Tue, 14 Apr 2020 23:50:55 +0100 Subject: [PATCH] Explicit engine: avoid numerical computation where precomputation suffices. --- prism/src/explicit/DTMCModelChecker.java | 104 ++++++++++++----------- prism/src/explicit/MDPModelChecker.java | 63 ++++++++------ prism/src/explicit/Utils.java | 14 ++- 3 files changed, 102 insertions(+), 79 deletions(-) diff --git a/prism/src/explicit/DTMCModelChecker.java b/prism/src/explicit/DTMCModelChecker.java index 41d2fbd1..ef201963 100644 --- a/prism/src/explicit/DTMCModelChecker.java +++ b/prism/src/explicit/DTMCModelChecker.java @@ -689,32 +689,34 @@ public class DTMCModelChecker extends ProbModelChecker numNo = no.cardinality(); mainLog.println("target=" + target.cardinality() + ", yes=" + numYes + ", no=" + numNo + ", maybe=" + (n - (numYes + numNo))); - boolean termCritAbsolute = termCrit == TermCrit.ABSOLUTE; - - // Compute probabilities - IterationMethod iterationMethod = null; - - switch (linEqMethod) { - case POWER: - iterationMethod = new IterationMethodPower(termCritAbsolute, termCritParam); - break; - case JACOBI: - iterationMethod = new IterationMethodJacobi(termCritAbsolute, termCritParam); - break; - case GAUSS_SEIDEL: - case BACKWARDS_GAUSS_SEIDEL: { - boolean backwards = linEqMethod == LinEqMethod.BACKWARDS_GAUSS_SEIDEL; - iterationMethod = new IterationMethodGS(termCritAbsolute, termCritParam, backwards); - break; - } - default: - throw new PrismException("Unknown linear equation solution method " + linEqMethod.fullName()); - } - - if (doIntervalIteration) { - res = doIntervalIterationReachProbs(dtmc, no, yes, init, known, iterationMethod, getDoTopologicalValueIteration()); + // Compute probabilities (if needed) + if (numYes + numNo < n) { + boolean termCritAbsolute = termCrit == TermCrit.ABSOLUTE; + IterationMethod iterationMethod = null; + switch (linEqMethod) { + case POWER: + iterationMethod = new IterationMethodPower(termCritAbsolute, termCritParam); + break; + case JACOBI: + iterationMethod = new IterationMethodJacobi(termCritAbsolute, termCritParam); + break; + case GAUSS_SEIDEL: + case BACKWARDS_GAUSS_SEIDEL: { + boolean backwards = linEqMethod == LinEqMethod.BACKWARDS_GAUSS_SEIDEL; + iterationMethod = new IterationMethodGS(termCritAbsolute, termCritParam, backwards); + break; + } + default: + throw new PrismException("Unknown linear equation solution method " + linEqMethod.fullName()); + } + if (doIntervalIteration) { + res = doIntervalIterationReachProbs(dtmc, no, yes, init, known, iterationMethod, getDoTopologicalValueIteration()); + } else { + res = doValueIterationReachProbs(dtmc, no, yes, init, known, iterationMethod, getDoTopologicalValueIteration()); + } } else { - res = doValueIterationReachProbs(dtmc, no, yes, init, known, iterationMethod, getDoTopologicalValueIteration()); + res = new ModelCheckerResult(); + res.soln = Utils.bitsetToDoubleArray(yes, n); } // Finished probabilistic reachability @@ -1799,32 +1801,34 @@ public class DTMCModelChecker extends ProbModelChecker numInf = inf.cardinality(); mainLog.println("target=" + numTarget + ", inf=" + numInf + ", rest=" + (n - (numTarget + numInf))); - boolean termCritAbsolute = termCrit == TermCrit.ABSOLUTE; - - IterationMethod iterationMethod; - - // Compute rewards - switch (linEqMethod) { - case POWER: - iterationMethod = new IterationMethodPower(termCritAbsolute, termCritParam); - break; - case JACOBI: - iterationMethod = new IterationMethodJacobi(termCritAbsolute, termCritParam); - break; - case GAUSS_SEIDEL: - case BACKWARDS_GAUSS_SEIDEL: { - boolean backwards = linEqMethod == LinEqMethod.BACKWARDS_GAUSS_SEIDEL; - iterationMethod = new IterationMethodGS(termCritAbsolute, termCritParam, backwards); - break; - } - default: - throw new PrismException("Unknown linear equation solution method " + linEqMethod.fullName()); - } - - if (doIntervalIteration) { - res = doIntervalIterationReachRewards(dtmc, mcRewards, target, inf, init, known, iterationMethod, getDoTopologicalValueIteration()); + // Compute rewards (if needed) + if (numTarget + numInf < n) { + boolean termCritAbsolute = termCrit == TermCrit.ABSOLUTE; + IterationMethod iterationMethod; + switch (linEqMethod) { + case POWER: + iterationMethod = new IterationMethodPower(termCritAbsolute, termCritParam); + break; + case JACOBI: + iterationMethod = new IterationMethodJacobi(termCritAbsolute, termCritParam); + break; + case GAUSS_SEIDEL: + case BACKWARDS_GAUSS_SEIDEL: { + boolean backwards = linEqMethod == LinEqMethod.BACKWARDS_GAUSS_SEIDEL; + iterationMethod = new IterationMethodGS(termCritAbsolute, termCritParam, backwards); + break; + } + default: + throw new PrismException("Unknown linear equation solution method " + linEqMethod.fullName()); + } + if (doIntervalIteration) { + res = doIntervalIterationReachRewards(dtmc, mcRewards, target, inf, init, known, iterationMethod, getDoTopologicalValueIteration()); + } else { + res = doValueIterationReachRewards(dtmc, mcRewards, target, inf, init, known, iterationMethod, getDoTopologicalValueIteration()); + } } else { - res = doValueIterationReachRewards(dtmc, mcRewards, target, inf, init, known, iterationMethod, getDoTopologicalValueIteration()); + res = new ModelCheckerResult(); + res.soln = Utils.bitsetToDoubleArray(inf, n, Double.POSITIVE_INFINITY); } // Finished expected reachability diff --git a/prism/src/explicit/MDPModelChecker.java b/prism/src/explicit/MDPModelChecker.java index 1cee5bd7..34360ba0 100644 --- a/prism/src/explicit/MDPModelChecker.java +++ b/prism/src/explicit/MDPModelChecker.java @@ -2166,38 +2166,45 @@ public class MDPModelChecker extends ProbModelChecker } } - ZeroRewardECQuotient quotient = null; - boolean doZeroMECCheckForMin = true; - if (min & doZeroMECCheckForMin) { - StopWatch zeroMECTimer = new StopWatch(mainLog); - zeroMECTimer.start("checking for zero-reward ECs"); - mainLog.println("For Rmin, checking for zero-reward ECs..."); - BitSet unknown = (BitSet) inf.clone(); - unknown.flip(0, mdp.getNumStates()); - unknown.andNot(target); - quotient = ZeroRewardECQuotient.getQuotient(this, mdp, unknown, mdpRewards); - - if (quotient == null) { - zeroMECTimer.stop("no zero-reward ECs found, proceeding normally"); - } else { - zeroMECTimer.stop("built quotient MDP with " + quotient.getNumberOfZeroRewardMECs() + " zero-reward MECs"); - if (strat != null) { - throw new PrismException("Constructing a strategy for Rmin in the presence of zero-reward ECs is currently not supported"); + // Compute rewards (if needed) + if (numTarget + numInf < n) { + + ZeroRewardECQuotient quotient = null; + boolean doZeroMECCheckForMin = true; + if (min & doZeroMECCheckForMin) { + StopWatch zeroMECTimer = new StopWatch(mainLog); + zeroMECTimer.start("checking for zero-reward ECs"); + mainLog.println("For Rmin, checking for zero-reward ECs..."); + BitSet unknown = (BitSet) inf.clone(); + unknown.flip(0, mdp.getNumStates()); + unknown.andNot(target); + quotient = ZeroRewardECQuotient.getQuotient(this, mdp, unknown, mdpRewards); + + if (quotient == null) { + zeroMECTimer.stop("no zero-reward ECs found, proceeding normally"); + } else { + zeroMECTimer.stop("built quotient MDP with " + quotient.getNumberOfZeroRewardMECs() + " zero-reward MECs"); + if (strat != null) { + throw new PrismException("Constructing a strategy for Rmin in the presence of zero-reward ECs is currently not supported"); + } } } - } - - if (quotient != null) { - BitSet newInfStates = (BitSet)inf.clone(); - newInfStates.or(quotient.getNonRepresentativeStates()); - int quotientModelStates = quotient.getModel().getNumStates() - newInfStates.cardinality(); - mainLog.println("Computing Rmin in zero-reward EC quotient model (" + quotientModelStates + " relevant states)..."); - res = computeReachRewardsNumeric(quotient.getModel(), quotient.getRewards(), mdpSolnMethod, target, newInfStates, min, init, known, strat); - quotient.mapResults(res.soln); + + if (quotient != null) { + BitSet newInfStates = (BitSet)inf.clone(); + newInfStates.or(quotient.getNonRepresentativeStates()); + int quotientModelStates = quotient.getModel().getNumStates() - newInfStates.cardinality(); + mainLog.println("Computing Rmin in zero-reward EC quotient model (" + quotientModelStates + " relevant states)..."); + res = computeReachRewardsNumeric(quotient.getModel(), quotient.getRewards(), mdpSolnMethod, target, newInfStates, min, init, known, strat); + quotient.mapResults(res.soln); + } else { + res = computeReachRewardsNumeric(mdp, mdpRewards, mdpSolnMethod, target, inf, min, init, known, strat); + } } else { - res = computeReachRewardsNumeric(mdp, mdpRewards, mdpSolnMethod, target, inf, min, init, known, strat); + res = new ModelCheckerResult(); + res.soln = Utils.bitsetToDoubleArray(inf, n, Double.POSITIVE_INFINITY); } - + // Store strategy if (genStrat) { res.strat = new MDStrategyArray(mdp, strat); diff --git a/prism/src/explicit/Utils.java b/prism/src/explicit/Utils.java index 47e5c4a4..e36a8f90 100644 --- a/prism/src/explicit/Utils.java +++ b/prism/src/explicit/Utils.java @@ -82,11 +82,23 @@ public class Utils * @param n The size of the array. */ public static double[] bitsetToDoubleArray(BitSet bs, int n) + { + return bitsetToDoubleArray(bs, n, 1.0); + } + + /** + * Create an n-element array of doubles from a BitSet, + * setting elements whose index is set in the BitSet to {@code val}, and otherwise 0.0. + * @param bs The bitset specifying set elements + * @param n The size of the array. + * @param val The value for "set" elements + */ + public static double[] bitsetToDoubleArray(BitSet bs, int n, double val) { int i; double res[] = new double[n]; for (i = 0; i < n; i++) - res[i] = bs.get(i) ? 1.0 : 0.0; + res[i] = bs.get(i) ? val : 0.0; return res; }