From cc8d3883f6d4fa2f262103be57b6079bfbefd79e Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Thu, 8 Jan 2015 15:44:50 +0000 Subject: [PATCH] Expected total rewards (R[C]) implemented for DTMCs in symbolic engine. git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@9511 bbc10eb1-c90d-0410-af57-cb519fbb1720 --- prism/src/prism/NondetModelChecker.java | 29 +++++- prism/src/prism/ProbModelChecker.java | 116 +++++++++++++++++++++++- 2 files changed, 136 insertions(+), 9 deletions(-) diff --git a/prism/src/prism/NondetModelChecker.java b/prism/src/prism/NondetModelChecker.java index 61b1365f..2f4e861e 100644 --- a/prism/src/prism/NondetModelChecker.java +++ b/prism/src/prism/NondetModelChecker.java @@ -213,15 +213,20 @@ public class NondetModelChecker extends NonProbModelChecker StateValues rewards = null; Expression expr2 = expr.getExpression(); if (expr2 instanceof ExpressionTemporal) { - switch (((ExpressionTemporal) expr2).getOperator()) { + ExpressionTemporal exprTemp = (ExpressionTemporal) expr2; + switch (exprTemp.getOperator()) { case ExpressionTemporal.R_C: - rewards = checkRewardCumul((ExpressionTemporal) expr2, stateRewards, transRewards, minMax.isMin()); + if (exprTemp.hasBounds()) { + rewards = checkRewardCumul(exprTemp, stateRewards, transRewards, minMax.isMin()); + } else { + rewards = checkRewardTotal(exprTemp, stateRewards, transRewards, minMax.isMin()); + } break; case ExpressionTemporal.R_I: - rewards = checkRewardInst((ExpressionTemporal) expr2, stateRewards, transRewards, minMax.isMin()); + rewards = checkRewardInst(exprTemp, stateRewards, transRewards, minMax.isMin()); break; case ExpressionTemporal.R_F: - rewards = checkRewardReach((ExpressionTemporal) expr2, stateRewards, transRewards, minMax.isMin()); + rewards = checkRewardReach(exprTemp, stateRewards, transRewards, minMax.isMin()); break; } } @@ -1062,6 +1067,15 @@ public class NondetModelChecker extends NonProbModelChecker return rewards; } + /** + * Compute rewards for a total reward operator. + */ + protected StateValues checkRewardTotal(ExpressionTemporal expr, JDDNode stateRewards, JDDNode transRewards, boolean min) throws PrismException + { + StateValues rewards = computeTotalRewards(trans, trans01, stateRewards, transRewards, min); + return rewards; + } + /** * Compute rewards for an instantaneous reward operator. */ @@ -1496,6 +1510,13 @@ public class NondetModelChecker extends NonProbModelChecker return rewards; } + // compute cumulative rewards + + protected StateValues computeTotalRewards(JDDNode tr, JDDNode tr01, JDDNode sr, JDDNode trr, boolean min) throws PrismException + { + throw new PrismException("Expected total reward (C) is not yet supported for MDPs."); + } + // compute rewards for inst reward protected StateValues computeInstRewards(JDDNode tr, JDDNode sr, int time, boolean min) throws PrismException diff --git a/prism/src/prism/ProbModelChecker.java b/prism/src/prism/ProbModelChecker.java index 3ebb00d0..82449b29 100644 --- a/prism/src/prism/ProbModelChecker.java +++ b/prism/src/prism/ProbModelChecker.java @@ -216,18 +216,23 @@ public class ProbModelChecker extends NonProbModelChecker StateValues rewards = null; Expression expr2 = expr.getExpression(); if (expr2 instanceof ExpressionTemporal) { - switch (((ExpressionTemporal) expr2).getOperator()) { + ExpressionTemporal exprTemp = (ExpressionTemporal) expr2; + switch (exprTemp.getOperator()) { case ExpressionTemporal.R_C: - rewards = checkRewardCumul((ExpressionTemporal) expr2, stateRewards, transRewards); + if (exprTemp.hasBounds()) { + rewards = checkRewardCumul(exprTemp, stateRewards, transRewards); + } else { + rewards = checkRewardTotal(exprTemp, stateRewards, transRewards); + } break; case ExpressionTemporal.R_I: - rewards = checkRewardInst((ExpressionTemporal) expr2, stateRewards, transRewards); + rewards = checkRewardInst(exprTemp, stateRewards, transRewards); break; case ExpressionTemporal.R_F: - rewards = checkRewardReach((ExpressionTemporal) expr2, stateRewards, transRewards); + rewards = checkRewardReach(exprTemp, stateRewards, transRewards); break; case ExpressionTemporal.R_S: - rewards = checkRewardSS((ExpressionTemporal) expr2, stateRewards, transRewards); + rewards = checkRewardSS(exprTemp, stateRewards, transRewards); break; } } @@ -761,6 +766,14 @@ public class ProbModelChecker extends NonProbModelChecker return rewards; } + // cumulative reward + + protected StateValues checkRewardTotal(ExpressionTemporal expr, JDDNode stateRewards, JDDNode transRewards) throws PrismException + { + StateValues rewards = computeTotalRewards(trans, trans01, stateRewards, transRewards); + return rewards; + } + // inst reward protected StateValues checkRewardInst(ExpressionTemporal expr, JDDNode stateRewards, JDDNode transRewards) throws PrismException @@ -1418,6 +1431,99 @@ public class ProbModelChecker extends NonProbModelChecker return rewards; } + // compute total rewards + + protected StateValues computeTotalRewards(JDDNode tr, JDDNode tr01, JDDNode sr, JDDNode trr) throws PrismException + { + JDDNode rewardsMTBDD; + DoubleVector rewardsDV; + StateValues rewards = null; + // BSCC stuff + List bsccs = null; + JDDNode notInBSCCs = null; + int numBSCCs = 0; + + // Compute bottom strongly connected components (BSCCs) + SCCComputer sccComputer = prism.getSCCComputer(model); + sccComputer.computeBSCCs(); + bsccs = sccComputer.getBSCCs(); + notInBSCCs = sccComputer.getNotInBSCCs(); + numBSCCs = bsccs.size(); + + // Find BSCCs with non-zero reward + JDD.Ref(sr); + JDDNode srNonZero = JDD.GreaterThan(sr, 0); + JDD.Ref(trr); + JDDNode trrNonZero = JDD.GreaterThan(trr, 0); + JDDNode bsccsNonZero = JDD.Constant(0); + for (int b = 0; b < numBSCCs; b++) { + if (JDD.AreInterecting(bsccs.get(b), srNonZero) || JDD.AreInterecting(bsccs.get(b), trrNonZero)) { + JDD.Ref(bsccs.get(b)); + bsccsNonZero = JDD.Or(bsccsNonZero, bsccs.get(b)); + } + } + JDD.Deref(srNonZero); + JDD.Deref(trrNonZero); + mainLog.print("States in non-zero reward BSCCs: " + JDD.GetNumMintermsString(bsccsNonZero, allDDRowVars.n())); + + // Find states with infinite reward (those reach a non-zero reward BSCC with prob > 0) + JDDNode inf = PrismMTBDD.Prob0(tr01, reach, allDDRowVars, allDDColVars, reach, bsccsNonZero); + inf = JDD.And(reach, JDD.Not(inf)); + JDDNode maybe = JDD.And(reach, JDD.Not(inf)); + JDD.Ref(bsccsNonZero); + + // Print out inf/maybe + mainLog.print("\ninf = " + JDD.GetNumMintermsString(inf, allDDRowVars.n())); + mainLog.print(", maybe = " + JDD.GetNumMintermsString(maybe, allDDRowVars.n()) + "\n"); + + // If maybe is empty, we have the rewards already + if (maybe.equals(JDD.ZERO)) { + JDD.Ref(inf); + rewards = new StateValuesMTBDD(JDD.ITE(inf, JDD.PlusInfinity(), JDD.Constant(0)), model); + } + // Otherwise we compute the actual rewards + else { + // Compute the rewards + // (do this using the functions for "reward reachability" properties but with no targets) + mainLog.println("\nComputing remaining rewards..."); + mainLog.println("Engine: " + Prism.getEngineString(engine)); + try { + switch (engine) { + case Prism.MTBDD: + rewardsMTBDD = PrismMTBDD.ProbReachReward(tr, sr, trr, odd, allDDRowVars, allDDColVars, JDD.ZERO, inf, maybe); + rewards = new StateValuesMTBDD(rewardsMTBDD, model); + break; + case Prism.SPARSE: + rewardsDV = PrismSparse.ProbReachReward(tr, sr, trr, odd, allDDRowVars, allDDColVars, JDD.ZERO, inf, maybe); + rewards = new StateValuesDV(rewardsDV, model); + break; + case Prism.HYBRID: + rewardsDV = PrismHybrid.ProbReachReward(tr, sr, trr, odd, allDDRowVars, allDDColVars, JDD.ZERO, inf, maybe); + rewards = new StateValuesDV(rewardsDV, model); + break; + default: + throw new PrismException("Unknown engine"); + } + } catch (PrismException e) { + JDD.Deref(inf); + JDD.Deref(maybe); + throw e; + } + } + + // Tidy up + for (int b = 0; b < numBSCCs; b++) { + if (bsccs.get(b) != null) + JDD.Deref(bsccs.get(b)); + } + if (start != notInBSCCs) + JDD.Deref(notInBSCCs); + JDD.Deref(inf); + JDD.Deref(maybe); + + return rewards; + } + // compute rewards for inst reward protected StateValues computeInstRewards(JDDNode tr, JDDNode sr, int time) throws PrismException