From b8a78c4031387bd9a329ab3939b62eaa7e6afeb2 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Mon, 25 Jul 2011 20:02:50 +0000 Subject: [PATCH] Updates to explicit engine from prism-qar (Vojta): * new rewards code for STPGs * additional utility methods * strip out some old reward stuff git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@3323 bbc10eb1-c90d-0410-af57-cb519fbb1720 --- prism/src/explicit/Distribution.java | 6 +- prism/src/explicit/DistributionSet.java | 19 +++ prism/src/explicit/MDP.java | 5 - prism/src/explicit/MDPSimple.java | 139 ++++++++++++++---- prism/src/explicit/MDPSparse.java | 16 +- prism/src/explicit/STPG.java | 9 +- prism/src/explicit/STPGAbstrSimple.java | 91 +++++++----- prism/src/explicit/STPGModelChecker.java | 17 ++- prism/src/explicit/rewards/MCRewards.java | 2 +- prism/src/explicit/rewards/MDPRewards.java | 2 +- .../explicit/rewards/MDPRewardsSimple.java | 6 + prism/src/explicit/rewards/Rewards.java | 9 ++ prism/src/explicit/rewards/STPGRewards.java | 78 ++++++++++ .../explicit/rewards/STPGRewardsConstant.java | 42 ++++++ .../explicit/rewards/STPGRewardsSimple.java | 132 +++++++++++++++++ .../rewards/StateTransitionRewardsSimple.java | 66 +++++++++ 16 files changed, 535 insertions(+), 104 deletions(-) create mode 100644 prism/src/explicit/rewards/Rewards.java create mode 100644 prism/src/explicit/rewards/STPGRewards.java create mode 100644 prism/src/explicit/rewards/STPGRewardsConstant.java create mode 100644 prism/src/explicit/rewards/STPGRewardsSimple.java create mode 100644 prism/src/explicit/rewards/StateTransitionRewardsSimple.java diff --git a/prism/src/explicit/Distribution.java b/prism/src/explicit/Distribution.java index 048c148c..9f5061c3 100644 --- a/prism/src/explicit/Distribution.java +++ b/prism/src/explicit/Distribution.java @@ -38,7 +38,7 @@ import prism.PrismUtils; public class Distribution implements Iterable> { private HashMap map; - + /** * Create an empty distribution. */ @@ -242,7 +242,7 @@ public class Distribution implements Iterable> } return true; } - + @Override public int hashCode() { @@ -253,6 +253,6 @@ public class Distribution implements Iterable> @Override public String toString() { - return "" + map; + return map.toString(); } } diff --git a/prism/src/explicit/DistributionSet.java b/prism/src/explicit/DistributionSet.java index 15a21b25..c0b07a4a 100644 --- a/prism/src/explicit/DistributionSet.java +++ b/prism/src/explicit/DistributionSet.java @@ -64,4 +64,23 @@ public class DistributionSet extends LinkedHashSet { return super.equals(o) && action == ((DistributionSet) o).action; } + + /** + * Returns the index of the distribution {@code d}, i.e. the position in the order given by the iterator of this set + * @param d the distribution to look up + * @return the index of {@code d} or -1 if not found + */ + public int indexOf(Distribution d) + { + int i = -1; + for(Distribution itDist : this) + { + i++; + if (itDist.equals(d)) + { + return i; + } + } + return -1; + } } diff --git a/prism/src/explicit/MDP.java b/prism/src/explicit/MDP.java index 0cee2140..5ca5b8fb 100644 --- a/prism/src/explicit/MDP.java +++ b/prism/src/explicit/MDP.java @@ -51,11 +51,6 @@ public interface MDP extends Model */ public Object getAction(int s, int i); - /** - * Get the transition reward (if any) for choice i of state s. - */ - public double getTransitionReward(int s, int i); - /** * Perform a single step of precomputation algorithm Prob0, i.e., for states i in {@code subset}, * set bit i of {@code result} iff, for all/some choices, diff --git a/prism/src/explicit/MDPSimple.java b/prism/src/explicit/MDPSimple.java index e1583477..d87c0e8f 100644 --- a/prism/src/explicit/MDPSimple.java +++ b/prism/src/explicit/MDPSimple.java @@ -32,6 +32,7 @@ import java.util.Map.Entry; import java.io.*; import explicit.rewards.MDPRewards; +import explicit.rewards.MDPRewardsSimple; import prism.ModelType; import prism.PrismException; @@ -55,8 +56,9 @@ public class MDPSimple extends ModelSimple implements MDP // Rewards // (if transRewardsConstant non-null, use this for all transitions; otherwise, use transRewards list) // (for transRewards, null in element s means no rewards for that state) - protected Double transRewardsConstant; - protected List> transRewards; + //protected Double transRewardsConstant; + //protected List> transRewards; + protected List stateRewards; // Flag: allow duplicates in distribution sets? protected boolean allowDupes = false; @@ -156,7 +158,7 @@ public class MDPSimple extends ModelSimple implements MDP trans.add(new ArrayList()); } actions = null; - clearAllRewards(); + //clearAllRewards(); } @Override @@ -175,8 +177,6 @@ public class MDPSimple extends ModelSimple implements MDP trans.get(s).clear(); if (actions != null && actions.get(s) != null) actions.get(s).clear(); - if (transRewards != null && transRewards.get(s) != null) - transRewards.get(s).clear(); } @Override @@ -193,8 +193,6 @@ public class MDPSimple extends ModelSimple implements MDP trans.add(new ArrayList()); if (actions != null) actions.add(null); - if (transRewards != null) - transRewards.add(null); numStates++; } } @@ -265,6 +263,86 @@ public class MDPSimple extends ModelSimple implements MDP initialStates.add(0); } + public MDPRewardsSimple buildRewardsFromPrismExplicit(String rews, String rewt) throws PrismException + { + BufferedReader in; + Distribution distr; + String s, ss[]; + int i, j, iLast, kLast, n, lineNum = 0; + double reward; + MDPRewardsSimple rs = new MDPRewardsSimple(this.getNumStates()); + + try { + /* WE DO NOT SUPPORT STATE REWARDS YET + // Open rews file + in = new BufferedReader(new FileReader(new File(rews))); + // Parse first line to get num states + s = in.readLine(); + lineNum = 1; + if (s == null) + throw new PrismException("Missing first line of .rews file"); + ss = s.split(" "); + n = Integer.parseInt(ss[0]); + // Go though list of transitions in file + iLast = -1; + kLast = -1; + distr = null; + s = in.readLine(); + lineNum++; + while (s != null) { + s = s.trim(); + if (s.length() > 0) { + ss = s.split(" "); + i = Integer.parseInt(ss[0]); + reward = Double.parseDouble(ss[1]); + this.setStateReward(i,reward); + } + s = in.readLine(); + lineNum++; + } + // Close file + in.close(); + */ + + //Open rewt file + in = new BufferedReader(new FileReader(new File(rewt))); + // Parse first line to get num states + s = in.readLine(); + lineNum = 1; + if (s == null) + throw new PrismException("Missing first line of .rewt file"); + ss = s.split(" "); + n = Integer.parseInt(ss[0]); + // Go though list of transitions in file + iLast = -1; + kLast = -1; + distr = null; + s = in.readLine(); + lineNum++; + while (s != null) { + s = s.trim(); + if (s.length() > 0) { + ss = s.split(" "); + i = Integer.parseInt(ss[0]); + j = Integer.parseInt(ss[1]); + reward = Double.parseDouble(ss[2]); + rs.setTransitionReward(i, j, reward); + } + s = in.readLine(); + lineNum++; + } + // Close file + in.close(); + return rs; + } catch (IOException e) { + System.out.println(e); + System.exit(1); + return null; //will never happen, it's here just to make the compiler happy + } catch (NumberFormatException e) { + throw new PrismException("Problem in .rewt file (line " + lineNum + ") for MDP"); + } + } + // Mutators (other) /** @@ -292,8 +370,6 @@ public class MDPSimple extends ModelSimple implements MDP if (actions != null && actions.get(s) != null) actions.get(s).add(null); // Add zero reward if necessary - if (transRewards != null && transRewards.get(s) != null) - transRewards.get(s).add(0.0); // Update stats numDistrs++; maxNumDistrs = Math.max(maxNumDistrs, set.size()); @@ -304,11 +380,11 @@ public class MDPSimple extends ModelSimple implements MDP /** * Remove all rewards from the model */ - public void clearAllRewards() + /*public void clearAllRewards() { transRewards = null; transRewardsConstant = null; - } + }*/ /** * Set the action label for choice i in some state s. @@ -337,20 +413,34 @@ public class MDPSimple extends ModelSimple implements MDP /** * Set a constant reward for all transitions */ - public void setConstantTransitionReward(double r) + /*public void setConstantTransitionReward(double r) { // This replaces any other reward definitions transRewards = null; // Store as a Double (because we use null to check for its existence) transRewardsConstant = new Double(r); - } + }*/ + public void setStateReward(int s, double r) + { + if(stateRewards == null) { + stateRewards = new ArrayList(numStates); + for (int i = 0; i< numStates; i++) + { + stateRewards.add(0.0); + } + } + + stateRewards.set(s,r); + } + /** * Set the reward for choice i in some state s to r. */ public void setTransitionReward(int s, int i, double r) { - // This would replace any constant reward definition, if it existed + //this.trans.get(s).get(i).setReward(r); + /*// This would replace any constant reward definition, if it existed transRewardsConstant = null; // If no rewards array created yet, create it if (transRewards == null) { @@ -368,9 +458,9 @@ public class MDPSimple extends ModelSimple implements MDP transRewards.set(s, list); } // Set reward - transRewards.get(s).set(i, r); + transRewards.get(s).set(i, r);*/ } - + // Accessors (for ModelSimple) @Override @@ -604,17 +694,6 @@ public class MDPSimple extends ModelSimple implements MDP return trans.get(s).get(i).iterator(); } - @Override - public double getTransitionReward(int s, int i) - { - List list; - if (transRewardsConstant != null) - return transRewardsConstant; - if (transRewards == null || (list = transRewards.get(s)) == null) - return 0.0; - return list.get(i); - } - @Override public void prob0step(BitSet subset, BitSet u, boolean forall, BitSet result) { @@ -900,7 +979,7 @@ public class MDPSimple extends ModelSimple implements MDP for (Distribution distr : step) { j++; // Compute sum for this distribution - d = mdpRewards != null ? mdpRewards.getTransitionReward(s, j) : getTransitionReward(s, j); + d = mdpRewards.getTransitionReward(s, j); for (Map.Entry e : distr) { k = (Integer) e.getKey(); prob = (Double) e.getValue(); @@ -931,7 +1010,7 @@ public class MDPSimple extends ModelSimple implements MDP for (Distribution distr : step) { j++; // Compute sum for this distribution - d = mdpRewards != null ? mdpRewards.getTransitionReward(s, j) : getTransitionReward(s, j); + d = mdpRewards.getTransitionReward(s, j); for (Map.Entry e : distr) { k = (Integer) e.getKey(); prob = (Double) e.getValue(); @@ -1023,8 +1102,6 @@ public class MDPSimple extends ModelSimple implements MDP s += trans.get(i).get(j); } s += "]"; - if (transRewards != null) - s += transRewards.get(i); } s += " ]\n"; return s; diff --git a/prism/src/explicit/MDPSparse.java b/prism/src/explicit/MDPSparse.java index e5db6af1..c06a99ce 100644 --- a/prism/src/explicit/MDPSparse.java +++ b/prism/src/explicit/MDPSparse.java @@ -682,18 +682,6 @@ public class MDPSparse extends ModelSparse implements MDP return actions[rowStarts[s] + i]; } - @Override - public double getTransitionReward(int s, int i) - { - // TODO - return 0;/*List list; - if (transRewardsConstant != null) - return transRewardsConstant; - if (transRewards == null || (list = transRewards.get(s)) == null) - return 0.0; - return list.get(i);*/ - } - @Override public void prob0step(BitSet subset, BitSet u, boolean forall, BitSet result) { @@ -1006,7 +994,7 @@ public class MDPSparse extends ModelSparse implements MDP h1 = rowStarts[s + 1]; for (j = l1; j < h1; j++) { // Compute sum for this distribution - d = mdpRewards != null ? mdpRewards.getTransitionReward(s, j - l1) : getTransitionReward(s, j - l1); + d = mdpRewards.getTransitionReward(s, j - l1); l2 = choiceStarts[j]; h2 = choiceStarts[j + 1]; for (k = l2; k < h2; k++) { @@ -1035,7 +1023,7 @@ public class MDPSparse extends ModelSparse implements MDP h1 = rowStarts[s + 1]; for (j = l1; j < h1; j++) { // Compute sum for this distribution - d = mdpRewards != null ? mdpRewards.getTransitionReward(s, j) : getTransitionReward(s, j); + d = mdpRewards.getTransitionReward(s, j); l2 = choiceStarts[j]; h2 = choiceStarts[j + 1]; for (k = l2; k < h2; k++) { diff --git a/prism/src/explicit/STPG.java b/prism/src/explicit/STPG.java index 4e6d3fab..a3f13b41 100644 --- a/prism/src/explicit/STPG.java +++ b/prism/src/explicit/STPG.java @@ -27,6 +27,7 @@ package explicit; import java.util.*; +import explicit.rewards.STPGRewards; /** * Interface for classes that provide (read) access to an explicit-state stochastic two-player game (STPG), @@ -41,7 +42,7 @@ public interface STPG extends Model /** * Get the transition reward (if any) for choice i of state s. */ - public double getTransitionReward(int s, int i); + //public double getTransitionReward(int s, int i); /** * Perform a single step of precomputation algorithm Prob0, i.e., for states i in {@code subset}, @@ -140,7 +141,7 @@ public interface STPG extends Model * @param complement If true, {@code subset} is taken to be its complement (ignored if {@code subset} is null) * @param adv Storage for adversary choice indices (ignored if null) */ - public void mvMultRewMinMax(double vect[], boolean min1, boolean min2, double result[], BitSet subset, boolean complement, int adv[]); + public void mvMultRewMinMax(double vect[], STPGRewards rewards, boolean min1, boolean min2, double result[], BitSet subset, boolean complement, int adv[]); /** * Do a single row of matrix-vector multiplication and sum of action reward followed by min/max. @@ -151,7 +152,7 @@ public interface STPG extends Model * @param min2 Min or max for player 2 (true=min, false=max) * @param adv Storage for adversary choice indices (ignored if null) */ - public double mvMultRewMinMaxSingle(int s, double vect[], boolean min1, boolean min2, int adv[]); + public double mvMultRewMinMaxSingle(int s, double vect[], STPGRewards rewards, boolean min1, boolean min2, int adv[]); /** * Determine which choices result in min/max after a single row of matrix-vector multiplication and sum of action reward. @@ -161,5 +162,5 @@ public interface STPG extends Model * @param min2 Min or max for player 2 (true=min, false=max) * @param val Min or max value to match */ - public List mvMultRewMinMaxSingleChoices(int s, double vect[], boolean min1, boolean min2, double val); + public List mvMultRewMinMaxSingleChoices(int s, double vect[], STPGRewards rewards, boolean min1, boolean min2, double val); } diff --git a/prism/src/explicit/STPGAbstrSimple.java b/prism/src/explicit/STPGAbstrSimple.java index 94b301a6..b5f7a701 100644 --- a/prism/src/explicit/STPGAbstrSimple.java +++ b/prism/src/explicit/STPGAbstrSimple.java @@ -29,6 +29,8 @@ package explicit; import java.util.*; import java.io.*; +import explicit.rewards.STPGRewards; + import prism.ModelType; import prism.PrismException; import prism.PrismUtils; @@ -44,10 +46,6 @@ public class STPGAbstrSimple extends ModelSimple implements STPG // Transition function (Steps) protected List> trans; - // Rewards - protected List> transRewards; - protected Double transRewardsConstant; - // Flag: allow dupes in distribution sets? public boolean allowDupes = false; @@ -104,7 +102,7 @@ public class STPGAbstrSimple extends ModelSimple implements STPG for (int i = 0; i < numStates; i++) { trans.add(new ArrayList()); } - clearAllRewards(); + //clearAllRewards(); } @Override @@ -270,27 +268,27 @@ public class STPGAbstrSimple extends ModelSimple implements STPG /** * Remove all rewards from the model */ - public void clearAllRewards() + /*public void clearAllRewards() { transRewards = null; transRewardsConstant = null; - } + }*/ /** * Set a constant reward for all transitions */ - public void setConstantTransitionReward(double r) + /*public void setConstantTransitionReward(double r) { // This replaces any other reward definitions transRewards = null; // Store as a Double (because we use null to check for its existence) transRewardsConstant = new Double(r); - } + }*/ /** * Set the reward for choice i in some state s to r. */ - public void setTransitionReward(int s, int i, double r) + /*public void setTransitionReward(int s, int i, double r) { // This would replace any constant reward definition, if it existed transRewardsConstant = null; @@ -311,7 +309,7 @@ public class STPGAbstrSimple extends ModelSimple implements STPG } // Set reward transRewards.get(s).set(i, r); - } + }*/ // Accessors (for ModelSimple) @@ -515,7 +513,7 @@ public class STPGAbstrSimple extends ModelSimple implements STPG return null; } - @Override + /*@Override public double getTransitionReward(int s, int i) { List list; @@ -524,7 +522,7 @@ public class STPGAbstrSimple extends ModelSimple implements STPG if (transRewards == null || (list = transRewards.get(s)) == null) return 0.0; return list.get(i); - } + }*/ @Override public void prob0step(BitSet subset, BitSet u, boolean forall1, boolean forall2, BitSet result) @@ -762,62 +760,72 @@ public class STPGAbstrSimple extends ModelSimple implements STPG } @Override - public void mvMultRewMinMax(double vect[], boolean min1, boolean min2, double result[], BitSet subset, boolean complement, int adv[]) + public void mvMultRewMinMax(double vect[], STPGRewards rewards, boolean min1, boolean min2, double result[], BitSet subset, boolean complement, int adv[]) { int s; // Loop depends on subset/complement arguments if (subset == null) { for (s = 0; s < numStates; s++) - result[s] = mvMultRewMinMaxSingle(s, vect, min1, min2, adv); + result[s] = mvMultRewMinMaxSingle(s, vect, rewards, min1, min2, adv); } else if (complement) { for (s = subset.nextClearBit(0); s < numStates; s = subset.nextClearBit(s + 1)) - result[s] = mvMultRewMinMaxSingle(s, vect, min1, min2, adv); + result[s] = mvMultRewMinMaxSingle(s, vect, rewards, min1, min2, adv); } else { for (s = subset.nextSetBit(0); s >= 0; s = subset.nextSetBit(s + 1)) - result[s] = mvMultRewMinMaxSingle(s, vect, min1, min2, adv); + result[s] = mvMultRewMinMaxSingle(s, vect, rewards, min1, min2, adv); } } @Override - public double mvMultRewMinMaxSingle(int s, double vect[], boolean min1, boolean min2, int adv[]) + public double mvMultRewMinMaxSingle(int s, double vect[], STPGRewards rewards, boolean min1, boolean min2, int adv[]) { - int k; + int dsIter, rewIter, dIter, rewCount, k; double d, prob, minmax1, minmax2; boolean first1, first2; ArrayList step; minmax1 = 0; first1 = true; + dsIter=-1; step = trans.get(s); for (DistributionSet distrs : step) { + dsIter++; minmax2 = 0; first2 = true; + + dIter=-1; for (Distribution distr : distrs) { - // Compute sum for this distribution - d = 0.0; - for (Map.Entry e : distr) { - k = (Integer) e.getKey(); - prob = (Double) e.getValue(); - d += prob * vect[k]; + dIter++; + rewCount = rewards.getTransitionRewardCount(s, dsIter, dIter); + for(rewIter = 0; rewIter e : distr) { + k = (Integer) e.getKey(); + prob = (Double) e.getValue(); + d += prob * vect[k]; + } + // Check whether we have exceeded min/max so far + if (first2 || (min2 && d < minmax2) || (!min2 && d > minmax2)) + minmax2 = d; + first2 = false; } - // Check whether we have exceeded min/max so far - if (first2 || (min2 && d < minmax2) || (!min2 && d > minmax2)) - minmax2 = d; - first2 = false; } // Check whether we have exceeded min/max so far if (first1 || (min1 && minmax2 < minmax1) || (!min1 && minmax2 > minmax1)) minmax1 = minmax2; - first1 = false; + first1 = false; } - return minmax1; + return minmax1; } @Override - public List mvMultRewMinMaxSingleChoices(int s, double vect[], boolean min1, boolean min2, double val) + public List mvMultRewMinMaxSingleChoices(int s, double vect[], STPGRewards rewards, boolean min1, boolean min2, double val) { - int j, k; + int dsIter, rewIter, dIter, rewCount, k; double d, prob, minmax2; boolean first2; List res; @@ -826,15 +834,22 @@ public class STPGAbstrSimple extends ModelSimple implements STPG // Create data structures to store strategy res = new ArrayList(); // One row of matrix-vector operation - j = -1; + dsIter = -1; step = trans.get(s); for (DistributionSet distrs : step) { - j++; + dsIter++; minmax2 = 0; first2 = true; + + dIter = -1; for (Distribution distr : distrs) { + dIter++; + + rewCount = rewards.getTransitionRewardCount(s, dsIter, dIter); + for(rewIter = 0; rewIter e : distr) { k = (Integer) e.getKey(); prob = (Double) e.getValue(); @@ -844,13 +859,15 @@ public class STPGAbstrSimple extends ModelSimple implements STPG if (first2 || (min2 && d < minmax2) || (!min2 && d > minmax2)) minmax2 = d; first2 = false; + } } // Store strategy info if value matches //if (PrismUtils.doublesAreClose(val, d, termCritParam, termCrit == TermCrit.ABSOLUTE)) { if (PrismUtils.doublesAreClose(val, minmax2, 1e-12, false)) { - res.add(j); + res.add(dsIter); //res.add(distrs.getAction()); } + } return res; diff --git a/prism/src/explicit/STPGModelChecker.java b/prism/src/explicit/STPGModelChecker.java index 734ee34e..2e9c5a92 100644 --- a/prism/src/explicit/STPGModelChecker.java +++ b/prism/src/explicit/STPGModelChecker.java @@ -31,6 +31,7 @@ import java.util.*; import parser.ast.Expression; import parser.ast.ExpressionTemporal; import prism.*; +import explicit.rewards.STPGRewards; /** * Explicit-state model checker for two-player stochastic games (STPGs). @@ -154,7 +155,7 @@ public class STPGModelChecker extends ProbModelChecker /** * Compute rewards for the contents of an R operator. */ - protected StateValues checkRewardFormula(Model model, ExpressionTemporal expr, boolean min1, boolean min2) throws PrismException + protected StateValues checkRewardFormula(Model model, STPGRewards rewards, ExpressionTemporal expr, boolean min1, boolean min2) throws PrismException { // Assume R [F ] for now... @@ -165,7 +166,7 @@ public class STPGModelChecker extends ProbModelChecker // model check operands first target = (BitSet) checkExpression(model, expr.getOperand2()); - res = computeReachRewards((STPG) model, target, min1, min2); + res = computeReachRewards((STPG) model, rewards, target, min1, min2); rews = StateValues.createFromDoubleArray(res.soln); return rews; @@ -783,9 +784,9 @@ public class STPGModelChecker extends ProbModelChecker * @param min1 Min or max rewards for player 1 (true=min, false=max) * @param min2 Min or max rewards for player 2 (true=min, false=max) */ - public ModelCheckerResult computeReachRewards(STPG stpg, BitSet target, boolean min1, boolean min2) throws PrismException + public ModelCheckerResult computeReachRewards(STPG stpg, STPGRewards rewards, BitSet target, boolean min1, boolean min2) throws PrismException { - return computeReachRewards(stpg, target, min1, min2, null, null); + return computeReachRewards(stpg, rewards, target, min1, min2, null, null); } /** @@ -799,7 +800,7 @@ public class STPGModelChecker extends ProbModelChecker * @param known Optionally, a set of states for which the exact answer is known * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. */ - public ModelCheckerResult computeReachRewards(STPG stpg, BitSet target, boolean min1, boolean min2, double init[], BitSet known) throws PrismException + public ModelCheckerResult computeReachRewards(STPG stpg, STPGRewards rewards, BitSet target, boolean min1, boolean min2, double init[], BitSet known) throws PrismException { ModelCheckerResult res = null; BitSet inf; @@ -841,7 +842,7 @@ public class STPGModelChecker extends ProbModelChecker // Compute rewards switch (solnMethod) { case VALUE_ITERATION: - res = computeReachRewardsValIter(stpg, target, inf, min1, min2, init, known); + res = computeReachRewardsValIter(stpg, rewards, target, inf, min1, min2, init, known); break; default: throw new PrismException("Unknown STPG solution method " + solnMethod); @@ -870,7 +871,7 @@ public class STPGModelChecker extends ProbModelChecker * @param known Optionally, a set of states for which the exact answer is known * Note: if 'known' is specified (i.e. is non-null, 'init' must also be given and is used for the exact values. */ - protected ModelCheckerResult computeReachRewardsValIter(STPG stpg, BitSet target, BitSet inf, boolean min1, boolean min2, double init[], BitSet known) + protected ModelCheckerResult computeReachRewardsValIter(STPG stpg, STPGRewards rewards, BitSet target, BitSet inf, boolean min1, boolean min2, double init[], BitSet known) throws PrismException { ModelCheckerResult res; @@ -922,7 +923,7 @@ public class STPGModelChecker extends ProbModelChecker //mainLog.println(soln); iters++; // Matrix-vector multiply and min/max ops - stpg.mvMultRewMinMax(soln, min1, min2, soln2, unknown, false, null); + stpg.mvMultRewMinMax(soln, rewards, min1, min2, soln2, unknown, false, null); // Check termination done = PrismUtils.doublesAreClose(soln, soln2, termCritParam, termCrit == TermCrit.ABSOLUTE); // Swap vectors for next iter diff --git a/prism/src/explicit/rewards/MCRewards.java b/prism/src/explicit/rewards/MCRewards.java index 2cf63212..8d22a546 100644 --- a/prism/src/explicit/rewards/MCRewards.java +++ b/prism/src/explicit/rewards/MCRewards.java @@ -29,7 +29,7 @@ package explicit.rewards; /** * Classes that provide (read) access to explicit-state rewards for a Markov chain (DTMC/CTMC). */ -public interface MCRewards +public interface MCRewards extends Rewards { /** * Get the state reward for state {@code s}. diff --git a/prism/src/explicit/rewards/MDPRewards.java b/prism/src/explicit/rewards/MDPRewards.java index a34929f6..c7bbfbe5 100644 --- a/prism/src/explicit/rewards/MDPRewards.java +++ b/prism/src/explicit/rewards/MDPRewards.java @@ -29,7 +29,7 @@ package explicit.rewards; /** * Classes that provide (read) access to explicit-state rewards for an MDP. */ -public interface MDPRewards +public interface MDPRewards extends Rewards { /** * Get the state reward for state {@code s}. diff --git a/prism/src/explicit/rewards/MDPRewardsSimple.java b/prism/src/explicit/rewards/MDPRewardsSimple.java index 79ebf8ad..674ce3c9 100644 --- a/prism/src/explicit/rewards/MDPRewardsSimple.java +++ b/prism/src/explicit/rewards/MDPRewardsSimple.java @@ -104,4 +104,10 @@ public class MDPRewardsSimple implements MDPRewards return 0.0; return list.get(i); } + + @Override + public String toString() + { + return "st: " + this.stateRewards + "; tr:" + this.transRewards; + } } diff --git a/prism/src/explicit/rewards/Rewards.java b/prism/src/explicit/rewards/Rewards.java new file mode 100644 index 00000000..2f239bef --- /dev/null +++ b/prism/src/explicit/rewards/Rewards.java @@ -0,0 +1,9 @@ +package explicit.rewards; + +/** + * A dummy interface implemented by all reward classes. + */ +public interface Rewards +{ + +} diff --git a/prism/src/explicit/rewards/STPGRewards.java b/prism/src/explicit/rewards/STPGRewards.java new file mode 100644 index 00000000..cd34b56b --- /dev/null +++ b/prism/src/explicit/rewards/STPGRewards.java @@ -0,0 +1,78 @@ +//============================================================================== +// +// Copyright (c) 2002- +// Authors: +// * Dave Parker (University of Oxford) +// +//------------------------------------------------------------------------------ +// +// This file is part of PRISM. +// +// PRISM is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// PRISM is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with PRISM; if not, write to the Free Software Foundation, +// Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +//============================================================================== + +package explicit.rewards; + +/** + * Class that provide access to explicit-state rewards for an abstraction STPG. + * + * There are two type of rewards. Ones are on distribution sets and correspond + * to state rewards in the MDP (SUPPORT FOR THESE IS CURRENTLY NOT IMPLEMENTED). + * The others are on Distributions and correspond to transition rewards. Because + * each of the distributions may abstract several concrete distributions, it can + * have multiple rewards. The number of different rewards for each distribution + * can be obtained using {@link #getTransitionRewardCount(int, int, int)}, + * the rewards itself are then obtained using {@link #getTransitionReward(int, int, int, int)} + * + * + * + */ +public interface STPGRewards extends Rewards +{ + /** + * Returns the reward associated with {@code ds}th distribution for the state {@code s}. + */ + public double getDistributionSetReward(int s, int ds); + + /** + * Removes all rewards for DistributionSets and Distributions associated with state {@code s}. + */ + public void clearRewards(int s); + + /** + * Returns the number of different rewards associated with {@code d}th distribution of + * {@code ds}th distribution set of state {@code s} + * + * @param s State + * @param ds Distribution set + * @param d Distribution + * @return Number of different rewards associated with the distribution + */ + public int getTransitionRewardCount(int s, int ds, int d); + + /** + * + * Returns {@code i}th reward of {@code d}th distribution of + * {@code ds}th distribution set of state {@code s} + * + * @param s State + * @param ds Distribution set + * @param d Distribution + * @param i Index of the reward to return + * @return The reward. + */ + public double getTransitionReward(int s, int ds, int d, int i); +} diff --git a/prism/src/explicit/rewards/STPGRewardsConstant.java b/prism/src/explicit/rewards/STPGRewardsConstant.java new file mode 100644 index 00000000..ca347041 --- /dev/null +++ b/prism/src/explicit/rewards/STPGRewardsConstant.java @@ -0,0 +1,42 @@ +package explicit.rewards; + +/** + * Explicit storage of constant game rewards. + */ +public class STPGRewardsConstant implements STPGRewards +{ + private double dsReward; + private double transReward; + + public STPGRewardsConstant(double dsReward, double transReward) + { + this.dsReward = dsReward; + this.transReward = transReward; + } + + @Override + public double getDistributionSetReward(int s, int d) + { + return this.dsReward; + } + + @Override + public int getTransitionRewardCount(int s, int ds, int d) + { + return 1; + } + + @Override + public double getTransitionReward(int s, int d, int t, int i) + { + return this.transReward; + } + + @Override + public void clearRewards(int s) + { + //do nothing + return; + } + +} diff --git a/prism/src/explicit/rewards/STPGRewardsSimple.java b/prism/src/explicit/rewards/STPGRewardsSimple.java new file mode 100644 index 00000000..2151f229 --- /dev/null +++ b/prism/src/explicit/rewards/STPGRewardsSimple.java @@ -0,0 +1,132 @@ +package explicit.rewards; + +import java.util.ArrayList; +import java.util.List; + +public class STPGRewardsSimple implements STPGRewards +{ + /** Number of states */ + protected int numStates; + + protected List> distributionSetRewards; + + protected List>>> transRewards; + + public STPGRewardsSimple(int numStates) + { + this.numStates = numStates; + // Initially lists are just null (denoting all 0) + distributionSetRewards = new ArrayList>(); + + transRewards = new ArrayList>>>(numStates); + for (int j = 0; j < numStates; j++) + { + transRewards.add(null); + distributionSetRewards.add(null); + } + } + + /** + * NOT IMPLEMENTED + */ + @Override + public double getDistributionSetReward(int s, int ds) + { + return 0; + } + + @Override + public int getTransitionRewardCount(int s, int ds, int d) + { + if (transRewards.get(s) == null || transRewards.get(s).get(ds) == null || transRewards.get(s).get(ds).get(d) == null) + return 0; + else + return transRewards.get(s).get(ds).get(d).size(); + } + + /** + * Adds rewards specified by {@code newRewards} to the rewards associated + * with {@code ds}th distribution of state {@code s}. + * + * The rewards are given as a list of lists of doubles, where the + * i-th element of {@code newRewards} specifies the rewards to be added + * to the (possibly empty) list of rewards associated with + * i-th distribution associated with {@code s} and {@code ds}. + * + * @param s + * @param ds + * @param newRewards + */ + public void addTransitionRewards(int s, int ds, List> newRewards) + { + if (transRewards.get(s) == null) { + List>> distTransRewards = new ArrayList>>(); + transRewards.set(s, distTransRewards); + } + + if (transRewards.get(s).size() <= ds) { + List> lTransRewards = new ArrayList>(); + transRewards.get(s).add(lTransRewards); + } + + List> dsRewards = transRewards.get(s).get(ds); + if (dsRewards.size() < newRewards.size()) + { + for (int i = dsRewards.size(); i < newRewards.size(); i++) + { + dsRewards.add(new ArrayList()); + } + } + + + for (int i = 0; i < dsRewards.size(); i++) + { + dsRewards.get(i).addAll(newRewards.get(i)); + } + } + + + @Override + public double getTransitionReward(int s, int ds, int d, int i) + { + return this.transRewards.get(s).get(ds).get(d).get(i); + } + + @Override + public void clearRewards(int s) + { + if(this.distributionSetRewards.get(s) != null) + this.distributionSetRewards.get(s).clear(); + if(this.transRewards.get(s) != null) + this.transRewards.get(s).clear(); + } + + public void addStates(int n) + { + this.numStates += n; + for (int i=0; i> transRewards; + + public StateTransitionRewardsSimple(int numStates) + { + super(numStates); + this.transRewards = new ArrayList>(); + for(int i = 0; i < numStates; i++) + { + this.transRewards.add(new ArrayList()); + } + } + + /** + * Increase the number of states by {@code numStates} + * + * @param numStates Number of newly added states + */ + public void addStates(int numStates) + { + for(int i = 0; i < numStates; i++) + { + this.transRewards.add(new ArrayList()); + } + } + + /** + * Set the reward of choice {@code c} of state {@code s} to {@code r}. + * + * The number of states added so far must be at least {@code s+1}. + * + * @param s State + * @param c Choice (Transition) + * @param r Reward + */ + public void setTransitionReward(int s, int c, double r) + { + int n = s - transRewards.get(s).size() + 1; + if (n > 0) { + for (int j = 0; j < n; j++) { + transRewards.get(s).add(0.0); + } + } + transRewards.get(s).set(c, r); + } + + @Override + public double getTransitionReward(int s, int i) + { + return transRewards.get(s).get(i); + } + + public String toString() + { + return "rews: " + stateRewards + "; rewt: " + transRewards; + } +}