diff --git a/prism/src/explicit/STPGAbstrSimple.java b/prism/src/explicit/STPGAbstrSimple.java index d0067738..4d2c6d61 100644 --- a/prism/src/explicit/STPGAbstrSimple.java +++ b/prism/src/explicit/STPGAbstrSimple.java @@ -715,7 +715,7 @@ public class STPGAbstrSimple extends ModelSimple implements STPG @Override public double mvMultRewMinMaxSingle(int s, double vect[], STPGRewards rewards, boolean min1, boolean min2, int adv[]) { - int dsIter, rewIter, dIter, rewCount, k; + int dsIter, dIter, k; double d, prob, minmax1, minmax2; boolean first1, first2; ArrayList step; @@ -728,26 +728,22 @@ public class STPGAbstrSimple extends ModelSimple implements STPG dsIter++; minmax2 = 0; first2 = true; - dIter = -1; for (Distribution distr : distrs) { dIter++; - rewCount = rewards.getTransitionRewardCount(s, dsIter, dIter); - for (rewIter = 0; rewIter < rewCount; rewIter++) { - // Compute sum for this distribution - d = rewards.getTransitionReward(s, dsIter, dIter, rewIter); - - for (Map.Entry e : distr) { - k = (Integer) e.getKey(); - prob = (Double) e.getValue(); - d += prob * vect[k]; - } - // Check whether we have exceeded min/max so far - if (first2 || (min2 && d < minmax2) || (!min2 && d > minmax2)) - minmax2 = d; - first2 = false; + // Compute sum for this distribution + d = rewards.getNestedTransitionReward(s, dsIter, dIter); + for (Map.Entry e : distr) { + k = (Integer) e.getKey(); + prob = (Double) e.getValue(); + d += prob * vect[k]; } + // Check whether we have exceeded min/max so far + if (first2 || (min2 && d < minmax2) || (!min2 && d > minmax2)) + minmax2 = d; + first2 = false; } + minmax2 += rewards.getNestedTransitionReward(s, dsIter); // Check whether we have exceeded min/max so far if (first1 || (min1 && minmax2 < minmax1) || (!min1 && minmax2 > minmax1)) minmax1 = minmax2; @@ -760,7 +756,7 @@ public class STPGAbstrSimple extends ModelSimple implements STPG @Override public List mvMultRewMinMaxSingleChoices(int s, double vect[], STPGRewards rewards, boolean min1, boolean min2, double val) { - int dsIter, rewIter, dIter, rewCount, k; + int dsIter, dIter, k; double d, prob, minmax2; boolean first2; List res; @@ -775,33 +771,28 @@ public class STPGAbstrSimple extends ModelSimple implements STPG dsIter++; minmax2 = 0; first2 = true; - dIter = -1; for (Distribution distr : distrs) { dIter++; - - rewCount = rewards.getTransitionRewardCount(s, dsIter, dIter); - for (rewIter = 0; rewIter < rewCount; rewIter++) { - // Compute sum for this distribution - d = rewards.getTransitionReward(s, dsIter, dIter, rewIter); - for (Map.Entry e : distr) { - k = (Integer) e.getKey(); - prob = (Double) e.getValue(); - d += prob * vect[k]; - } - // Check whether we have exceeded min/max so far - if (first2 || (min2 && d < minmax2) || (!min2 && d > minmax2)) - minmax2 = d; - first2 = false; + // Compute sum for this distribution + d = rewards.getNestedTransitionReward(s, dsIter, dIter); + for (Map.Entry e : distr) { + k = (Integer) e.getKey(); + prob = (Double) e.getValue(); + d += prob * vect[k]; } + // Check whether we have exceeded min/max so far + if (first2 || (min2 && d < minmax2) || (!min2 && d > minmax2)) + minmax2 = d; + first2 = false; } + minmax2 += rewards.getNestedTransitionReward(s, dsIter); // Store strategy info if value matches //if (PrismUtils.doublesAreClose(val, d, termCritParam, termCrit == TermCrit.ABSOLUTE)) { if (PrismUtils.doublesAreClose(val, minmax2, 1e-12, false)) { res.add(dsIter); //res.add(distrs.getAction()); } - } return res; diff --git a/prism/src/explicit/rewards/MDPRewardsSimple.java b/prism/src/explicit/rewards/MDPRewardsSimple.java index 369dd504..7564bf81 100644 --- a/prism/src/explicit/rewards/MDPRewardsSimple.java +++ b/prism/src/explicit/rewards/MDPRewardsSimple.java @@ -83,7 +83,7 @@ public class MDPRewardsSimple implements MDPRewards for (int j = 0; j < numStates; j++) transRewards.add(null); } - // If no rewards for state i yet, create list + // If no rewards for state s yet, create list if (transRewards.get(s) == null) { list = new ArrayList(); transRewards.set(s, list); diff --git a/prism/src/explicit/rewards/STPGRewards.java b/prism/src/explicit/rewards/STPGRewards.java index befc46b3..bd5e2b9c 100644 --- a/prism/src/explicit/rewards/STPGRewards.java +++ b/prism/src/explicit/rewards/STPGRewards.java @@ -28,52 +28,29 @@ package explicit.rewards; /** - * Class that provide access to explicit-state rewards for an abstraction STPG. - * - * There are two type of rewards. Ones are on distribution sets and correspond - * to state rewards in the MDP (SUPPORT FOR THESE IS CURRENTLY NOT IMPLEMENTED). - * The others are on Distributions and correspond to transition rewards. Because - * each of the distributions may abstract several concrete distributions, it can - * have multiple rewards. The number of different rewards for each distribution - * can be obtained using {@link #getTransitionRewardCount(int, int, int)}, - * the rewards itself are then obtained using {@link #getTransitionReward(int, int, int, int)} - * - * - * + * Classes that provide (read) access to explicit-state rewards for an STPG. + * See the {@link explicit.STPG} interface for details of the accompanying model, + * in particular, for an explanation of nested transitions. */ public interface STPGRewards extends Rewards { /** - * Returns the reward associated with {@code ds}th distribution for the state {@code s}. + * Get the state reward for state {@code s}. */ - public double getDistributionSetReward(int s, int ds); - + public abstract double getStateReward(int s); + /** - * Removes all rewards for DistributionSets and Distributions associated with state {@code s}. + * Get the transition reward for the {@code i}th choice from state {@code s}. */ - public void clearRewards(int s); + public abstract double getTransitionReward(int s, int i); /** - * Returns the number of different rewards associated with {@code d}th distribution of - * {@code ds}th distribution set of state {@code s} - * - * @param s State - * @param ds Distribution set - * @param d Distribution - * @return Number of different rewards associated with the distribution + * Get the transition reward for the {@code i}th nested choice from state {@code s}. */ - public int getTransitionRewardCount(int s, int ds, int d); + public abstract double getNestedTransitionReward(int s, int i); /** - * - * Returns {@code i}th reward of {@code d}th distribution of - * {@code ds}th distribution set of state {@code s} - * - * @param s State - * @param ds Distribution set - * @param d Distribution - * @param i Index of the reward to return - * @return The reward. + * Get the transition reward for the {@code i,j}th nested choice from state {@code s}. */ - public double getTransitionReward(int s, int ds, int d, int i); + public abstract double getNestedTransitionReward(int s, int i, int j); } diff --git a/prism/src/explicit/rewards/STPGRewardsConstant.java b/prism/src/explicit/rewards/STPGRewardsConstant.java deleted file mode 100644 index c37107bf..00000000 --- a/prism/src/explicit/rewards/STPGRewardsConstant.java +++ /dev/null @@ -1,69 +0,0 @@ -//============================================================================== -// -// Copyright (c) 2002- -// Authors: -// * Dave Parker (University of Oxford) -// * Vojtech Forejt (University of Oxford) -// -//------------------------------------------------------------------------------ -// -// This file is part of PRISM. -// -// PRISM is free software; you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation; either version 2 of the License, or -// (at your option) any later version. -// -// PRISM is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with PRISM; if not, write to the Free Software Foundation, -// Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA -// -//============================================================================== - -package explicit.rewards; - -/** - * Explicit storage of constant game rewards. - */ -public class STPGRewardsConstant implements STPGRewards -{ - private double dsReward; - private double transReward; - - public STPGRewardsConstant(double dsReward, double transReward) - { - this.dsReward = dsReward; - this.transReward = transReward; - } - - @Override - public double getDistributionSetReward(int s, int d) - { - return this.dsReward; - } - - @Override - public int getTransitionRewardCount(int s, int ds, int d) - { - return 1; - } - - @Override - public double getTransitionReward(int s, int d, int t, int i) - { - return this.transReward; - } - - @Override - public void clearRewards(int s) - { - //do nothing - return; - } - -} diff --git a/prism/src/explicit/rewards/STPGRewardsSimple.java b/prism/src/explicit/rewards/STPGRewardsSimple.java index e16dedfb..a7dfcce2 100644 --- a/prism/src/explicit/rewards/STPGRewardsSimple.java +++ b/prism/src/explicit/rewards/STPGRewardsSimple.java @@ -30,130 +30,132 @@ package explicit.rewards; import java.util.ArrayList; import java.util.List; -public class STPGRewardsSimple implements STPGRewards +public class STPGRewardsSimple extends MDPRewardsSimple implements STPGRewards { - /** Number of states */ - protected int numStates; - - protected List> distributionSetRewards; - - protected List>>> transRewards; + /** Nested transition rewards (level 1) */ + protected List> nestedTransRewards1; + /** Nested transition rewards (level 2) */ + protected List>> nestedTransRewards2; + /** + * Constructor: all zero rewards. + * @param numStates Number of states + */ public STPGRewardsSimple(int numStates) { - this.numStates = numStates; + super(numStates); // Initially lists are just null (denoting all 0) - distributionSetRewards = new ArrayList>(); - - transRewards = new ArrayList>>>(numStates); - for (int j = 0; j < numStates; j++) - { - transRewards.add(null); - distributionSetRewards.add(null); - } + nestedTransRewards1 = null; + nestedTransRewards2 = null; } - + + // Mutators + /** - * NOT IMPLEMENTED + * Set the reward for the {@code i}th nested transition of state {@code s} to {@code r}. */ - @Override - public double getDistributionSetReward(int s, int ds) + public void setNestedTransitionReward(int s, int i, double r) { - return 0; + List list; + // If no rewards array created yet, create it + if (nestedTransRewards1 == null) { + nestedTransRewards1 = new ArrayList>(numStates); + for (int j = 0; j < numStates; j++) + nestedTransRewards1.add(null); + } + // If no rewards for state s yet, create list + if (nestedTransRewards1.get(s) == null) { + list = new ArrayList(); + nestedTransRewards1.set(s, list); + } else { + list = nestedTransRewards1.get(s); + } + // If list not big enough, extend + int n = i - list.size() + 1; + if (n > 0) { + for (int j = 0; j < n; j++) { + list.add(0.0); + } + } + // Set reward + list.set(i, r); } - @Override - public int getTransitionRewardCount(int s, int ds, int d) - { - if (transRewards.get(s) == null || transRewards.get(s).get(ds) == null || transRewards.get(s).get(ds).get(d) == null) - return 0; - else - return transRewards.get(s).get(ds).get(d).size(); - } - /** - * Adds rewards specified by {@code newRewards} to the rewards associated - * with {@code ds}th distribution of state {@code s}. - * - * The rewards are given as a list of lists of doubles, where the - * i-th element of {@code newRewards} specifies the rewards to be added - * to the (possibly empty) list of rewards associated with - * i-th distribution associated with {@code s} and {@code ds}. - * - * @param s - * @param ds - * @param newRewards + * Set the reward for the {@code i},{@code j}th nested transition of state {@code s} to {@code r}. */ - public void addTransitionRewards(int s, int ds, List> newRewards) - { - if (transRewards.get(s) == null) { - List>> distTransRewards = new ArrayList>>(); - transRewards.set(s, distTransRewards); + public void setNestedTransitionReward(int s, int i, int j, double r) + { + List> list1; + List list2; + // If no rewards array created yet, create it + if (nestedTransRewards2 == null) { + nestedTransRewards2 = new ArrayList>>(numStates); + for (int k = 0; k < numStates; k++) + nestedTransRewards2.add(null); } - - if (transRewards.get(s).size() <= ds) { - List> lTransRewards = new ArrayList>(); - transRewards.get(s).add(lTransRewards); + // If no rewards for state s yet, create list1 + if (nestedTransRewards2.get(s) == null) { + list1 = new ArrayList>(); + nestedTransRewards2.set(s, list1); + } else { + list1 = nestedTransRewards2.get(s); } - - List> dsRewards = transRewards.get(s).get(ds); - if (dsRewards.size() < newRewards.size()) - { - for (int i = dsRewards.size(); i < newRewards.size(); i++) - { - dsRewards.add(new ArrayList()); + // If list1 not big enough, extend + int n1 = i - list1.size() + 1; + if (n1 > 0) { + for (int k = 0; k < n1; k++) { + list1.add(null); } } - - - for (int i = 0; i < dsRewards.size(); i++) - { - dsRewards.get(i).addAll(newRewards.get(i)); + // If no rewards for state s, choice i, create list2 + if (list1.get(i) == null) { + list2 = new ArrayList(); + list1.set(i, list2); + } else { + list2 = list1.get(i); } + // If list2 not big enough, extend + int n2 = j - list2.size() + 1; + if (n2 > 0) { + for (int k = 0; k < n2; k++) { + list2.add(null); + } + } + // Set reward + list2.set(j, r); } - + + // Accessors @Override - public double getTransitionReward(int s, int ds, int d, int i) + public double getNestedTransitionReward(int s, int i) { - return this.transRewards.get(s).get(ds).get(d).get(i); + List list; + if (nestedTransRewards1 == null || (list = nestedTransRewards1.get(s)) == null) + return 0.0; + if (list.size() <= i) + return 0.0; + return list.get(i); } @Override - public void clearRewards(int s) + public double getNestedTransitionReward(int s, int i, int j) { - if(this.distributionSetRewards.get(s) != null) - this.distributionSetRewards.get(s).clear(); - if(this.transRewards.get(s) != null) - this.transRewards.get(s).clear(); + List> list1; + List list2; + if (nestedTransRewards1 == null || (list1 = nestedTransRewards2.get(s)) == null) + return 0.0; + if (list1.size() <= i || (list2 = list1.get(i)) == null) + return 0.0; + if (list2.size() <= j) + return 0.0; + return list2.get(j); } - public void addStates(int n) - { - this.numStates += n; - for (int i=0; i