diff --git a/prism/src/explicit/MDP.java b/prism/src/explicit/MDP.java index eecc63a1..b70724e2 100644 --- a/prism/src/explicit/MDP.java +++ b/prism/src/explicit/MDP.java @@ -173,6 +173,22 @@ public interface MDP extends Model */ public void mvMultRewMinMax(double vect[], MDPRewards mdpRewards, boolean min, double result[], BitSet subset, boolean complement, int adv[]); + /** + * Do a Gauss-Seidel-style matrix-vector multiplication and sum of action reward followed by min/max. + * i.e. for all s: vect[s] = min/max_k { rew(s) + (sum_{j!=s} P_k(s,j)*vect[j]) / P_k(s,s) } + * and store new values directly in {@code vect} as computed. + * The maximum (absolute/relative) difference between old/new + * elements of {@code vect} is also returned. + * @param vect Vector to multiply by (and store the result in) + * @param mdpRewards The rewards + * @param min Min or max for (true=min, false=max) + * @param subset Only do multiplication for these rows (ignored if null) + * @param complement If true, {@code subset} is taken to be its complement (ignored if {@code subset} is null) + * @param absolute If true, compute absolute, rather than relative, difference + * @return The maximum difference between old/new elements of {@code vect} + */ + public double mvMultRewGSMinMax(double vect[], MDPRewards mdpRewards, boolean min, BitSet subset, boolean complement, boolean absolute); + /** * Do a single row of matrix-vector multiplication and sum of action reward followed by min/max. * i.e. return min/max_k { rew(s) + sum_j P_k(s,j)*vect[j] } @@ -184,6 +200,16 @@ public interface MDP extends Model */ public double mvMultRewMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min, int adv[]); + /** + * Do a single row of Jacobi-style matrix-vector multiplication and sum of action reward followed by min/max. + * i.e. return min/max_k { (sum_{j!=s} P_k(s,j)*vect[j]) / P_k(s,s) } + * @param s Row index + * @param vect Vector to multiply by + * @param mdpRewards The rewards + * @param min Min or max for (true=min, false=max) + */ + public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min); + /** * Determine which choices result in min/max after a single row of matrix-vector multiplication and sum of action reward. * @param s Row index diff --git a/prism/src/explicit/MDPExplicit.java b/prism/src/explicit/MDPExplicit.java index 4bc89cec..567ea3c7 100644 --- a/prism/src/explicit/MDPExplicit.java +++ b/prism/src/explicit/MDPExplicit.java @@ -258,4 +258,44 @@ public abstract class MDPExplicit extends ModelExplicit implements MDP result[s] = mvMultRewMinMaxSingle(s, vect, mdpRewards, min, adv); } } + + @Override + public double mvMultRewGSMinMax(double vect[], MDPRewards mdpRewards, boolean min, BitSet subset, boolean complement, boolean absolute) + { + int s; + double d, diff, maxDiff = 0.0; + // Loop depends on subset/complement arguments + if (subset == null) { + for (s = 0; s < numStates; s++) { + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min); + diff = absolute ? (Math.abs(d - vect[s])) : (Math.abs(d - vect[s]) / d); + maxDiff = diff > maxDiff ? diff : maxDiff; + vect[s] = d; + } + } else if (complement) { + for (s = subset.nextClearBit(0); s < numStates; s = subset.nextClearBit(s + 1)) { + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min); + diff = absolute ? (Math.abs(d - vect[s])) : (Math.abs(d - vect[s]) / d); + maxDiff = diff > maxDiff ? diff : maxDiff; + vect[s] = d; + } + } else { + for (s = subset.nextSetBit(0); s >= 0; s = subset.nextSetBit(s + 1)) { + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min); + diff = absolute ? (Math.abs(d - vect[s])) : (Math.abs(d - vect[s]) / d); + maxDiff = diff > maxDiff ? diff : maxDiff; + vect[s] = d; + } + } + // Use this code instead for backwards Gauss-Seidel + /*for (s = numStates - 1; s >= 0; s--) { + if (subset.get(s)) { + d = mvMultRewJacMinMaxSingle(s, vect, mdpRewards, min); + diff = absolute ? (Math.abs(d - vect[s])) : (Math.abs(d - vect[s]) / d); + maxDiff = diff > maxDiff ? diff : maxDiff; + vect[s] = d; + } + }*/ + return maxDiff; + } } diff --git a/prism/src/explicit/MDPModelChecker.java b/prism/src/explicit/MDPModelChecker.java index d3e76c57..ca6a3390 100644 --- a/prism/src/explicit/MDPModelChecker.java +++ b/prism/src/explicit/MDPModelChecker.java @@ -1040,6 +1040,9 @@ public class MDPModelChecker extends ProbModelChecker case VALUE_ITERATION: res = computeReachRewardsValIter(mdp, mdpRewards, target, inf, min, init, known); break; + case GAUSS_SEIDEL: + res = computeReachRewardsGaussSeidel(mdp, mdpRewards, target, inf, min, init, known); + break; default: throw new PrismException("Unknown MDP solution method " + solnMethod); } @@ -1055,6 +1058,91 @@ public class MDPModelChecker extends ProbModelChecker return res; } + /** + * Compute expected reachability rewards using Gauss-Seidel (including Jacobi-style updates). + * @param mdp The MDP + * @param mdpRewards The rewards + * @param target Target states + * @param inf States for which reward is infinite + * @param min Min or max rewards (true=min, false=max) + * @param init Optionally, an initial solution vector (will be overwritten) + * @param known Optionally, a set of states for which the exact answer is known + * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. + */ + protected ModelCheckerResult computeReachRewardsGaussSeidel(MDP mdp, MDPRewards mdpRewards, BitSet target, BitSet inf, boolean min, double init[], BitSet known) throws PrismException + { + ModelCheckerResult res; + BitSet unknown; + int i, n, iters; + double soln[], maxDiff; + boolean done; + long timer; + + // Start value iteration + timer = System.currentTimeMillis(); + mainLog.println("Starting Gauss-Seidel (" + (min ? "min" : "max") + ")..."); + + // Store num states + n = mdp.getNumStates(); + + // Create solution vector(s) + soln = (init == null) ? new double[n] : init; + + // Initialise solution vector. Use (where available) the following in order of preference: + // (1) exact answer, if already known; (2) 0.0/infinity if in target/inf; (3) passed in initial value; (4) 0.0 + if (init != null) { + if (known != null) { + for (i = 0; i < n; i++) + soln[i] = known.get(i) ? init[i] : target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + } else { + for (i = 0; i < n; i++) + soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : init[i]; + } + } else { + for (i = 0; i < n; i++) + soln[i] = target.get(i) ? 0.0 : inf.get(i) ? Double.POSITIVE_INFINITY : 0.0; + } + + // Determine set of states actually need to compute values for + unknown = new BitSet(); + unknown.set(0, n); + unknown.andNot(target); + unknown.andNot(inf); + if (known != null) + unknown.andNot(known); + + // Start iterations + iters = 0; + done = false; + while (!done && iters < maxIters) { + //mainLog.println(soln); + iters++; + // Matrix-vector multiply and min/max ops + maxDiff = mdp.mvMultRewGSMinMax(soln, mdpRewards, min, unknown, false, termCrit == TermCrit.ABSOLUTE); + // Check termination + done = maxDiff < termCritParam; + } + + // Finished Gauss-Seidel + timer = System.currentTimeMillis() - timer; + mainLog.print("Gauss-Seidel (" + (min ? "min" : "max") + ")"); + mainLog.println(" took " + iters + " iterations and " + timer / 1000.0 + " seconds."); + + // Non-convergence is an error + if (!done) { + String msg = "Iterative method did not converge within " + iters + " iterations."; + msg += "\nConsider using a different numerical method or increasing the maximum number of iterations"; + throw new PrismException(msg); + } + + // Return results + res = new ModelCheckerResult(); + res.soln = soln; + res.numIters = iters; + res.timeTaken = timer / 1000.0; + return res; + } + /** * Compute expected reachability rewards using value iteration. * @param mdp The MDP diff --git a/prism/src/explicit/MDPSimple.java b/prism/src/explicit/MDPSimple.java index ce96f877..13301dba 100644 --- a/prism/src/explicit/MDPSimple.java +++ b/prism/src/explicit/MDPSimple.java @@ -746,6 +746,45 @@ public class MDPSimple extends MDPExplicit implements ModelSimple return minmax; } + @Override + public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min) + { + int j, k; + double diag, d, prob, minmax; + boolean first; + List step; + + minmax = 0; + first = true; + j = -1; + step = trans.get(s); + for (Distribution distr : step) { + j++; + diag = 1.0; + // Compute sum for this distribution + d = mdpRewards.getTransitionReward(s, j); + for (Map.Entry e : distr) { + k = (Integer) e.getKey(); + prob = (Double) e.getValue(); + if (k != s) { + d += prob * vect[k]; + } else { + diag -= prob; + } + } + if (diag > 0) + d /= diag; + // Check whether we have exceeded min/max so far + if (first || (min && d < minmax) || (!min && d > minmax)) + minmax = d; + first = false; + } + // Add state reward (doesn't affect min/max) + minmax += mdpRewards.getStateReward(s); + + return minmax; + } + @Override public List mvMultRewMinMaxSingleChoices(int s, double vect[], MDPRewards mdpRewards, boolean min, double val) { diff --git a/prism/src/explicit/MDPSparse.java b/prism/src/explicit/MDPSparse.java index fb8f58b7..8774bc36 100644 --- a/prism/src/explicit/MDPSparse.java +++ b/prism/src/explicit/MDPSparse.java @@ -812,6 +812,43 @@ public class MDPSparse extends MDPExplicit return minmax; } + @Override + public double mvMultRewJacMinMaxSingle(int s, double vect[], MDPRewards mdpRewards, boolean min) + { + int j, k, l1, h1, l2, h2; + double diag, d, minmax; + boolean first; + + minmax = 0; + first = true; + l1 = rowStarts[s]; + h1 = rowStarts[s + 1]; + for (j = l1; j < h1; j++) { + diag = 1.0; + // Compute sum for this distribution + d = mdpRewards.getTransitionReward(s, j - l1); + l2 = choiceStarts[j]; + h2 = choiceStarts[j + 1]; + for (k = l2; k < h2; k++) { + if (cols[k] != s) { + d += nonZeros[k] * vect[cols[k]]; + } else { + diag -= nonZeros[k]; + } + } + if (diag > 0) + d /= diag; + // Check whether we have exceeded min/max so far + if (first || (min && d < minmax) || (!min && d > minmax)) + minmax = d; + first = false; + } + // Add state reward (doesn't affect min/max) + minmax += mdpRewards.getStateReward(s); + + return minmax; + } + @Override public List mvMultRewMinMaxSingleChoices(int s, double vect[], MDPRewards mdpRewards, boolean min, double val) {