Browse Source

Expected total rewards (R[C]) implemented for DTMCs in symbolic engine.

git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@9511 bbc10eb1-c90d-0410-af57-cb519fbb1720
master
Dave Parker 11 years ago
parent
commit
cc8d3883f6
  1. 29
      prism/src/prism/NondetModelChecker.java
  2. 116
      prism/src/prism/ProbModelChecker.java

29
prism/src/prism/NondetModelChecker.java

@ -213,15 +213,20 @@ public class NondetModelChecker extends NonProbModelChecker
StateValues rewards = null; StateValues rewards = null;
Expression expr2 = expr.getExpression(); Expression expr2 = expr.getExpression();
if (expr2 instanceof ExpressionTemporal) { if (expr2 instanceof ExpressionTemporal) {
switch (((ExpressionTemporal) expr2).getOperator()) {
ExpressionTemporal exprTemp = (ExpressionTemporal) expr2;
switch (exprTemp.getOperator()) {
case ExpressionTemporal.R_C: case ExpressionTemporal.R_C:
rewards = checkRewardCumul((ExpressionTemporal) expr2, stateRewards, transRewards, minMax.isMin());
if (exprTemp.hasBounds()) {
rewards = checkRewardCumul(exprTemp, stateRewards, transRewards, minMax.isMin());
} else {
rewards = checkRewardTotal(exprTemp, stateRewards, transRewards, minMax.isMin());
}
break; break;
case ExpressionTemporal.R_I: case ExpressionTemporal.R_I:
rewards = checkRewardInst((ExpressionTemporal) expr2, stateRewards, transRewards, minMax.isMin());
rewards = checkRewardInst(exprTemp, stateRewards, transRewards, minMax.isMin());
break; break;
case ExpressionTemporal.R_F: case ExpressionTemporal.R_F:
rewards = checkRewardReach((ExpressionTemporal) expr2, stateRewards, transRewards, minMax.isMin());
rewards = checkRewardReach(exprTemp, stateRewards, transRewards, minMax.isMin());
break; break;
} }
} }
@ -1062,6 +1067,15 @@ public class NondetModelChecker extends NonProbModelChecker
return rewards; return rewards;
} }
/**
* Compute rewards for a total reward operator.
*/
protected StateValues checkRewardTotal(ExpressionTemporal expr, JDDNode stateRewards, JDDNode transRewards, boolean min) throws PrismException
{
StateValues rewards = computeTotalRewards(trans, trans01, stateRewards, transRewards, min);
return rewards;
}
/** /**
* Compute rewards for an instantaneous reward operator. * Compute rewards for an instantaneous reward operator.
*/ */
@ -1496,6 +1510,13 @@ public class NondetModelChecker extends NonProbModelChecker
return rewards; return rewards;
} }
// compute cumulative rewards
protected StateValues computeTotalRewards(JDDNode tr, JDDNode tr01, JDDNode sr, JDDNode trr, boolean min) throws PrismException
{
throw new PrismException("Expected total reward (C) is not yet supported for MDPs.");
}
// compute rewards for inst reward // compute rewards for inst reward
protected StateValues computeInstRewards(JDDNode tr, JDDNode sr, int time, boolean min) throws PrismException protected StateValues computeInstRewards(JDDNode tr, JDDNode sr, int time, boolean min) throws PrismException

116
prism/src/prism/ProbModelChecker.java

