package explicit.rewards; import java.util.ArrayList; import java.util.List; public class STPGRewardsSimple implements STPGRewards { /** Number of states */ protected int numStates; protected List> distributionSetRewards; protected List>>> transRewards; public STPGRewardsSimple(int numStates) { this.numStates = numStates; // Initially lists are just null (denoting all 0) distributionSetRewards = new ArrayList>(); transRewards = new ArrayList>>>(numStates); for (int j = 0; j < numStates; j++) { transRewards.add(null); distributionSetRewards.add(null); } } /** * NOT IMPLEMENTED */ @Override public double getDistributionSetReward(int s, int ds) { return 0; } @Override public int getTransitionRewardCount(int s, int ds, int d) { if (transRewards.get(s) == null || transRewards.get(s).get(ds) == null || transRewards.get(s).get(ds).get(d) == null) return 0; else return transRewards.get(s).get(ds).get(d).size(); } /** * Adds rewards specified by {@code newRewards} to the rewards associated * with {@code ds}th distribution of state {@code s}. * * The rewards are given as a list of lists of doubles, where the * i-th element of {@code newRewards} specifies the rewards to be added * to the (possibly empty) list of rewards associated with * i-th distribution associated with {@code s} and {@code ds}. * * @param s * @param ds * @param newRewards */ public void addTransitionRewards(int s, int ds, List> newRewards) { if (transRewards.get(s) == null) { List>> distTransRewards = new ArrayList>>(); transRewards.set(s, distTransRewards); } if (transRewards.get(s).size() <= ds) { List> lTransRewards = new ArrayList>(); transRewards.get(s).add(lTransRewards); } List> dsRewards = transRewards.get(s).get(ds); if (dsRewards.size() < newRewards.size()) { for (int i = dsRewards.size(); i < newRewards.size(); i++) { dsRewards.add(new ArrayList()); } } for (int i = 0; i < dsRewards.size(); i++) { dsRewards.get(i).addAll(newRewards.get(i)); } } @Override public double getTransitionReward(int s, int ds, int d, int i) { return this.transRewards.get(s).get(ds).get(d).get(i); } @Override public void clearRewards(int s) { if(this.distributionSetRewards.get(s) != null) this.distributionSetRewards.get(s).clear(); if(this.transRewards.get(s) != null) this.transRewards.get(s).clear(); } public void addStates(int n) { this.numStates += n; for (int i=0; i