You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
3.0 KiB
132 lines
3.0 KiB
package explicit.rewards;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.List;
|
|
|
|
public class STPGRewardsSimple implements STPGRewards
|
|
{
|
|
/** Number of states */
|
|
protected int numStates;
|
|
|
|
protected List<List<Double>> distributionSetRewards;
|
|
|
|
protected List<List<List<List<Double>>>> transRewards;
|
|
|
|
public STPGRewardsSimple(int numStates)
|
|
{
|
|
this.numStates = numStates;
|
|
// Initially lists are just null (denoting all 0)
|
|
distributionSetRewards = new ArrayList<List<Double>>();
|
|
|
|
transRewards = new ArrayList<List<List<List<Double>>>>(numStates);
|
|
for (int j = 0; j < numStates; j++)
|
|
{
|
|
transRewards.add(null);
|
|
distributionSetRewards.add(null);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* NOT IMPLEMENTED
|
|
*/
|
|
@Override
|
|
public double getDistributionSetReward(int s, int ds)
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
@Override
|
|
public int getTransitionRewardCount(int s, int ds, int d)
|
|
{
|
|
if (transRewards.get(s) == null || transRewards.get(s).get(ds) == null || transRewards.get(s).get(ds).get(d) == null)
|
|
return 0;
|
|
else
|
|
return transRewards.get(s).get(ds).get(d).size();
|
|
}
|
|
|
|
/**
|
|
* Adds rewards specified by {@code newRewards} to the rewards associated
|
|
* with {@code ds}th distribution of state {@code s}.
|
|
*
|
|
* The rewards are given as a list of lists of doubles, where the
|
|
* i-th element of {@code newRewards} specifies the rewards to be added
|
|
* to the (possibly empty) list of rewards associated with
|
|
* i-th distribution associated with {@code s} and {@code ds}.
|
|
*
|
|
* @param s
|
|
* @param ds
|
|
* @param newRewards
|
|
*/
|
|
public void addTransitionRewards(int s, int ds, List<List<Double>> newRewards)
|
|
{
|
|
if (transRewards.get(s) == null) {
|
|
List<List<List<Double>>> distTransRewards = new ArrayList<List<List<Double>>>();
|
|
transRewards.set(s, distTransRewards);
|
|
}
|
|
|
|
if (transRewards.get(s).size() <= ds) {
|
|
List<List<Double>> lTransRewards = new ArrayList<List<Double>>();
|
|
transRewards.get(s).add(lTransRewards);
|
|
}
|
|
|
|
List<List<Double>> dsRewards = transRewards.get(s).get(ds);
|
|
if (dsRewards.size() < newRewards.size())
|
|
{
|
|
for (int i = dsRewards.size(); i < newRewards.size(); i++)
|
|
{
|
|
dsRewards.add(new ArrayList<Double>());
|
|
}
|
|
}
|
|
|
|
|
|
for (int i = 0; i < dsRewards.size(); i++)
|
|
{
|
|
dsRewards.get(i).addAll(newRewards.get(i));
|
|
}
|
|
}
|
|
|
|
|
|
@Override
|
|
public double getTransitionReward(int s, int ds, int d, int i)
|
|
{
|
|
return this.transRewards.get(s).get(ds).get(d).get(i);
|
|
}
|
|
|
|
@Override
|
|
public void clearRewards(int s)
|
|
{
|
|
if(this.distributionSetRewards.get(s) != null)
|
|
this.distributionSetRewards.get(s).clear();
|
|
if(this.transRewards.get(s) != null)
|
|
this.transRewards.get(s).clear();
|
|
}
|
|
|
|
public void addStates(int n)
|
|
{
|
|
this.numStates += n;
|
|
for (int i=0; i<n; i++)
|
|
{
|
|
this.distributionSetRewards.add(null);
|
|
this.transRewards.add(null);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public String toString()
|
|
{
|
|
int i;
|
|
boolean first;
|
|
String s = "";
|
|
first = true;
|
|
s = "[ ";
|
|
for (i = 0; i < numStates; i++) {
|
|
if (first)
|
|
first = false;
|
|
else
|
|
s += ", ";
|
|
s += i + ": " + transRewards.get(i);
|
|
}
|
|
s += " ]";
|
|
return s;
|
|
}
|
|
}
|