Browse Source

imported patch symb-counter-transform-TransitionsByRewardsInfo.patch

accumulation-v4.7
Joachim Klein 7 years ago
committed by Joachim Klein
parent
commit
c5fce8f507
  1. 146
      prism/src/prism/TransitionsByRewardsInfo.java

146
prism/src/prism/TransitionsByRewardsInfo.java

@ -0,0 +1,146 @@
package prism;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;
import common.SafeCast;
import jdd.JDD;
import jdd.JDDNode;
public class TransitionsByRewardsInfo extends PrismComponent
{
protected Model model;
protected TreeMap<Integer, JDDNode> rewToTrans = new TreeMap<Integer, JDDNode>();
protected JDDNode transRewards;
private Integer maxReward = null;
public TransitionsByRewardsInfo(PrismComponent parent, Model model, JDDNode transRewards) throws PrismException
{
super(parent);
this.model = model;
this.transRewards = transRewards;
splitTransitionMatrix(false);
}
public Model getModel()
{
return model;
}
public JDDNode getTransRewards()
{
return transRewards.copy();
}
private void putTransitionsWithReward(int rew, JDDNode tr)
{
JDDNode old = rewToTrans.put(rew, tr);
if (old != null) JDD.Deref(old);
}
public Set<Integer> getOccuringRewards()
{
return rewToTrans.keySet();
}
public JDDNode getTransitionsWithReward(int rew)
{
JDDNode result = rewToTrans.get(rew);
if (result != null) {
result = result.copy();
}
return result;
}
public JDDNode getStatesWithPosRewardTransitions()
{
JDDNode tr01_pos = JDD.GreaterThan(transRewards.copy(), 0.0);
tr01_pos = JDD.And(tr01_pos, model.getTrans01().copy());
if (model.getModelType() == ModelType.MDP) {
tr01_pos = JDD.ThereExists(tr01_pos, ((NondetModel)model).getAllDDNondetVars());
}
JDDNode states_with_pos_tr = JDD.ThereExists(tr01_pos, model.getAllDDColVars());
states_with_pos_tr = JDD.And(states_with_pos_tr, model.getReach().copy());
return states_with_pos_tr;
}
public JDDNode getTransitions01WithPosReward()
{
JDDNode trZero = getTransitionsWithReward(0);
if (trZero == null) trZero = JDD.Constant(0.0);
JDDNode trZero01 = JDD.GreaterThan(trZero, 0.0);
JDDNode trPos01 = JDD.And(getModel().getTrans01().copy(), JDD.Not(trZero01));
return trPos01;
}
public Iterable<Entry<Integer, JDDNode>> getTransitionsWithReward()
{
return rewToTrans.entrySet();
}
private void setMaxReward(int maxReward)
{
this.maxReward = maxReward;
}
public int getMaxReward()
{
return maxReward;
}
private void splitTransitionMatrix(boolean debug) throws PrismException
{
Model model = getModel();
JDDNode transRewards = getTransRewards();
if (debug) JDD.PrintMinterms(getLog(), model.getTrans().copy(), "tr");
if (debug) JDD.PrintMinterms(getLog(), transRewards.copy(), "transRewards");
// zero reward
JDDNode tr01ZeroRew = JDD.Equals(transRewards.copy(), 0.0);
JDDNode trZeroRew = JDD.Apply(JDD.TIMES, model.getTrans().copy(), tr01ZeroRew);
if (debug) JDD.PrintMinterms(getLog(), trZeroRew.copy(), "trZeroRew");
putTransitionsWithReward(0, trZeroRew);
int maxReward = 0;
while (!transRewards.equals(JDD.ZERO)) {
// find maximal occurring reward
double rew = JDD.FindMax(transRewards);
int rewInt = SafeCast.toInt(rew);
// track maximal reward
if (rewInt > maxReward) maxReward = rewInt;
// get set of transitions with this reward
JDDNode tr01WithRew = JDD.Equals(transRewards.copy(), rew);
JDDNode trWithRew = JDD.Apply(JDD.TIMES, model.getTrans().copy(), tr01WithRew.copy());
JDDNode remaining = JDD.Not(tr01WithRew);
if (debug) JDD.PrintMinterms(getLog(), trWithRew.copy(), "trWithRew_"+rewInt);
putTransitionsWithReward(rewInt, trWithRew);
// set tRew to 0 for the transitions in tr01WithRew
transRewards = JDD.Apply(JDD.TIMES, transRewards, remaining);
}
JDD.Deref(transRewards);
setMaxReward(maxReward);
}
public void clear()
{
if (transRewards != null) JDD.Deref(transRewards);
for (JDDNode tr : rewToTrans.values()) {
JDD.Deref(tr);
}
}
}
Loading…
Cancel
Save