diff --git a/prism/src/explicit/MDP.java b/prism/src/explicit/MDP.java index 751216b9..52071968 100644 --- a/prism/src/explicit/MDP.java +++ b/prism/src/explicit/MDP.java @@ -220,6 +220,7 @@ public interface MDP extends Model * and store new values directly in {@code vect} as computed. * The maximum (absolute/relative) difference between old/new * elements of {@code vect} is also returned. + * Optionally, store optimal (memoryless) strategy info. * @param vect Vector to multiply by (and store the result in) * @param mdpRewards The rewards * @param min Min or max for (true=min, false=max) @@ -227,8 +228,9 @@ public interface MDP extends Model * @param complement If true, {@code subset} is taken to be its complement (ignored if {@code subset} is null) * @param absolute If true, compute absolute, rather than relative, difference * @return The maximum difference between old/new elements of {@code vect} + * @param strat Storage for (memoryless) strategy choice indices (ignored if null) */ - public double mvMultRewGSMinMax(double vect[], MDPRewards mdpRewards, boolean min, BitSet subset, boolean complement, boolean absolute); + public double mvMultRewGSMinMax(double vect[], MDPRewards mdpRewards, boolean min, BitSet subset, boolean complement, boolean absolute, int strat[]); /** * Do a single row of matrix-vector multiplication and sum of action reward followed by min/max. @@ -245,12 +247,14 @@ public interface MDP extends Model /** * Do a single row of Jacobi-style matrix-vector multiplication and sum of action reward followed by min/max. * i.e. return min/max_k { (sum_{j!=s} P_k(s,j)*vect[j]) / P_k(s,s) } + * Optionally, store optimal (memoryless) strategy info. * @param s Row index * @param vect Vector to multiply by * @param mdpRewards The rewards * @param min Min or max for (true=min, false=max) + * @param strat Storage for (memoryless) strategy choice indices (ignored if null) */ - public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min); + public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min, int strat[]); /** * Determine which choices result in min/max after a single row of matrix-vector multiplication and sum of action reward. diff --git a/prism/src/explicit/MDPExplicit.java b/prism/src/explicit/MDPExplicit.java index 22b2a696..1e1e4bbb 100644 --- a/prism/src/explicit/MDPExplicit.java +++ b/prism/src/explicit/MDPExplicit.java @@ -301,28 +301,28 @@ public abstract class MDPExplicit extends ModelExplicit implements MDP } @Override - public double mvMultRewGSMinMax(double vect[], MDPRewards mdpRewards, boolean min, BitSet subset, boolean complement, boolean absolute) + public double mvMultRewGSMinMax(double vect[], MDPRewards mdpRewards, boolean min, BitSet subset, boolean complement, boolean absolute, int strat[]) { int s; double d, diff, maxDiff = 0.0; // Loop depends on subset/complement arguments if (subset == null) { for (s = 0; s < numStates; s++) { - d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min); + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min, strat); diff = absolute ? (Math.abs(d - vect[s])) : (Math.abs(d - vect[s]) / d); maxDiff = diff > maxDiff ? diff : maxDiff; vect[s] = d; } } else if (complement) { for (s = subset.nextClearBit(0); s < numStates; s = subset.nextClearBit(s + 1)) { - d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min); + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min, strat); diff = absolute ? (Math.abs(d - vect[s])) : (Math.abs(d - vect[s]) / d); maxDiff = diff > maxDiff ? diff : maxDiff; vect[s] = d; } } else { for (s = subset.nextSetBit(0); s >= 0; s = subset.nextSetBit(s + 1)) { - d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min); + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min, strat); diff = absolute ? (Math.abs(d - vect[s])) : (Math.abs(d - vect[s]) / d); maxDiff = diff > maxDiff ? diff : maxDiff; vect[s] = d; diff --git a/prism/src/explicit/MDPModelChecker.java b/prism/src/explicit/MDPModelChecker.java index d7eefa47..351d7a4d 100644 --- a/prism/src/explicit/MDPModelChecker.java +++ b/prism/src/explicit/MDPModelChecker.java @@ -393,7 +393,7 @@ public class MDPModelChecker extends ProbModelChecker mainLog.println("target=" + target.cardinality() + ", yes=" + numYes + ", no=" + numNo + ", maybe=" + (n - (numYes + numNo))); // If still required, generate strategy for no/yes (0/1) states. - // This is not just for the cases max=0 and min=1, where arbitrary choices suffice. + // This is just for the cases max=0 and min=1, where arbitrary choices suffice. // So just pick the first choice (0) for all these. if (genStrat) { if (min) { @@ -1182,6 +1182,8 @@ public class MDPModelChecker extends ProbModelChecker BitSet inf; int i, n, numTarget, numInf; long timer, timerProb1; + int strat[] = null; + boolean genStrat; // Local copy of setting MDPSolnMethod mdpSolnMethod = this.mdpSolnMethod; @@ -1191,6 +1193,9 @@ public class MDPModelChecker extends ProbModelChecker mainLog.printWarning("Switching to MDP solution method \"" + mdpSolnMethod.fullName() + "\""); } + // Are we generating an optimal strategy? + genStrat = exportAdv; + // Start expected reachability timer = System.currentTimeMillis(); mainLog.println("\nStarting expected reachability (" + (min ? "min" : "max") + ")..."); @@ -1201,6 +1206,15 @@ public class MDPModelChecker extends ProbModelChecker // Store num states n = mdp.getNumStates(); + // If required, create/initialise strategy storage + // Set all choices to -1, denoting unknown/arbitrary + if (genStrat) { + strat = new int[n]; + for (i = 0; i < n; i++) { + strat[i] = -1; + } + } + // Optimise by enlarging target set (if more info is available) if (init != null && known != null) { BitSet targetNew = new BitSet(n); @@ -1224,10 +1238,10 @@ public class MDPModelChecker extends ProbModelChecker // Compute rewards switch (mdpSolnMethod) { case VALUE_ITERATION: - res = computeReachRewardsValIter(mdp, mdpRewards, target, inf, min, init, known); + res = computeReachRewardsValIter(mdp, mdpRewards, target, inf, min, init, known, strat); break; case GAUSS_SEIDEL: - res = computeReachRewardsGaussSeidel(mdp, mdpRewards, target, inf, min, init, known); + res = computeReachRewardsGaussSeidel(mdp, mdpRewards, target, inf, min, init, known, strat); break; default: throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName()); @@ -1245,7 +1259,8 @@ public class MDPModelChecker extends ProbModelChecker } /** - * Compute expected reachability rewards using Gauss-Seidel (including Jacobi-style updates). + * Compute expected reachability rewards using value iteration. + * Optionally, store optimal (memoryless) strategy info. * @param mdp The MDP * @param mdpRewards The rewards * @param target Target states @@ -1253,41 +1268,43 @@ public class MDPModelChecker extends ProbModelChecker * @param min Min or max rewards (true=min, false=max) * @param init Optionally, an initial solution vector (will be overwritten) * @param known Optionally, a set of states for which the exact answer is known + * @param strat Storage for (memoryless) strategy choice indices (ignored if null) * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. */ - protected ModelCheckerResult computeReachRewardsGaussSeidel(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], - BitSet known) throws PrismException + protected ModelCheckerResult computeReachRewardsValIter(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], BitSet known, int strat[]) + throws PrismException { ModelCheckerResult res; BitSet unknown; int i, n, iters; - double soln[], maxDiff; + double soln[], soln2[], tmpsoln[]; boolean done; long timer; // Start value iteration timer = System.currentTimeMillis(); - mainLog.println("Starting Gauss-Seidel (" + (min ? "min" : "max") + ")..."); + mainLog.println("Starting value iteration (" + (min ? "min" : "max") + ")..."); // Store num states n = mdp.getNumStates(); // Create solution vector(s) - soln = (init == null) ? new double[n] : init; + soln = new double[n]; + soln2 = (init == null) ? new double[n] : init; - // Initialise solution vector. Use (where available) the following in order of preference: + // Initialise solution vectors. Use (where available) the following in order of preference: // (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0 if (init != null) { if (known != null) { for (i = 0; i < n; i++) - soln[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + soln[i] = soln2[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; } else { for (i = 0; i < n; i++) - soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; } } else { for (i = 0; i < n; i++) - soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; + soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; } // Determine set of states actually need to compute values for @@ -1305,14 +1322,18 @@ public class MDPModelChecker extends ProbModelChecker //mainLog.println(soln); iters++; // Matrix-vector multiply and min/max ops - maxDiff = mdp.mvMultRewGSMinMax(soln, mdpRewards, min, unknown, false, termCrit == TermCrit.ABSOLUTE); + mdp.mvMultRewMinMax(soln, mdpRewards, min, soln2, unknown, false, strat); // Check termination - done = maxDiff < termCritParam; + done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE); + // Swap vectors for next iter + tmpsoln = soln; + soln = soln2; + soln2 = tmpsoln; } - // Finished Gauss-Seidel + // Finished value iteration timer = System.currentTimeMillis() - timer; - mainLog.print("Gauss-Seidel (" + (min ? "min" : "max") + ")"); + mainLog.print("Value iteration (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); // Non-convergence is an error @@ -1331,7 +1352,8 @@ public class MDPModelChecker extends ProbModelChecker } /** - * Compute expected reachability rewards using value iteration. + * Compute expected reachability rewards using Gauss-Seidel (including Jacobi-style updates). + * Optionally, store optimal (memoryless) strategy info. * @param mdp The MDP * @param mdpRewards The rewards * @param target Target states @@ -1339,42 +1361,42 @@ public class MDPModelChecker extends ProbModelChecker * @param min Min or max rewards (true=min, false=max) * @param init Optionally, an initial solution vector (will be overwritten) * @param known Optionally, a set of states for which the exact answer is known + * @param strat Storage for (memoryless) strategy choice indices (ignored if null) * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. */ - protected ModelCheckerResult computeReachRewardsValIter(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], BitSet known) - throws PrismException + protected ModelCheckerResult computeReachRewardsGaussSeidel(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], + BitSet known, int strat[]) throws PrismException { ModelCheckerResult res; BitSet unknown; int i, n, iters; - double soln[], soln2[], tmpsoln[]; + double soln[], maxDiff; boolean done; long timer; // Start value iteration timer = System.currentTimeMillis(); - mainLog.println("Starting value iteration (" + (min ? "min" : "max") + ")..."); + mainLog.println("Starting Gauss-Seidel (" + (min ? "min" : "max") + ")..."); // Store num states n = mdp.getNumStates(); // Create solution vector(s) - soln = new double[n]; - soln2 = (init == null) ? new double[n] : init; + soln = (init == null) ? new double[n] : init; - // Initialise solution vectors. Use (where available) the following in order of preference: + // Initialise solution vector. Use (where available) the following in order of preference: // (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0 if (init != null) { if (known != null) { for (i = 0; i < n; i++) - soln[i] = soln2[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + soln[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; } else { for (i = 0; i < n; i++) - soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; } } else { for (i = 0; i < n; i++) - soln[i] = soln2[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; + soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; } // Determine set of states actually need to compute values for @@ -1392,18 +1414,14 @@ public class MDPModelChecker extends ProbModelChecker //mainLog.println(soln); iters++; // Matrix-vector multiply and min/max ops - mdp.mvMultRewMinMax(soln, mdpRewards, min, soln2, unknown, false, null); + maxDiff = mdp.mvMultRewGSMinMax(soln, mdpRewards, min, unknown, false, termCrit == TermCrit.ABSOLUTE, strat); // Check termination - done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE); - // Swap vectors for next iter - tmpsoln = soln; - soln = soln2; - soln2 = tmpsoln; + done = maxDiff < termCritParam; } - // Finished value iteration + // Finished Gauss-Seidel timer = System.currentTimeMillis() - timer; - mainLog.print("Value iteration (" + (min ? "min" : "max") + ")"); + mainLog.print("Gauss-Seidel (" + (min ? "min" : "max") + ")"); mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); // Non-convergence is an error diff --git a/prism/src/explicit/MDPSimple.java b/prism/src/explicit/MDPSimple.java index f0cf30e5..38ef8765 100644 --- a/prism/src/explicit/MDPSimple.java +++ b/prism/src/explicit/MDPSimple.java @@ -862,6 +862,8 @@ public class MDPSimple extends MDPExplicit implements ModelSimple } first = false; } + // Add state reward (doesn't affect min/max) + minmax += mdpRewards.getStateReward(s); // If strategy generation is enabled, store optimal choice if (strat != null & !first) { // Only remember strictly better choices (required for max) @@ -869,16 +871,14 @@ public class MDPSimple extends MDPExplicit implements ModelSimple strat[s] = stratCh; } } - // Add state reward (doesn't affect min/max) - minmax += mdpRewards.getStateReward(s); return minmax; } @Override - public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min) + public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min, int strat[]) { - int j, k; + int j, k, stratCh = -1; double diag, d, prob, minmax; boolean first; List step; @@ -904,12 +904,24 @@ public class MDPSimple extends MDPExplicit implements ModelSimple if (diag > 0) d /= diag; // Check whether we have exceeded min/max so far - if (first || (min && d < minmax) || (!min && d > minmax)) + if (first || (min && d < minmax) || (!min && d > minmax)) { minmax = d; + // If strategy generation is enabled, remember optimal choice + if (strat != null) { + stratCh = j; + } + } first = false; } // Add state reward (doesn't affect min/max) minmax += mdpRewards.getStateReward(s); + // If strategy generation is enabled, store optimal choice + if (strat != null & !first) { + // Only remember strictly better choices (required for max) + if (strat[s] == -1 || (min && minmax < vect[s]) || (!min && minmax > vect[s])) { + strat[s] = stratCh; + } + } return minmax; } diff --git a/prism/src/explicit/MDPSparse.java b/prism/src/explicit/MDPSparse.java index 1d7ab854..1c11c3f5 100644 --- a/prism/src/explicit/MDPSparse.java +++ b/prism/src/explicit/MDPSparse.java @@ -925,6 +925,8 @@ public class MDPSparse extends MDPExplicit } first = false; } + // Add state reward (doesn't affect min/max) + minmax += mdpRewards.getStateReward(s); // If strategy generation is enabled, store optimal choice if (strat != null & !first) { // Only remember strictly better choices (required for max) @@ -932,16 +934,14 @@ public class MDPSparse extends MDPExplicit strat[s] = stratCh; } } - // Add state reward (doesn't affect min/max) - minmax += mdpRewards.getStateReward(s); return minmax; } @Override - public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min) + public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min, int strat[]) { - int j, k, l1, h1, l2, h2; + int j, k, l1, h1, l2, h2, stratCh = -1; double diag, d, minmax; boolean first; @@ -965,12 +965,23 @@ public class MDPSparse extends MDPExplicit if (diag > 0) d /= diag; // Check whether we have exceeded min/max so far - if (first || (min && d < minmax) || (!min && d > minmax)) + if (first || (min && d < minmax) || (!min && d > minmax)) { minmax = d; + // If strategy generation is enabled, remember optimal choice + if (strat != null) + stratCh = j - l1; + } first = false; } // Add state reward (doesn't affect min/max) minmax += mdpRewards.getStateReward(s); + // If strategy generation is enabled, store optimal choice + if (strat != null & !first) { + // Only remember strictly better choices (required for max) + if (strat[s] == -1 || (min && minmax < vect[s]) || (!min && minmax > vect[s])) { + strat[s] = stratCh; + } + } return minmax; }