diff --git a/prism/src/explicit/ProbModelChecker.java b/prism/src/explicit/ProbModelChecker.java index 9e8c23ea..a891bf34 100644 --- a/prism/src/explicit/ProbModelChecker.java +++ b/prism/src/explicit/ProbModelChecker.java @@ -875,12 +875,26 @@ public class ProbModelChecker extends NonProbModelChecker } /** - * Construct rewards from a reward structure and a model. + * Construct rewards from a (non-negative) reward structure and a model. */ protected Rewards constructRewards(Model model, RewardStruct rewStruct) throws PrismException + { + return constructRewards(model, rewStruct, false); + } + + /** + * Construct rewards from a reward structure and a model. + *
+ * If {@code allowNegativeRewards} is true, the rewards may be positive and negative, i.e., weights. + */ + protected Rewards constructRewards(Model model, RewardStruct rewStruct, boolean allowNegativeRewards) throws PrismException { Rewards rewards; ConstructRewards constructRewards = new ConstructRewards(mainLog); + + if (allowNegativeRewards) + constructRewards.allowNegativeRewards(); + switch (model.getModelType()) { case CTMC: case DTMC: @@ -894,7 +908,7 @@ public class ProbModelChecker extends NonProbModelChecker } return rewards; } - + /** * Compute rewards for the contents of an R operator. */ diff --git a/prism/src/explicit/rewards/ConstructRewards.java b/prism/src/explicit/rewards/ConstructRewards.java index b163ad31..401936c7 100644 --- a/prism/src/explicit/rewards/ConstructRewards.java +++ b/prism/src/explicit/rewards/ConstructRewards.java @@ -50,6 +50,9 @@ public class ConstructRewards { protected PrismLog mainLog; + /** Allow negative rewards, i.e., weights. Defaults to false. */ + protected boolean allowNegative = false; + public ConstructRewards() { this(new PrismFileLog("stdout")); @@ -60,6 +63,12 @@ public class ConstructRewards this.mainLog = mainLog; } + /** Set flag that negative rewards are allowed, i.e., weights */ + public void allowNegativeRewards() + { + allowNegative = true; + } + /** * Construct rewards from a model and reward structure. * @param model The model @@ -100,6 +109,8 @@ public class ConstructRewards double rew = rewStr.getReward(0).evaluateDouble(constantValues); if (Double.isNaN(rew)) throw new PrismLangException("Reward structure evaluates to NaN (at any state)", rewStr.getReward(0)); + if (!allowNegative && rew < 0) + throw new PrismLangException("Reward structure evaluates to " + rew + " (at any state), negative rewards not allowed", rewStr.getReward(0)); return new StateRewardsConstant(rew); } // Normal: state rewards @@ -115,6 +126,8 @@ public class ConstructRewards double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j)); if (Double.isNaN(rew)) throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i)); + if (!allowNegative && rew < 0) + throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i)); rewSA.addToStateReward(j, rew); } } @@ -142,6 +155,8 @@ public class ConstructRewards double rew = rewStr.getReward(0).evaluateDouble(constantValues); if (Double.isNaN(rew)) throw new PrismLangException("Reward structure evaluates to NaN (at any state)", rewStr.getReward(0)); + if (!allowNegative && rew < 0) + throw new PrismLangException("Reward structure evaluates to " + rew + " (at any state), negative rewards not allowed", rewStr.getReward(0)); return new StateRewardsConstant(rew); } // Normal: state and transition rewards @@ -165,6 +180,8 @@ public class ConstructRewards double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j)); if (Double.isNaN(rew)) throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i)); + if (!allowNegative && rew < 0) + throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i)); rewSimple.addToTransitionReward(j, k, rew); } } @@ -174,6 +191,8 @@ public class ConstructRewards double rew = rewStr.getReward(i).evaluateDouble(constantValues, statesList.get(j)); if (Double.isNaN(rew)) throw new PrismLangException("Reward structure evaluates to NaN at state " + statesList.get(j), rewStr.getReward(i)); + if (!allowNegative && rew < 0) + throw new PrismLangException("Reward structure evaluates to " + rew + " at state " + statesList.get(j) +", negative rewards not allowed", rewStr.getReward(i)); rewSimple.addToStateReward(j, rew); } } @@ -191,21 +210,18 @@ public class ConstructRewards */ public MCRewards buildMCRewardsFromPrismExplicit(DTMC mc, File rews, File rewt) throws PrismException { - BufferedReader in; String s, ss[]; int i, lineNum = 0; double reward; StateRewardsArray rewSA = new StateRewardsArray(mc.getNumStates()); - try { - if (rews != null) { - // Open state rewards file - in = new BufferedReader(new FileReader(rews)); + if (rews != null) { + // Open state rewards file, automatic close + try (BufferedReader in = new BufferedReader(new FileReader(rews))) { // Ignore first line s = in.readLine(); lineNum = 1; if (s == null) { - in.close(); throw new PrismException("Missing first line of state rewards file"); } // Go though list of state rewards in file @@ -217,18 +233,19 @@ public class ConstructRewards ss = s.split(" "); i = Integer.parseInt(ss[0]); reward = Double.parseDouble(ss[1]); + if (!allowNegative && reward < 0) { + throw new PrismLangException("Found state reward " + reward + " at state " + i +", negative rewards not allowed"); + } rewSA.setStateReward(i, reward); } s = in.readLine(); lineNum++; } - // Close file - in.close(); + } catch (IOException e) { + throw new PrismException("Could not read state rewards from file \"" + rews + "\"" + e); + } catch (NumberFormatException e) { + throw new PrismException("Problem in state rewards file (line " + lineNum + ") for MDP"); } - } catch (IOException e) { - throw new PrismException("Could not read state rewards from file \"" + rews + "\"" + e); - } catch (NumberFormatException e) { - throw new PrismException("Problem in state rewards file (line " + lineNum + ") for MDP"); } if (rewt != null) { @@ -246,21 +263,18 @@ public class ConstructRewards */ public MDPRewards buildMDPRewardsFromPrismExplicit(MDP mdp, File rews, File rewt) throws PrismException { - BufferedReader in; String s, ss[]; int i, j, lineNum = 0; double reward; MDPRewardsSimple rs = new MDPRewardsSimple(mdp.getNumStates()); - try { - if (rews != null) { - // Open state rewards file - in = new BufferedReader(new FileReader(rews)); + if (rews != null) { + // Open state rewards file, automatic close + try (BufferedReader in = new BufferedReader(new FileReader(rews))) { // Ignore first line s = in.readLine(); lineNum = 1; if (s == null) { - in.close(); throw new PrismException("Missing first line of state rewards file"); } // Go though list of state rewards in file @@ -272,29 +286,28 @@ public class ConstructRewards ss = s.split(" "); i = Integer.parseInt(ss[0]); reward = Double.parseDouble(ss[1]); + if (!allowNegative && reward < 0) { + throw new PrismLangException("Found state reward " + reward + " at state " + i +", negative rewards not allowed"); + } rs.setStateReward(i, reward); } s = in.readLine(); lineNum++; } - // Close file - in.close(); + } catch (IOException e) { + throw new PrismException("Could not read state rewards from file \"" + rews + "\"" + e); + } catch (NumberFormatException e) { + throw new PrismException("Problem in state rewards file (line " + lineNum + ") for MDP"); } - } catch (IOException e) { - throw new PrismException("Could not read state rewards from file \"" + rews + "\"" + e); - } catch (NumberFormatException e) { - throw new PrismException("Problem in state rewards file (line " + lineNum + ") for MDP"); } - try { - if (rewt != null) { - // Open transition rewards file - in = new BufferedReader(new FileReader(rewt)); + if (rewt != null) { + // Open transition rewards file, automatic close + try (BufferedReader in = new BufferedReader(new FileReader(rewt))) { // Ignore first line s = in.readLine(); lineNum = 1; if (s == null) { - in.close(); throw new PrismException("Missing first line of transition rewards file"); } // Go though list of transition rewards in file @@ -307,20 +320,23 @@ public class ConstructRewards i = Integer.parseInt(ss[0]); j = Integer.parseInt(ss[1]); reward = Double.parseDouble(ss[3]); + if (!allowNegative && reward < 0) { + throw new PrismLangException("Found transition reward " + reward + " at state " + i +", action " + j +", negative rewards not allowed"); + } rs.setTransitionReward(i, j, reward); } s = in.readLine(); lineNum++; } - // Close file - in.close(); + + } catch (IOException e) { + throw new PrismException("Could not read transition rewards from file \"" + rewt + "\"" + e); + } catch (NumberFormatException e) { + throw new PrismException("Problem in transition rewards file (line " + lineNum + ") for MDP"); } - } catch (IOException e) { - throw new PrismException("Could not read transition rewards from file \"" + rewt + "\"" + e); - } catch (NumberFormatException e) { - throw new PrismException("Problem in transition rewards file (line " + lineNum + ") for MDP"); } return rs; } + }