@ -216,18 +216,23 @@ public class ProbModelChecker extends NonProbModelChecker
StateValues rewards = null; StateValues rewards = null;
Expression expr2 = expr.getExpression(); Expression expr2 = expr.getExpression();
if (expr2 instanceof ExpressionTemporal) { if (expr2 instanceof ExpressionTemporal) {
switch (((ExpressionTemporal) expr2).getOperator()) {
ExpressionTemporal exprTemp = (ExpressionTemporal) expr2;
switch (exprTemp.getOperator()) {
case ExpressionTemporal.R_C: case ExpressionTemporal.R_C:
rewards = checkRewardCumul((ExpressionTemporal) expr2, stateRewards, transRewards);
if (exprTemp.hasBounds()) {
rewards = checkRewardCumul(exprTemp, stateRewards, transRewards);
} else {
rewards = checkRewardTotal(exprTemp, stateRewards, transRewards);
}
break; break;
case ExpressionTemporal.R_I: case ExpressionTemporal.R_I:
rewards = checkRewardInst((ExpressionTemporal) expr2, stateRewards, transRewards);
rewards = checkRewardInst(exprTemp, stateRewards, transRewards);
break; break;
case ExpressionTemporal.R_F: case ExpressionTemporal.R_F:
rewards = checkRewardReach((ExpressionTemporal) expr2, stateRewards, transRewards);
rewards = checkRewardReach(exprTemp, stateRewards, transRewards);
break; break;
case ExpressionTemporal.R_S: case ExpressionTemporal.R_S:
rewards = checkRewardSS((ExpressionTemporal) expr2, stateRewards, transRewards);
rewards = checkRewardSS(exprTemp, stateRewards, transRewards);
break; break;
} }
} }
@ -761,6 +766,14 @@ public class ProbModelChecker extends NonProbModelChecker
return rewards; return rewards;
} }
// cumulative reward
protected StateValues checkRewardTotal(ExpressionTemporal expr, JDDNode stateRewards, JDDNode transRewards) throws PrismException
{
StateValues rewards = computeTotalRewards(trans, trans01, stateRewards, transRewards);
return rewards;
}
// inst reward // inst reward
protected StateValues checkRewardInst(ExpressionTemporal expr, JDDNode stateRewards, JDDNode transRewards) throws PrismException protected StateValues checkRewardInst(ExpressionTemporal expr, JDDNode stateRewards, JDDNode transRewards) throws PrismException
@ -1418,6 +1431,99 @@ public class ProbModelChecker extends NonProbModelChecker
return rewards; return rewards;
} }
// compute total rewards
protected StateValues computeTotalRewards(JDDNode tr, JDDNode tr01, JDDNode sr, JDDNode trr) throws PrismException
{
JDDNode rewardsMTBDD;
DoubleVector rewardsDV;
StateValues rewards = null;
// BSCC stuff
List<JDDNode> bsccs = null;
JDDNode notInBSCCs = null;
int numBSCCs = 0;
// Compute bottom strongly connected components (BSCCs)
SCCComputer sccComputer = prism.getSCCComputer(model);
sccComputer.computeBSCCs();
bsccs = sccComputer.getBSCCs();
notInBSCCs = sccComputer.getNotInBSCCs();
numBSCCs = bsccs.size();
// Find BSCCs with non-zero reward
JDD.Ref(sr);
JDDNode srNonZero = JDD.GreaterThan(sr, 0);
JDD.Ref(trr);
JDDNode trrNonZero = JDD.GreaterThan(trr, 0);
JDDNode bsccsNonZero = JDD.Constant(0);
for (int b = 0; b < numBSCCs; b++) {
if (JDD.AreInterecting(bsccs.get(b), srNonZero) || JDD.AreInterecting(bsccs.get(b), trrNonZero)) {
JDD.Ref(bsccs.get(b));
bsccsNonZero = JDD.Or(bsccsNonZero, bsccs.get(b));
}
}
JDD.Deref(srNonZero);
JDD.Deref(trrNonZero);
mainLog.print("States in non-zero reward BSCCs: " + JDD.GetNumMintermsString(bsccsNonZero, allDDRowVars.n()));
// Find states with infinite reward (those reach a non-zero reward BSCC with prob > 0)
JDDNode inf = PrismMTBDD.Prob0(tr01, reach, allDDRowVars, allDDColVars, reach, bsccsNonZero);
inf = JDD.And(reach, JDD.Not(inf));
JDDNode maybe = JDD.And(reach, JDD.Not(inf));
JDD.Ref(bsccsNonZero);
// Print out inf/maybe
mainLog.print("\ninf = " + JDD.GetNumMintermsString(inf, allDDRowVars.n()));
mainLog.print(", maybe = " + JDD.GetNumMintermsString(maybe, allDDRowVars.n()) + "\n");
// If maybe is empty, we have the rewards already
if (maybe.equals(JDD.ZERO)) {
JDD.Ref(inf);
rewards = new StateValuesMTBDD(JDD.ITE(inf, JDD.PlusInfinity(), JDD.Constant(0)), model);
}
// Otherwise we compute the actual rewards
else {
// Compute the rewards
// (do this using the functions for "reward reachability" properties but with no targets)
mainLog.println("\nComputing remaining rewards...");
mainLog.println("Engine: " + Prism.getEngineString(engine));
try {
switch (engine) {
case Prism.MTBDD:
rewardsMTBDD = PrismMTBDD.ProbReachReward(tr, sr, trr, odd, allDDRowVars, allDDColVars, JDD.ZERO, inf, maybe);
rewards = new StateValuesMTBDD(rewardsMTBDD, model);
break;
case Prism.SPARSE:
rewardsDV = PrismSparse.ProbReachReward(tr, sr, trr, odd, allDDRowVars, allDDColVars, JDD.ZERO, inf, maybe);
rewards = new StateValuesDV(rewardsDV, model);
break;
case Prism.HYBRID:
rewardsDV = PrismHybrid.ProbReachReward(tr, sr, trr, odd, allDDRowVars, allDDColVars, JDD.ZERO, inf, maybe);
rewards = new StateValuesDV(rewardsDV, model);
break;
default:
throw new PrismException("Unknown engine");
}
} catch (PrismException e) {
JDD.Deref(inf);
JDD.Deref(maybe);
throw e;
}
}
// Tidy up
for (int b = 0; b < numBSCCs; b++) {
if (bsccs.get(b) != null)
JDD.Deref(bsccs.get(b));
}
if (start != notInBSCCs)
JDD.Deref(notInBSCCs);
JDD.Deref(inf);
JDD.Deref(maybe);
return rewards;
}
// compute rewards for inst reward // compute rewards for inst reward
protected StateValues computeInstRewards(JDDNode tr, JDDNode sr, int time) throws PrismException protected StateValues computeInstRewards(JDDNode tr, JDDNode sr, int time) throws PrismException

Loading…
Cancel
Save