diff --git a/prism/src/explicit/MDPModelChecker.java b/prism/src/explicit/MDPModelChecker.java index a385eb24..78549e10 100644 --- a/prism/src/explicit/MDPModelChecker.java +++ b/prism/src/explicit/MDPModelChecker.java @@ -26,16 +26,19 @@ package explicit; -import java.util.*; +import java.util.BitSet; +import java.util.List; +import java.util.Map; import parser.ast.Expression; import parser.ast.ExpressionTemporal; -import java.util.Map.Entry; - +import prism.PrismDevNullLog; +import prism.PrismException; +import prism.PrismFileLog; +import prism.PrismLog; +import prism.PrismUtils; import explicit.rewards.MDPRewards; -import prism.*; - /** * Explicit-state model checker for Markov decision processes (MDPs). */ @@ -248,9 +251,17 @@ public class MDPModelChecker extends ProbModelChecker int i, n, numYes, numNo; long timer, timerProb0, timerProb1; boolean genAdv; + // Local copy of setting + MDPSolnMethod mdpSolnMethod = this.mdpSolnMethod; + // Switch to a supported method, if necessary + if (mdpSolnMethod == MDPSolnMethod.LINEAR_PROGRAMMING) { + mdpSolnMethod = MDPSolnMethod.GAUSS_SEIDEL; + mainLog.printWarning("Switching to MDP solution method \"" + mdpSolnMethod.fullName() + "\""); + } + // Check for some unsupported combinations - if (solnMethod == SolnMethod.VALUE_ITERATION && valIterDir == ValIterDir.ABOVE) { + if (mdpSolnMethod == MDPSolnMethod.VALUE_ITERATION && valIterDir == ValIterDir.ABOVE) { if (!(precomp && prob0)) throw new PrismException("Precomputation (Prob0) must be enabled for value iteration from above"); if (!min) @@ -301,7 +312,7 @@ public class MDPModelChecker extends ProbModelChecker mainLog.println("target=" + target.cardinality() + ", yes=" + numYes + ", no=" + numNo + ", maybe=" + (n - (numYes + numNo))); // Compute probabilities - switch (solnMethod) { + switch (mdpSolnMethod) { case VALUE_ITERATION: res = computeReachProbsValIter(mdp, no, yes, min, init, known); break; @@ -315,7 +326,7 @@ public class MDPModelChecker extends ProbModelChecker res = computeReachProbsModPolIter(mdp, no, yes, min); break; default: - throw new PrismException("Unknown MDP solution method " + solnMethod); + throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName()); } // Finished probabilistic reachability @@ -1004,6 +1015,14 @@ public class MDPModelChecker extends ProbModelChecker BitSet inf; int i, n, numTarget, numInf; long timer, timerProb1; + // Local copy of setting + MDPSolnMethod mdpSolnMethod = this.mdpSolnMethod; + + // Switch to a supported method, if necessary + if (!(mdpSolnMethod == MDPSolnMethod.VALUE_ITERATION || mdpSolnMethod == MDPSolnMethod.GAUSS_SEIDEL)) { + mdpSolnMethod = MDPSolnMethod.GAUSS_SEIDEL; + mainLog.printWarning("Switching to MDP solution method \"" + mdpSolnMethod.fullName() + "\""); + } // Start expected reachability timer = System.currentTimeMillis(); @@ -1036,7 +1055,7 @@ public class MDPModelChecker extends ProbModelChecker mainLog.println("target=" + numTarget + ", inf=" + numInf + ", rest=" + (n - (numTarget + numInf))); // Compute rewards - switch (solnMethod) { + switch (mdpSolnMethod) { case VALUE_ITERATION: res = computeReachRewardsValIter(mdp, mdpRewards, target, inf, min, init, known); break; @@ -1044,7 +1063,7 @@ public class MDPModelChecker extends ProbModelChecker res = computeReachRewardsGaussSeidel(mdp, mdpRewards, target, inf, min, init, known); break; default: - throw new PrismException("Unknown MDP solution method " + solnMethod); + throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName()); } // Finished expected reachability diff --git a/prism/src/explicit/ProbModelChecker.java b/prism/src/explicit/ProbModelChecker.java index 58589106..7ae0a3eb 100644 --- a/prism/src/explicit/ProbModelChecker.java +++ b/prism/src/explicit/ProbModelChecker.java @@ -49,6 +49,8 @@ public class ProbModelChecker extends StateModelChecker // Method used to solve linear equation systems protected LinEqMethod linEqMethod = LinEqMethod.GAUSS_SEIDEL; + // Method used to solve MDPs + protected MDPSolnMethod mdpSolnMethod = MDPSolnMethod.GAUSS_SEIDEL; // Iterative numerical method termination criteria protected TermCrit termCrit = TermCrit.RELATIVE; // Parameter for iterative numerical method termination criteria @@ -95,6 +97,28 @@ public class ProbModelChecker extends StateModelChecker } }; + // Method used for solving MDPs + public enum MDPSolnMethod { + VALUE_ITERATION, GAUSS_SEIDEL, POLICY_ITERATION, MODIFIED_POLICY_ITERATION, LINEAR_PROGRAMMING; + public String fullName() + { + switch (this) { + case VALUE_ITERATION: + return "Value iteration"; + case GAUSS_SEIDEL: + return "Gauss-Seidel"; + case POLICY_ITERATION: + return "Policy iteration"; + case MODIFIED_POLICY_ITERATION: + return "Modified policy iteration"; + case LINEAR_PROGRAMMING: + return "Linear programming"; + default: + return this.toString(); + } + } + }; + // Iterative numerical method termination criteria public enum TermCrit { ABSOLUTE, RELATIVE @@ -107,7 +131,7 @@ public class ProbModelChecker extends StateModelChecker // Method used for numerical solution public enum SolnMethod { - VALUE_ITERATION, GAUSS_SEIDEL, POLICY_ITERATION, MODIFIED_POLICY_ITERATION + VALUE_ITERATION, GAUSS_SEIDEL, POLICY_ITERATION, MODIFIED_POLICY_ITERATION, LINEAR_PROGRAMMING }; // Settings methods @@ -137,6 +161,21 @@ public class ProbModelChecker extends StateModelChecker } else { throw new PrismException("Explicit engine does not support linear equation solution method \"" + s + "\""); } + // PRISM_MDP_SOLN_METHOD + s = settings.getString(PrismSettings.PRISM_MDP_SOLN_METHOD); + if (s.equals("Value iteration")) { + setSolnMethod(SolnMethod.VALUE_ITERATION); + } else if (s.equals("Gauss-Seidel")) { + setSolnMethod(SolnMethod.GAUSS_SEIDEL); + } else if (s.equals("Policy iteration")) { + setSolnMethod(SolnMethod.POLICY_ITERATION); + } else if (s.equals("Modified policy iteration")) { + setSolnMethod(SolnMethod.MODIFIED_POLICY_ITERATION); + } else if (s.equals("Linear programming")) { + setSolnMethod(SolnMethod.LINEAR_PROGRAMMING); + } else { + throw new PrismException("Explicit engine does not support MDP solution method \"" + s + "\""); + } s = settings.getString(PrismSettings.PRISM_TERM_CRIT); if (s.equals("Absolute")) { @@ -150,16 +189,6 @@ public class ProbModelChecker extends StateModelChecker setProb0(settings.getBoolean(PrismSettings.PRISM_PROB0)); setProb1(settings.getBoolean(PrismSettings.PRISM_PROB1)); // valiterdir - s = settings.getString(PrismSettings.PRISM_MDP_SOLN_METHOD); - if (s.equals("Gauss-Seidel")) { - setSolnMethod(SolnMethod.GAUSS_SEIDEL); - } else if (s.equals("Policy iteration")) { - setSolnMethod(SolnMethod.POLICY_ITERATION); - } else if (s.equals("Modified policy iteration")) { - setSolnMethod(SolnMethod.MODIFIED_POLICY_ITERATION); - } else { - setSolnMethod(SolnMethod.VALUE_ITERATION); - } s = settings.getString(PrismSettings.PRISM_EXPORT_ADV); if (!(s.equals("None"))) exportAdv = true; @@ -189,6 +218,7 @@ public class ProbModelChecker extends StateModelChecker { super.printSettings(); mainLog.print("linEqMethod = " + linEqMethod + " "); + mainLog.print("mdpSolnMethod = " + mdpSolnMethod + " "); mainLog.print("termCrit = " + termCrit + " "); mainLog.print("termCritParam = " + termCritParam + " "); mainLog.print("maxIters = " + maxIters + " "); @@ -217,6 +247,14 @@ public class ProbModelChecker extends StateModelChecker this.linEqMethod = linEqMethod; } + /** + * Set method used to solve MDPs. + */ + public void setMDPSolnMethod(MDPSolnMethod mdpSolnMethod) + { + this.mdpSolnMethod = mdpSolnMethod; + } + /** * Set termination criteria type for numerical iterative methods. */ @@ -293,6 +331,11 @@ public class ProbModelChecker extends StateModelChecker return linEqMethod; } + public MDPSolnMethod getMDPSolnMethod() + { + return mdpSolnMethod; + } + public TermCrit getTermCrit() { return termCrit;