Browse Source
imported patch symb-counter-transform-TransitionsByRewardsInfo.patch
accumulation-v4.7
imported patch symb-counter-transform-TransitionsByRewardsInfo.patch
accumulation-v4.7
committed by
Joachim Klein
1 changed files with 146 additions and 0 deletions
@ -0,0 +1,146 @@ |
|||||
|
package prism; |
||||
|
|
||||
|
import java.util.Map.Entry; |
||||
|
import java.util.Set; |
||||
|
import java.util.TreeMap; |
||||
|
|
||||
|
import common.SafeCast; |
||||
|
import jdd.JDD; |
||||
|
import jdd.JDDNode; |
||||
|
|
||||
|
public class TransitionsByRewardsInfo extends PrismComponent |
||||
|
{ |
||||
|
protected Model model; |
||||
|
protected TreeMap<Integer, JDDNode> rewToTrans = new TreeMap<Integer, JDDNode>(); |
||||
|
protected JDDNode transRewards; |
||||
|
private Integer maxReward = null; |
||||
|
|
||||
|
public TransitionsByRewardsInfo(PrismComponent parent, Model model, JDDNode transRewards) throws PrismException |
||||
|
{ |
||||
|
super(parent); |
||||
|
this.model = model; |
||||
|
this.transRewards = transRewards; |
||||
|
|
||||
|
splitTransitionMatrix(false); |
||||
|
} |
||||
|
|
||||
|
public Model getModel() |
||||
|
{ |
||||
|
return model; |
||||
|
} |
||||
|
|
||||
|
public JDDNode getTransRewards() |
||||
|
{ |
||||
|
return transRewards.copy(); |
||||
|
} |
||||
|
|
||||
|
private void putTransitionsWithReward(int rew, JDDNode tr) |
||||
|
{ |
||||
|
JDDNode old = rewToTrans.put(rew, tr); |
||||
|
if (old != null) JDD.Deref(old); |
||||
|
} |
||||
|
|
||||
|
public Set<Integer> getOccuringRewards() |
||||
|
{ |
||||
|
return rewToTrans.keySet(); |
||||
|
} |
||||
|
|
||||
|
public JDDNode getTransitionsWithReward(int rew) |
||||
|
{ |
||||
|
JDDNode result = rewToTrans.get(rew); |
||||
|
if (result != null) { |
||||
|
result = result.copy(); |
||||
|
} |
||||
|
return result; |
||||
|
} |
||||
|
|
||||
|
public JDDNode getStatesWithPosRewardTransitions() |
||||
|
{ |
||||
|
JDDNode tr01_pos = JDD.GreaterThan(transRewards.copy(), 0.0); |
||||
|
tr01_pos = JDD.And(tr01_pos, model.getTrans01().copy()); |
||||
|
if (model.getModelType() == ModelType.MDP) { |
||||
|
tr01_pos = JDD.ThereExists(tr01_pos, ((NondetModel)model).getAllDDNondetVars()); |
||||
|
} |
||||
|
JDDNode states_with_pos_tr = JDD.ThereExists(tr01_pos, model.getAllDDColVars()); |
||||
|
states_with_pos_tr = JDD.And(states_with_pos_tr, model.getReach().copy()); |
||||
|
|
||||
|
return states_with_pos_tr; |
||||
|
|
||||
|
} |
||||
|
|
||||
|
public JDDNode getTransitions01WithPosReward() |
||||
|
{ |
||||
|
JDDNode trZero = getTransitionsWithReward(0); |
||||
|
if (trZero == null) trZero = JDD.Constant(0.0); |
||||
|
JDDNode trZero01 = JDD.GreaterThan(trZero, 0.0); |
||||
|
|
||||
|
JDDNode trPos01 = JDD.And(getModel().getTrans01().copy(), JDD.Not(trZero01)); |
||||
|
|
||||
|
return trPos01; |
||||
|
} |
||||
|
|
||||
|
public Iterable<Entry<Integer, JDDNode>> getTransitionsWithReward() |
||||
|
{ |
||||
|
return rewToTrans.entrySet(); |
||||
|
} |
||||
|
|
||||
|
private void setMaxReward(int maxReward) |
||||
|
{ |
||||
|
this.maxReward = maxReward; |
||||
|
} |
||||
|
|
||||
|
public int getMaxReward() |
||||
|
{ |
||||
|
return maxReward; |
||||
|
} |
||||
|
|
||||
|
private void splitTransitionMatrix(boolean debug) throws PrismException |
||||
|
{ |
||||
|
Model model = getModel(); |
||||
|
JDDNode transRewards = getTransRewards(); |
||||
|
|
||||
|
if (debug) JDD.PrintMinterms(getLog(), model.getTrans().copy(), "tr"); |
||||
|
if (debug) JDD.PrintMinterms(getLog(), transRewards.copy(), "transRewards"); |
||||
|
|
||||
|
// zero reward |
||||
|
|
||||
|
JDDNode tr01ZeroRew = JDD.Equals(transRewards.copy(), 0.0); |
||||
|
JDDNode trZeroRew = JDD.Apply(JDD.TIMES, model.getTrans().copy(), tr01ZeroRew); |
||||
|
if (debug) JDD.PrintMinterms(getLog(), trZeroRew.copy(), "trZeroRew"); |
||||
|
putTransitionsWithReward(0, trZeroRew); |
||||
|
|
||||
|
int maxReward = 0; |
||||
|
|
||||
|
while (!transRewards.equals(JDD.ZERO)) { |
||||
|
// find maximal occurring reward |
||||
|
double rew = JDD.FindMax(transRewards); |
||||
|
int rewInt = SafeCast.toInt(rew); |
||||
|
|
||||
|
// track maximal reward |
||||
|
if (rewInt > maxReward) maxReward = rewInt; |
||||
|
|
||||
|
// get set of transitions with this reward |
||||
|
JDDNode tr01WithRew = JDD.Equals(transRewards.copy(), rew); |
||||
|
JDDNode trWithRew = JDD.Apply(JDD.TIMES, model.getTrans().copy(), tr01WithRew.copy()); |
||||
|
JDDNode remaining = JDD.Not(tr01WithRew); |
||||
|
|
||||
|
if (debug) JDD.PrintMinterms(getLog(), trWithRew.copy(), "trWithRew_"+rewInt); |
||||
|
putTransitionsWithReward(rewInt, trWithRew); |
||||
|
|
||||
|
// set tRew to 0 for the transitions in tr01WithRew |
||||
|
transRewards = JDD.Apply(JDD.TIMES, transRewards, remaining); |
||||
|
} |
||||
|
JDD.Deref(transRewards); |
||||
|
|
||||
|
setMaxReward(maxReward); |
||||
|
} |
||||
|
|
||||
|
public void clear() |
||||
|
{ |
||||
|
if (transRewards != null) JDD.Deref(transRewards); |
||||
|
|
||||
|
for (JDDNode tr : rewToTrans.values()) { |
||||
|
JDD.Deref(tr); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
Write
Preview
Loading…
Cancel
Save
Reference in new issue