From ea1af800d7dcbed207ce1de86c542866c53a5c36 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Sat, 23 May 2020 00:54:30 +0100 Subject: [PATCH] Refactoring in ConstructRewards. Separate out repeated code. --- .../explicit/rewards/ConstructRewards.java | 210 ++++++++++-------- 1 file changed, 122 insertions(+), 88 deletions(-) diff --git a/prism/src/explicit/rewards/ConstructRewards.java b/prism/src/explicit/rewards/ConstructRewards.java index 23d3ebfa..fc7378c4 100644 --- a/prism/src/explicit/rewards/ConstructRewards.java +++ b/prism/src/explicit/rewards/ConstructRewards.java @@ -38,6 +38,7 @@ import explicit.MDP; import explicit.Model; import parser.State; import parser.Values; +import parser.ast.ASTElement; import parser.ast.Expression; import parser.ast.RewardStruct; import prism.PrismComponent; @@ -90,37 +91,15 @@ public class ConstructRewards extends PrismComponent */ public MCRewards buildMCRewardStructure(DTMC mc, RewardGenerator rewardGen, int r) throws PrismException { - if (rewardGen == null) { - throw new PrismException("No reward generator to build reward structure"); - } - - // TODO: Transition rewards for Markov chains not supported yet if (rewardGen.rewardStructHasTransitionRewards(r)) { throw new PrismNotSupportedException("Explicit engine does not yet handle transition rewards for D/CTMCs"); } - int numStates = mc.getNumStates(); List statesList = mc.getStatesList(); StateRewardsArray rewSA = new StateRewardsArray(numStates); for (int s = 0; s < numStates; s++) { - // State rewards - double rew = 0; - Object stateIndex = null; if (rewardGen.rewardStructHasStateRewards(r)) { - if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) { - State state = statesList.get(s); - stateIndex = state; - rew = rewardGen.getStateReward(r, state); - } else if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE_INDEX)) { - stateIndex = s; - rew = rewardGen.getStateReward(r, s); - } else { - throw new PrismException("Unknown state lookup mechanism for reward generator"); - } - if (Double.isNaN(rew)) - throw new PrismException("State reward evaluates to NaN at state " + stateIndex); - if (!allowNegative && rew < 0) - throw new PrismException("State reward is negative (" + rew + ") at state " + stateIndex + ""); + double rew = getAndCheckStateReward(s, rewardGen, r, statesList); rewSA.addToStateReward(s, rew); } } @@ -139,50 +118,18 @@ public class ConstructRewards extends PrismComponent List statesList = mdp.getStatesList(); MDPRewardsSimple rewSimple = new MDPRewardsSimple(numStates); for (int s = 0; s < numStates; s++) { - // State rewards - double rew = 0; - Object stateIndex = null; if (rewardGen.rewardStructHasStateRewards(r)) { - if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) { - State state = statesList.get(s); - stateIndex = state; - rew = rewardGen.getStateReward(r, state); - } else if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE_INDEX)) { - stateIndex = s; - rew = rewardGen.getStateReward(r, s); - } else { - throw new PrismException("Unknown state lookup mechanism for reward generator"); - } - if (Double.isNaN(rew)) - throw new PrismException("State reward evaluates to NaN at state " + stateIndex); - if (!allowNegative && rew < 0) - throw new PrismException("State reward is negative (" + rew + ") at state " + stateIndex + ""); + double rew = getAndCheckStateReward(s, rewardGen, r, statesList); rewSimple.addToStateReward(s, rew); } - // State-action rewards if (rewardGen.rewardStructHasTransitionRewards(r)) { + // Don't add rewards to transitions added to "fix" deadlock states if (mdp.isDeadlockState(s)) { - // As state s is a deadlock state, any outgoing transition - // was added to "fix" the deadlock and thus does not get a reward. - // Skip to next state continue; } int numChoices = mdp.getNumChoices(s); for (int k = 0; k < numChoices; k++) { - if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) { - State state = statesList.get(s); - stateIndex = state; - rew = rewardGen.getStateActionReward(r, state, mdp.getAction(s, k)); - } else if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE_INDEX)) { - stateIndex = s; - rew = rewardGen.getStateActionReward(r, s, mdp.getAction(s, k)); - } else { - throw new PrismException("Unknown state lookup mechanism for reward generator"); - } - if (Double.isNaN(rew)) - throw new PrismException("Transition reward evaluates to NaN at state " + stateIndex); - if (!allowNegative && rew < 0) - throw new PrismException("Transition reward is negative (" + rew + ") at state " + stateIndex + ""); + double rew = getAndCheckStateActionReward(s, mdp.getAction(s, k), rewardGen, r, statesList); rewSimple.addToTransitionReward(s, k, rew); } } @@ -190,6 +137,58 @@ public class ConstructRewards extends PrismComponent return rewSimple; } + /** + * Get a state reward for a specific state and reward structure from a RewardGenerator. + * Also check that the state reward is legal. Throw an exception if not. + * @param s The index of the state + * @param rewardGen The RewardGenerator defining the rewards + * @param r The index of the reward structure to build + * @param statesLists List of states (maybe needed for state look up) + */ + private double getAndCheckStateReward(int s, RewardGenerator rewardGen, int r, List statesList) throws PrismException + { + double rew = 0; + Object stateIndex = null; + if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) { + State state = statesList.get(s); + stateIndex = state; + rew = rewardGen.getStateReward(r, state); + } else if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE_INDEX)) { + stateIndex = s; + rew = rewardGen.getStateReward(r, s); + } else { + throw new PrismException("Unknown reward lookup mechanism for reward generator"); + } + checkStateReward(rew, stateIndex, null); + return rew; + } + + /** + * Get a state reward for a specific state and reward structure from a RewardGenerator. + * Also check that the state reward is legal. Throw an exception if not. + * @param s The index of the state + * @param rewardGen The RewardGenerator defining the rewards + * @param r The index of the reward structure to build + * @param statesLists List of states (maybe needed for state look up) + */ + private double getAndCheckStateActionReward(int s, Object action, RewardGenerator rewardGen, int r, List statesList) throws PrismException + { + double rew = 0; + Object stateIndex = null; + if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE)) { + State state = statesList.get(s); + stateIndex = state; + rew = rewardGen.getStateActionReward(r, state, action); + } else if (rewardGen.isRewardLookupSupported(RewardLookup.BY_STATE_INDEX)) { + stateIndex = s; + rew = rewardGen.getStateActionReward(r, s, action); + } else { + throw new PrismException("Unknown reward lookup mechanism for reward generator"); + } + checkTransitionReward(rew, stateIndex, null); + return rew; + } + /** * Construct rewards from a model and reward structure. * @param model The model @@ -222,16 +221,12 @@ public class ConstructRewards extends PrismComponent int i, j, n, numStates; if (rewStr.getNumTransItems() > 0) { - // TODO throw new PrismNotSupportedException("Explicit engine does not yet handle transition rewards for D/CTMCs"); } // Special case: constant rewards if (rewStr.getNumStateItems() == 1 && Expression.isTrue(rewStr.getStates(0)) && rewStr.getReward(0).isConstant()) { double rew = rewStr.getReward(0).evaluateDouble(constantValues); - if (Double.isNaN(rew)) - throw new PrismLangException("Reward structure evaluates to NaN (at any state)", rewStr.getReward(0)); - if (!allowNegative && rew < 0) - throw new PrismLangException("Reward structure evaluates to " + rew + " (at any state), negative rewards not allowed", rewStr.getReward(0)); + checkStateReward(rew, null, rewStr.getReward(0)); return new StateRewardsConstant(rew); } // Normal: state rewards @@ -245,10 +240,7 @@ public class ConstructRewards extends PrismComponent for (j = 0; j < numStates; j++) { if (guard.evaluateBoolean(constantValues, statesList.get(j))) { double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j)); - if (Double.isNaN(rew)) - throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i)); - if (!allowNegative && rew < 0) - throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i)); + checkStateReward(rew, statesList.get(j), rewStr.getReward(i)); rewSA.addToStateReward(j, rew); } } @@ -274,10 +266,7 @@ public class ConstructRewards extends PrismComponent // Special case: constant state rewards if (rewStr.getNumStateItems() == 1 && Expression.isTrue(rewStr.getStates(0)) && rewStr.getReward(0).isConstant()) { double rew = rewStr.getReward(0).evaluateDouble(constantValues); - if (Double.isNaN(rew)) - throw new PrismLangException("Reward structure evaluates to NaN (at any state)", rewStr.getReward(0)); - if (!allowNegative && rew < 0) - throw new PrismLangException("Reward structure evaluates to " + rew + " (at any state), negative rewards not allowed", rewStr.getReward(0)); + checkStateReward(rew, null, rewStr.getReward(0)); return new StateRewardsConstant(rew); } // Normal: state and transition rewards @@ -305,10 +294,7 @@ public class ConstructRewards extends PrismComponent mdpAction = mdp.getAction(j, k); if (mdpAction == null ? (action.isEmpty()) : mdpAction.equals(action)) { double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j)); - if (Double.isNaN(rew)) - throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i)); - if (!allowNegative && rew < 0) - throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i)); + checkTransitionReward(rew, statesList.get(j), rewStr.getReward(i)); rewSimple.addToTransitionReward(j, k, rew); } } @@ -316,10 +302,7 @@ public class ConstructRewards extends PrismComponent // State reward else { double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j)); - if (Double.isNaN(rew)) - throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i)); - if (!allowNegative && rew < 0) - throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i)); + checkStateReward(rew, statesList.get(j), rewStr.getReward(i)); rewSimple.addToStateReward(j, rew); } } @@ -360,9 +343,7 @@ public class ConstructRewards extends PrismComponent ss = s.split(" "); i = Integer.parseInt(ss[0]); reward = Double.parseDouble(ss[1]); - if (!allowNegative && reward < 0) { - throw new PrismLangException("Found state reward " + reward + " at state " + i +", negative rewards not allowed"); - } + checkStateReward(reward, i, null); rewSA.setStateReward(i, reward); } s = in.readLine(); @@ -413,9 +394,7 @@ public class ConstructRewards extends PrismComponent ss = s.split(" "); i = Integer.parseInt(ss[0]); reward = Double.parseDouble(ss[1]); - if (!allowNegative && reward < 0) { - throw new PrismLangException("Found state reward " + reward + " at state " + i +", negative rewards not allowed"); - } + checkStateReward(reward, i, null); rs.setStateReward(i, reward); } s = in.readLine(); @@ -447,9 +426,7 @@ public class ConstructRewards extends PrismComponent i = Integer.parseInt(ss[0]); j = Integer.parseInt(ss[1]); reward = Double.parseDouble(ss[3]); - if (!allowNegative && reward < 0) { - throw new PrismLangException("Found transition reward " + reward + " at state " + i +", action " + j +", negative rewards not allowed"); - } + checkTransitionReward(reward, i, null); rs.setTransitionReward(i, j, reward); } s = in.readLine(); @@ -465,4 +442,61 @@ public class ConstructRewards extends PrismComponent return rs; } + + /** + * Check that a state reward is legal. Throw an exception if not. + * Optionally, provide a state where the error occurs (as an Object), + * and/or a pointer to where the error occurs syntactically (as an ASTElement) + * @param rew The reward value + * @param stateIndex The index of the state, for error reporting (optional) + * @param ast Where the error occurred, for error reporting (optional) + */ + private void checkStateReward(double rew, Object stateIndex, ASTElement ast) throws PrismException + { + String error = null; + if (Double.isNaN(rew)) { + error = "State reward evaluates to NaN"; + } else if (!allowNegative && rew < 0) { + error = "State reward is negative (" + rew + ")"; + } + if (error != null) { + if (stateIndex != null) { + error += " at state " + stateIndex; + } + if (ast != null) { + throw new PrismLangException(error, ast); + } else { + throw new PrismException(error); + } + } + } + + /** + * Check that a state reward is legal. Throw an exception if not. + * @param rew The reward value + * Optionally, provide a state where the error occurs (as an Object), + * and/or a pointer to where the error occurs syntactically (as an ASTElement) + * @param rew The reward value + * @param stateIndex The index of the state, for error reporting (optional) + * @param ast Where the error occurred, for error reporting (optional) + */ + private void checkTransitionReward(double rew, Object stateIndex, ASTElement ast) throws PrismException + { + String error = null; + if (Double.isNaN(rew)) { + error = "Transition reward evaluates to NaN"; + } else if (!allowNegative && rew < 0) { + error = "Transition reward is negative (" + rew + ")"; + } + if (error != null) { + if (stateIndex != null) { + error += " at state " + stateIndex; + } + if (ast != null) { + throw new PrismLangException(error, ast); + } else { + throw new PrismException(error); + } + } + } }