diff --git a/prism/src/explicit/MDPModelChecker.java b/prism/src/explicit/MDPModelChecker.java index 4f2c16b5..47b686a2 100644 --- a/prism/src/explicit/MDPModelChecker.java +++ b/prism/src/explicit/MDPModelChecker.java @@ -32,6 +32,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import common.IterableStateSet; import parser.VarList; import parser.ast.Declaration; import parser.ast.DeclarationIntUnbounded; @@ -42,6 +43,7 @@ import prism.PrismDevNullLog; import prism.PrismException; import prism.PrismFileLog; import prism.PrismLog; +import prism.PrismNotSupportedException; import prism.PrismUtils; import strat.MDStrategyArray; import acceptance.AcceptanceReach; @@ -1292,6 +1294,126 @@ public class MDPModelChecker extends ProbModelChecker return res; } + /** + * Compute total expected rewards. + * @param mdp The MDP + * @param mdpRewards The rewards + * @param min Min or max rewards (true=min, false=max) + */ + public ModelCheckerResult computeTotalRewards(MDP mdp, MDPRewards mdpRewards, boolean min) throws PrismException + { + if (min) { + throw new PrismNotSupportedException("Minimum total expected reward not supported in explicit engine"); + } else { + // max. We don't know if there are positive ECs, so we can't skip precomputation + return computeTotalRewardsMax(mdp, mdpRewards, false); + } + } + + /** + * Compute maximal total expected rewards. + * @param mdp The MDP + * @param mdpRewards The rewards + * @param noPositiveECs if true, there are no positive ECs, i.e., all states have finite values (skip precomputation) + */ + public ModelCheckerResult computeTotalRewardsMax(MDP mdp, MDPRewards mdpRewards, boolean noPositiveECs) throws PrismException + { + ModelCheckerResult res = null; + int n; + long timer; + BitSet inf; + + // Local copy of setting + MDPSolnMethod mdpSolnMethod = this.mdpSolnMethod; + + // Switch to a supported method, if necessary + if (!(mdpSolnMethod == MDPSolnMethod.VALUE_ITERATION || mdpSolnMethod == MDPSolnMethod.GAUSS_SEIDEL || mdpSolnMethod == MDPSolnMethod.POLICY_ITERATION)) { + mdpSolnMethod = MDPSolnMethod.GAUSS_SEIDEL; + mainLog.printWarning("Switching to MDP solution method \"" + mdpSolnMethod.fullName() + "\""); + } + + // Start expected total reward + timer = System.currentTimeMillis(); + mainLog.println("\nStarting total expected reward (max)..."); + + // Store num states + n = mdp.getNumStates(); + + long timerPre; + + if (noPositiveECs) { + // no inf states + inf = new BitSet(); + timerPre = 0; + } else { + mainLog.println("Precomputation: Find positive end components..."); + + timerPre = System.currentTimeMillis(); + + ECComputer ecs = ECComputer.createECComputer(this, mdp); + ecs.computeMECStates(); + BitSet positiveECs = new BitSet(); + for (BitSet ec : ecs.getMECStates()) { + // check if this MEC is positive + boolean positiveEC = false; + for (int state : new IterableStateSet(ec, n)) { + if (mdpRewards.getStateReward(state) > 0) { + // state with positive reward in this MEC + positiveEC = true; + break; + } + for (int choice = 0, numChoices = mdp.getNumChoices(state); choice < numChoices; choice++) { + if (mdpRewards.getTransitionReward(state, choice) > 0 && + mdp.allSuccessorsInSet(state, choice, ec)) { + // choice from this state with positive reward back into this MEC + positiveEC = true; + break; + } + } + } + if (positiveEC) { + positiveECs.or(ec); + } + } + + // inf = Pmax[ <> positiveECs ] > 0 + // = ! (Pmax[ <> positiveECs ] = 0) + inf = prob0(mdp, null, positiveECs, false, null); // Pmax[ <> positiveECs ] = 0 + inf.flip(0,n); // !(Pmax[ <> positive ECs ] = 0) = Pmax[ <> positiveECs ] > 0 + + timerPre = System.currentTimeMillis() - timerPre; + mainLog.println("Precomputation took " + timerPre / 1000.0 + " seconds, " + inf.cardinality() + " infinite states, " + (n - inf.cardinality()) + " states remaining."); + } + + // Compute rewards + // do standard max reward calculation, but with empty target set + switch (mdpSolnMethod) { + case VALUE_ITERATION: + res = computeReachRewardsValIter(mdp, mdpRewards, new BitSet(), inf, false, null, null, null); + break; + case GAUSS_SEIDEL: + res = computeReachRewardsGaussSeidel(mdp, mdpRewards, new BitSet(), inf, false, null, null, null); + break; + case POLICY_ITERATION: + res = computeReachRewardsPolIter(mdp, mdpRewards, new BitSet(), inf, false, null); + break; + default: + throw new PrismException("Unknown MDP solution method " + mdpSolnMethod.fullName()); + } + + // Finished expected total reward + timer = System.currentTimeMillis() - timer; + mainLog.println("Expected total reward took " + timer / 1000.0 + " seconds."); + + // Update time taken + res.timeTaken = timer / 1000.0; + res.timePre = timerPre / 1000.0; + + // Return results + return res; + } + + /** * Compute expected reachability rewards. * @param mdp The MDP diff --git a/prism/src/explicit/ProbModelChecker.java b/prism/src/explicit/ProbModelChecker.java index 488e2c0d..b1334ae2 100644 --- a/prism/src/explicit/ProbModelChecker.java +++ b/prism/src/explicit/ProbModelChecker.java @@ -1080,6 +1080,8 @@ public class ProbModelChecker extends NonProbModelChecker res = ((CTMCModelChecker) this).computeTotalRewards((CTMC) model, (MCRewards) modelRewards); break; case MDP: + res = ((MDPModelChecker) this).computeTotalRewards((MDP) model, (MDPRewards) modelRewards, minMax.isMin()); + break; default: throw new PrismNotSupportedException("Explicit engine does not yet handle the " + expr.getOperatorSymbol() + " reward operator for " + model.getModelType() + "s");