Browse Source

Refactoring in ConstructRewards. Separate out repeated code.

accumulation-v4.7
Dave Parker 6 years ago
parent
commit
ea1af800d7
  1. 200
      prism/src/explicit/rewards/ConstructRewards.java

200
prism/src/explicit/rewards/ConstructRewards.java

@ -38,6 +38,7 @@ import explicit.MDP;
import explicit.Model;
import parser.State;
import parser.Values;
import parser.ast.ASTElement;
import parser.ast.Expression;
import parser.ast.RewardStruct;
import prism.PrismComponent;
@ -90,37 +91,15 @@ public class ConstructRewards extends PrismComponent
*/
public MCRewards buildMCRewardStructure(DTMC mc, RewardGenerator rewardGen, int r) throws PrismException
{
if (rewardGen == null) {
throw new PrismException("No reward generator to build reward structure");
}
// TODO: Transition rewards for Markov chains not supported yet
if (rewardGen.rewardStructHasTransitionRewards(r)) {
throw new PrismNotSupportedException("Explicit engine does not yet handle transition rewards for D/CTMCs");
}
int numStates = mc.getNumStates();
List<State> statesList = mc.getStatesList();
StateRewardsArray rewSA = new StateRewardsArray(numStates);
for (int s = 0; s < numStates; s++) {
// State rewards
double rew = 0;
Object stateIndex = null;
if (rewardGen.rewardStructHasStateRewards(r)) {
if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) {
State state = statesList.get(s);
stateIndex = state;
rew = rewardGen.getStateReward(r, state);
} else if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE_INDEX)) {
stateIndex = s;
rew = rewardGen.getStateReward(r, s);
} else {
throw new PrismException("Unknown state lookup mechanism for reward generator");
}
if (Double.isNaN(rew))
throw new PrismException("State reward evaluates to NaN at state " + stateIndex);
if (!allowNegative && rew < 0)
throw new PrismException("State reward is negative (" + rew + ") at state " + stateIndex + "");
double rew = getAndCheckStateReward(s, rewardGen, r, statesList);
rewSA.addToStateReward(s, rew);
}
}
@ -139,10 +118,37 @@ public class ConstructRewards extends PrismComponent
List<State> statesList = mdp.getStatesList();
MDPRewardsSimple rewSimple = new MDPRewardsSimple(numStates);
for (int s = 0; s < numStates; s++) {
// State rewards
if (rewardGen.rewardStructHasStateRewards(r)) {
double rew = getAndCheckStateReward(s, rewardGen, r, statesList);
rewSimple.addToStateReward(s, rew);
}
if (rewardGen.rewardStructHasTransitionRewards(r)) {
// Don't add rewards to transitions added to "fix" deadlock states
if (mdp.isDeadlockState(s)) {
continue;
}
int numChoices = mdp.getNumChoices(s);
for (int k = 0; k < numChoices; k++) {
double rew = getAndCheckStateActionReward(s, mdp.getAction(s, k), rewardGen, r, statesList);
rewSimple.addToTransitionReward(s, k, rew);
}
}
}
return rewSimple;
}
/**
* Get a state reward for a specific state and reward structure from a RewardGenerator.
* Also check that the state reward is legal. Throw an exception if not.
* @param s The index of the state
* @param rewardGen The RewardGenerator defining the rewards
* @param r The index of the reward structure to build
* @param statesLists List of states (maybe needed for state look up)
*/
private double getAndCheckStateReward(int s, RewardGenerator rewardGen, int r, List<State> statesList) throws PrismException
{
double rew = 0;
Object stateIndex = null;
if (rewardGen.rewardStructHasStateRewards(r)) {
if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) {
State state = statesList.get(s);
stateIndex = state;
@ -151,43 +157,36 @@ public class ConstructRewards extends PrismComponent
stateIndex = s;
rew = rewardGen.getStateReward(r, s);
} else {
throw new PrismException("Unknown state lookup mechanism for reward generator");
}
if (Double.isNaN(rew))
throw new PrismException("State reward evaluates to NaN at state " + stateIndex);
if (!allowNegative && rew < 0)
throw new PrismException("State reward is negative (" + rew + ") at state " + stateIndex + "");
rewSimple.addToStateReward(s, rew);
throw new PrismException("Unknown reward lookup mechanism for reward generator");
}
// State-action rewards
if (rewardGen.rewardStructHasTransitionRewards(r)) {
if (mdp.isDeadlockState(s)) {
// As state s is a deadlock state, any outgoing transition
// was added to "fix" the deadlock and thus does not get a reward.
// Skip to next state
continue;
checkStateReward(rew, stateIndex, null);
return rew;
}
int numChoices = mdp.getNumChoices(s);
for (int k = 0; k < numChoices; k++) {
/**
* Get a state reward for a specific state and reward structure from a RewardGenerator.
* Also check that the state reward is legal. Throw an exception if not.
* @param s The index of the state
* @param rewardGen The RewardGenerator defining the rewards
* @param r The index of the reward structure to build
* @param statesLists List of states (maybe needed for state look up)
*/
private double getAndCheckStateActionReward(int s, Object action, RewardGenerator rewardGen, int r, List<State> statesList) throws PrismException
{
double rew = 0;
Object stateIndex = null;
if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) {
State state = statesList.get(s);
stateIndex = state;
rew = rewardGen.getStateActionReward(r, state, mdp.getAction(s, k));
rew = rewardGen.getStateActionReward(r, state, action);
} else if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE_INDEX)) {
stateIndex = s;
rew = rewardGen.getStateActionReward(r, s, mdp.getAction(s, k));
rew = rewardGen.getStateActionReward(r, s, action);
} else {
throw new PrismException("Unknown state lookup mechanism for reward generator");
throw new PrismException("Unknown reward lookup mechanism for reward generator");
}
if (Double.isNaN(rew))
throw new PrismException("Transition reward evaluates to NaN at state " + stateIndex);
if (!allowNegative && rew < 0)
throw new PrismException("Transition reward is negative (" + rew + ") at state " + stateIndex + "");
rewSimple.addToTransitionReward(s, k, rew);
}
}
}
return rewSimple;
checkTransitionReward(rew, stateIndex, null);
return rew;
}
/**
@ -222,16 +221,12 @@ public class ConstructRewards extends PrismComponent
int i, j, n, numStates;
if (rewStr.getNumTransItems() > 0) {
// TODO
throw new PrismNotSupportedException("Explicit engine does not yet handle transition rewards for D/CTMCs");
}
// Special case: constant rewards
if (rewStr.getNumStateItems() == 1 && Expression.isTrue(rewStr.getStates(0)) && rewStr.getReward(0).isConstant()) {
double rew = rewStr.getReward(0).evaluateDouble(constantValues);
if (Double.isNaN(rew))
throw new PrismLangException("Reward structure evaluates to NaN (at any state)", rewStr.getReward(0));
if (!allowNegative && rew < 0)
throw new PrismLangException("Reward structure evaluates to " + rew + " (at any state), negative rewards not allowed", rewStr.getReward(0));
checkStateReward(rew, null, rewStr.getReward(0));
return new StateRewardsConstant(rew);
}
// Normal: state rewards
@ -245,10 +240,7 @@ public class ConstructRewards extends PrismComponent
for (j = 0; j < numStates; j++) {
if (guard.evaluateBoolean(constantValues, statesList.get(j))) {
double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j));
if (Double.isNaN(rew))
throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i));
if (!allowNegative && rew < 0)
throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i));
checkStateReward(rew, statesList.get(j), rewStr.getReward(i));
rewSA.addToStateReward(j, rew);
}
}
@ -274,10 +266,7 @@ public class ConstructRewards extends PrismComponent
// Special case: constant state rewards
if (rewStr.getNumStateItems() == 1 && Expression.isTrue(rewStr.getStates(0)) && rewStr.getReward(0).isConstant()) {
double rew = rewStr.getReward(0).evaluateDouble(constantValues);
if (Double.isNaN(rew))
throw new PrismLangException("Reward structure evaluates to NaN (at any state)", rewStr.getReward(0));
if (!allowNegative && rew < 0)
throw new PrismLangException("Reward structure evaluates to " + rew + " (at any state), negative rewards not allowed", rewStr.getReward(0));
checkStateReward(rew, null, rewStr.getReward(0));
return new StateRewardsConstant(rew);
}
// Normal: state and transition rewards
@ -305,10 +294,7 @@ public class ConstructRewards extends PrismComponent
mdpAction = mdp.getAction(j, k);
if (mdpAction == null ? (action.isEmpty()) : mdpAction.equals(action)) {
double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j));
if (Double.isNaN(rew))
throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i));
if (!allowNegative && rew < 0)
throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i));
checkTransitionReward(rew, statesList.get(j), rewStr.getReward(i));
rewSimple.addToTransitionReward(j, k, rew);
}
}
@ -316,10 +302,7 @@ public class ConstructRewards extends PrismComponent
// State reward
else {
double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j));
if (Double.isNaN(rew))
throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i));
if (!allowNegative && rew < 0)
throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i));
checkStateReward(rew, statesList.get(j), rewStr.getReward(i));
rewSimple.addToStateReward(j, rew);
}
}
@ -360,9 +343,7 @@ public class ConstructRewards extends PrismComponent
ss = s.split(" ");
i = Integer.parseInt(ss[0]);
reward = Double.parseDouble(ss[1]);
if (!allowNegative && reward < 0) {
throw new PrismLangException("Found state reward " + reward + " at state " + i +", negative rewards not allowed");
}
checkStateReward(reward, i, null);
rewSA.setStateReward(i, reward);
}
s = in.readLine();
@ -413,9 +394,7 @@ public class ConstructRewards extends PrismComponent
ss = s.split(" ");
i = Integer.parseInt(ss[0]);
reward = Double.parseDouble(ss[1]);
if (!allowNegative && reward < 0) {
throw new PrismLangException("Found state reward " + reward + " at state " + i +", negative rewards not allowed");
}
checkStateReward(reward, i, null);
rs.setStateReward(i, reward);
}
s = in.readLine();
@ -447,9 +426,7 @@ public class ConstructRewards extends PrismComponent
i = Integer.parseInt(ss[0]);
j = Integer.parseInt(ss[1]);
reward = Double.parseDouble(ss[3]);
if (!allowNegative && reward < 0) {
throw new PrismLangException("Found transition reward " + reward + " at state " + i +", action " + j +", negative rewards not allowed");
}
checkTransitionReward(reward, i, null);
rs.setTransitionReward(i, j, reward);
}
s = in.readLine();
@ -465,4 +442,61 @@ public class ConstructRewards extends PrismComponent
return rs;
}
/**
* Check that a state reward is legal. Throw an exception if not.
* Optionally, provide a state where the error occurs (as an Object),
* and/or a pointer to where the error occurs syntactically (as an ASTElement)
* @param rew The reward value
* @param stateIndex The index of the state, for error reporting (optional)
* @param ast Where the error occurred, for error reporting (optional)
*/
private void checkStateReward(double rew, Object stateIndex, ASTElement ast) throws PrismException
{
String error = null;
if (Double.isNaN(rew)) {
error = "State reward evaluates to NaN";
} else if (!allowNegative && rew < 0) {
error = "State reward is negative (" + rew + ")";
}
if (error != null) {
if (stateIndex != null) {
error += " at state " + stateIndex;
}
if (ast != null) {
throw new PrismLangException(error, ast);
} else {
throw new PrismException(error);
}
}
}
/**
* Check that a state reward is legal. Throw an exception if not.
* @param rew The reward value
* Optionally, provide a state where the error occurs (as an Object),
* and/or a pointer to where the error occurs syntactically (as an ASTElement)
* @param rew The reward value
* @param stateIndex The index of the state, for error reporting (optional)
* @param ast Where the error occurred, for error reporting (optional)
*/
private void checkTransitionReward(double rew, Object stateIndex, ASTElement ast) throws PrismException
{
String error = null;
if (Double.isNaN(rew)) {
error = "Transition reward evaluates to NaN";
} else if (!allowNegative && rew < 0) {
error = "Transition reward is negative (" + rew + ")";
}
if (error != null) {
if (stateIndex != null) {
error += " at state " + stateIndex;
}
if (ast != null) {
throw new PrismLangException(error, ast);
} else {
throw new PrismException(error);
}
}
}
}
Loading…
Cancel
Save