Browse Source
imported patch symb-counter-transform-transformations.patch
tud-infrastructure-2018-10-12
imported patch symb-counter-transform-transformations.patch
tud-infrastructure-2018-10-12
4 changed files with 565 additions and 4 deletions
-
245prism/src/prism/CounterTransformation.java
-
8prism/src/prism/ProbModelTransformationOperator.java
-
95prism/src/prism/RewardCounterProduct.java
-
221prism/src/prism/RewardCounterTransformationAdd.java
@ -0,0 +1,245 @@ |
|||
package prism; |
|||
|
|||
import java.util.ArrayList; |
|||
import java.util.List; |
|||
|
|||
import jdd.JDD; |
|||
import jdd.JDDNode; |
|||
import parser.ast.Expression; |
|||
import parser.ast.ExpressionReward; |
|||
import parser.ast.RewardStruct; |
|||
import parser.ast.TemporalOperatorBound; |
|||
import parser.ast.TemporalOperatorBounds; |
|||
import parser.visitor.ReplaceBound; |
|||
import prism.IntegerBound; |
|||
import prism.PrismException; |
|||
|
|||
public class CounterTransformation<M extends Model> implements ModelExpressionTransformation<M, M> { |
|||
private Expression originalExpression; |
|||
private Expression transformedExpression; |
|||
private M originalModel; |
|||
private RewardCounterProduct<M> product; |
|||
|
|||
StateModelChecker mc; |
|||
|
|||
/** |
|||
* The originalExpression will be modified! |
|||
* @param mc |
|||
* @param originalModel |
|||
* @param originalExpression |
|||
* @param bound |
|||
* @param statesOfInterest |
|||
* @throws PrismException |
|||
*/ |
|||
public CounterTransformation(StateModelChecker mc, |
|||
M originalModel, Expression originalExpression, |
|||
TemporalOperatorBound bound, |
|||
JDDNode statesOfInterest) throws PrismException { |
|||
this.originalModel = originalModel; |
|||
this.originalExpression = originalExpression.deepCopy(); |
|||
this.mc = mc; |
|||
|
|||
transformedExpression = originalExpression; |
|||
doTransformation(originalModel, bound, statesOfInterest); |
|||
} |
|||
|
|||
/** |
|||
* The originalExpression will be modified! |
|||
* @param mc |
|||
* @param originalModel |
|||
* @param originalExpression |
|||
* @param bound |
|||
* @param statesOfInterest |
|||
* @throws PrismException |
|||
*/ |
|||
public CounterTransformation(StateModelChecker mc, |
|||
M originalModel, Expression originalExpression, |
|||
List<TemporalOperatorBound> bounds, |
|||
JDDNode statesOfInterest) throws PrismException { |
|||
this.originalModel = originalModel; |
|||
this.originalExpression = originalExpression.deepCopy(); |
|||
this.mc = mc; |
|||
|
|||
transformedExpression = originalExpression; |
|||
doTransformation(originalModel, bounds, statesOfInterest); |
|||
} |
|||
|
|||
|
|||
@Override |
|||
public Expression getTransformedExpression() { |
|||
return transformedExpression; |
|||
} |
|||
|
|||
@Override |
|||
public M getTransformedModel() { |
|||
return product.getTransformedModel(); |
|||
} |
|||
|
|||
@Override |
|||
public JDDNode getTransformedStatesOfInterest() { |
|||
return product.getTransformedStatesOfInterest(); |
|||
} |
|||
|
|||
@Override |
|||
public M getOriginalModel() { |
|||
return originalModel; |
|||
} |
|||
|
|||
@Override |
|||
public Expression getOriginalExpression() { |
|||
return originalExpression; |
|||
} |
|||
|
|||
@Override |
|||
public StateValues projectToOriginalModel(StateValues svTransformedModel) |
|||
throws PrismException { |
|||
return product.projectToOriginalModel(svTransformedModel); |
|||
} |
|||
|
|||
private void doTransformation(M originalModel, TemporalOperatorBound bound, JDDNode statesOfInterest) throws PrismException { |
|||
List<TemporalOperatorBound> bounds = new ArrayList<TemporalOperatorBound>(); |
|||
bounds.add(bound); |
|||
doTransformation(originalModel, bounds, statesOfInterest); |
|||
} |
|||
|
|||
private void doTransformation(M originalModel, List<TemporalOperatorBound> bounds, JDDNode statesOfInterest) throws PrismException { |
|||
if (originalModel instanceof NondetModel) { |
|||
doTransformation((NondetModel)originalModel, bounds, statesOfInterest); |
|||
} else { |
|||
throw new PrismException("Counter-Transformation is not supported for "+originalModel.getClass()); |
|||
} |
|||
} |
|||
|
|||
|
|||
|
|||
@SuppressWarnings("unchecked") |
|||
private void doTransformation(NondetModel originalModel, List<TemporalOperatorBound> bounds, JDDNode statesOfInterest) throws PrismException |
|||
{ |
|||
List<IntegerBound> intBounds = new ArrayList<IntegerBound>(); |
|||
|
|||
if (bounds.isEmpty()) { |
|||
throw new IllegalArgumentException("Can not do counter transformation without any bounds."); |
|||
} |
|||
|
|||
for (TemporalOperatorBound bound : bounds) { |
|||
IntegerBound intBound = IntegerBound.fromTemporalOperatorBound(bound, mc.constantValues, true); |
|||
intBounds.add(intBound); |
|||
|
|||
if (!bound.hasSameDomainDiscreteTime(bounds.get(0))) { |
|||
throw new IllegalArgumentException("Can only do counter transformation for bounds with same domain."); |
|||
} |
|||
} |
|||
JDDNode trRewards = null; |
|||
|
|||
switch (bounds.get(0).getBoundType()) { |
|||
case REWARD_BOUND: { |
|||
// Get reward info |
|||
Object rsi = bounds.get(0).getRewardStructureIndex(); |
|||
JDDNode srew = mc.getStateRewardsByIndexObject(rsi, originalModel, mc.constantValues).copy(); |
|||
JDDNode trew = mc.getTransitionRewardsByIndexObject(rsi, originalModel, mc.constantValues).copy(); |
|||
|
|||
trRewards = JDD.Apply(JDD.PLUS, srew, trew); |
|||
break; |
|||
} |
|||
case DEFAULT_BOUND: |
|||
case STEP_BOUND: |
|||
case TIME_BOUND: |
|||
// a time/step bound, use constant reward structure assigning 1 to each state |
|||
trRewards = JDD.Constant(1); |
|||
break; |
|||
} |
|||
|
|||
if (trRewards == null) { |
|||
throw new PrismException("Couldn't generate reward information."); |
|||
} |
|||
|
|||
int saturation_limit = IntegerBound.getMaximalInterestingValueForConjunction(intBounds); |
|||
|
|||
product = (RewardCounterProduct<M>) RewardCounterProduct.generate(mc.prism, originalModel, trRewards, saturation_limit, statesOfInterest); |
|||
|
|||
// add 'in_bound-x' label |
|||
JDDNode statesInBound = product.getStatesWithAccumulatedRewardInBoundConjunction(intBounds); |
|||
//JDD.PrintMinterms(mc.prism.getMainLog(), statesInBound.copy(), "statesInBound (1)"); |
|||
statesInBound = JDD.And(statesInBound, product.getTransformedModel().getReach().copy()); |
|||
//JDD.PrintMinterms(mc.prism.getMainLog(), statesInBound.copy(), "statesInBound (2)"); |
|||
String labelInBound = ((NondetModel)product.getTransformedModel()).addUniqueLabelDD("in_bound", statesInBound, mc.getDefinedLabelNames()); |
|||
|
|||
// transform expression |
|||
for (TemporalOperatorBound bound : bounds) { |
|||
ReplaceBound replace = new ReplaceBound(bound, labelInBound); |
|||
transformedExpression = (Expression) transformedExpression.accept(replace); |
|||
|
|||
if (!replace.wasSuccessfull()) { |
|||
throw new PrismException("Replacing bound was not successfull."); |
|||
} |
|||
} |
|||
} |
|||
|
|||
|
|||
public static <M extends Model> ModelExpressionTransformation<M, M> replaceBoundsWithCounters(StateModelChecker mc, |
|||
M originalModel, Expression originalExpression, |
|||
List<TemporalOperatorBound> bounds, |
|||
JDDNode statesOfInterest) throws PrismException { |
|||
|
|||
if (bounds.isEmpty()) { |
|||
throw new PrismException("No bounds to replace!"); |
|||
} |
|||
|
|||
if (!originalExpression.isSimplePathFormula()) { |
|||
throw new PrismException("Replacing bounds is only supported in simple path formulas."); |
|||
} |
|||
|
|||
Prism prism = mc.prism; |
|||
|
|||
// TODO: Check nesting depth = 1 |
|||
|
|||
ModelExpressionTransformation<M, M> nested = null; |
|||
for (TemporalOperatorBound bound : bounds) { |
|||
// resolve RewardStruct for reward bounds |
|||
if (bound.isRewardBound()) { |
|||
int r = ExpressionReward.getRewardStructIndexByIndexObject(bound.getRewardStructureIndex(), mc.prism.getPRISMModel(), mc.constantValues); |
|||
bound.setResolvedRewardStructIndex(r); |
|||
} |
|||
} |
|||
|
|||
List<List<TemporalOperatorBound>> groupedBoundList = TemporalOperatorBounds.groupBoundsDiscreteTime(bounds); |
|||
|
|||
for (List<TemporalOperatorBound> groupedBounds : groupedBoundList) { |
|||
if (groupedBounds.get(0).isRewardBound()) { |
|||
String rewStructName = mc.getModulesFile().getRewardStructNames().get(groupedBounds.get(0).getResolvedRewardStructIndex()); |
|||
prism.getLog().println("Transform to incorporate counter for reward '" + rewStructName + "' and " + groupedBounds); |
|||
} else { |
|||
prism.getLog().println("Transform to incorporate counter for steps "+groupedBounds); |
|||
} |
|||
|
|||
ModelExpressionTransformation<M, M> current; |
|||
|
|||
if (nested == null) { |
|||
current = new CounterTransformation<M>(mc, originalModel, originalExpression, groupedBounds, statesOfInterest); |
|||
nested = current; |
|||
} else { |
|||
current = new CounterTransformation<M>(mc, nested.getTransformedModel(), nested.getTransformedExpression(), groupedBounds, nested.getTransformedStatesOfInterest()); |
|||
nested = new ModelExpressionTransformationNested<M, M, M>(nested, current); |
|||
} |
|||
|
|||
prism.getLog().println("Transformed "+nested.getTransformedModel().getModelType()+": "); |
|||
nested.getTransformedModel().printTransInfo(prism.getLog()); |
|||
/* try { |
|||
prism.exportTransToFile(nested.getTransformedModel(), true, Prism.EXPORT_DOT_STATES, new java.io.File("t.dot")); |
|||
} catch (FileNotFoundException e) { |
|||
// TODO Auto-generated catch block |
|||
e.printStackTrace(); |
|||
}*/ |
|||
prism.getLog().println("Transformed expression: "+ nested.getTransformedExpression()); |
|||
} |
|||
|
|||
return nested; |
|||
} |
|||
|
|||
@Override |
|||
public void clear() |
|||
{ |
|||
product.clear(); |
|||
} |
|||
|
|||
} |
|||
@ -0,0 +1,95 @@ |
|||
package prism; |
|||
|
|||
|
|||
import java.util.List; |
|||
|
|||
import jdd.JDD; |
|||
import jdd.JDDNode; |
|||
import jdd.JDDVars; |
|||
|
|||
public class RewardCounterProduct<M extends Model> extends Product<M> |
|||
{ |
|||
private int limit; |
|||
private RewardCounterTransformationAdd transform; |
|||
|
|||
private RewardCounterProduct(M originalModel, |
|||
M productModel, |
|||
RewardCounterTransformationAdd transform, |
|||
JDDNode productStatesOfInterest, |
|||
JDDVars automatonRowVars) { |
|||
super(productModel, originalModel, productStatesOfInterest, automatonRowVars); |
|||
this.transform = transform; |
|||
this.limit = transform.getLimit(); |
|||
} |
|||
|
|||
@Override |
|||
public void clear() |
|||
{ |
|||
super.clear(); |
|||
transform.clear(); |
|||
} |
|||
|
|||
/** |
|||
* Get the states in the product for a given accumulated reward. |
|||
* If acc_reward is >= limit, then the states with rewards beyond the |
|||
* limit are returned. |
|||
*/ |
|||
public JDDNode getStatesWithAccumulatedReward(int acc_reward) { |
|||
if (acc_reward >= transform.getLimit()) { |
|||
acc_reward = transform.getLimit(); |
|||
} |
|||
return JDD.And(productModel.getReach().copy(), |
|||
transform.encodeInt(acc_reward, false)); |
|||
} |
|||
|
|||
/** |
|||
* Get the states in the product inside a given integer bound. |
|||
*/ |
|||
public JDDNode getStatesWithAccumulatedRewardInBound(IntegerBound bound) { |
|||
JDDNode result = JDD.Constant(0); |
|||
for (int r=0; r<=bound.getMaximalInterestingValue(); r++) { |
|||
if (bound.isInBounds(r)) { |
|||
result = JDD.Or(result, getStatesWithAccumulatedReward(r)); |
|||
} |
|||
} |
|||
return result; |
|||
} |
|||
|
|||
/** |
|||
* Generate the product of a MDP with an accumulated reward counter. |
|||
* The counter has the range [0,limit], with saturation semantics for accumulated |
|||
* rewards >=limit. |
|||
* @param originalModel the MDP |
|||
* @param rewards integer MCRewards |
|||
* @param limit the saturation value for the counter |
|||
* @param statesOfInterest the set of state of interest, the starting point for the counters |
|||
* @return |
|||
* @throws PrismException |
|||
*/ |
|||
static public RewardCounterProduct<NondetModel> generate(PrismComponent parent, NondetModel originalModel, JDDNode trRewards, int limit, JDDNode statesOfInterest) throws PrismException { |
|||
TransitionsByRewardsInfo info = new TransitionsByRewardsInfo(parent, originalModel, trRewards); |
|||
RewardCounterTransformationAdd transform = new RewardCounterTransformationAdd(originalModel, info, limit, statesOfInterest); |
|||
|
|||
NondetModel transformedModel = originalModel.getTransformed(transform); |
|||
JDDNode productStatesOfInterest = transformedModel.getStart().copy(); |
|||
return new RewardCounterProduct<NondetModel>(originalModel, transformedModel, transform, productStatesOfInterest, transform.getExtraRowVars().copy()); |
|||
} |
|||
|
|||
|
|||
/** |
|||
* Get the states in the product DTMC inside the conjunction of integer bound. |
|||
*/ |
|||
JDDNode getStatesWithAccumulatedRewardInBoundConjunction(List<IntegerBound> bounds) { |
|||
JDDNode result = JDD.Constant(0); |
|||
for (int r=0; r<=limit; r++) { |
|||
//System.out.println("r="+r+" is in bound?"); |
|||
if (IntegerBound.isInBoundForConjunction(bounds, r)) { |
|||
//System.out.println("r="+r+" is in bound"); |
|||
JDDNode accStates = getStatesWithAccumulatedReward(r); |
|||
// JDD.PrintMinterms(new PrismFileLog("stdout"), accStates.copy(), "accStates"); |
|||
result = JDD.Or(result, accStates); |
|||
} |
|||
} |
|||
return result; |
|||
} |
|||
} |
|||
@ -0,0 +1,221 @@ |
|||
package prism; |
|||
|
|||
import java.util.BitSet; |
|||
|
|||
import jdd.JDD; |
|||
import jdd.JDDNode; |
|||
import jdd.JDDVars; |
|||
|
|||
public class RewardCounterTransformationAdd extends ProbModelTransformationOperator { |
|||
private int bits; |
|||
private int limit; |
|||
private int maxRepresentable; |
|||
private TransitionsByRewardsInfo info; |
|||
private JDDNode statesOfInterest; |
|||
// private PrismLog log = new PrismFileLog("stdout"); |
|||
private boolean msbFirst = true; |
|||
|
|||
/** Count rewards from [0,limit]. All values >= limit are encoded by limit */ |
|||
public RewardCounterTransformationAdd(ProbModel model, |
|||
TransitionsByRewardsInfo info, |
|||
int limit, |
|||
JDDNode statesOfInterest) { |
|||
super(model); |
|||
|
|||
this.info = info; |
|||
this.limit = limit; |
|||
this.statesOfInterest = statesOfInterest; |
|||
|
|||
//log.println("Limit = "+limit); |
|||
bits = (int) Math.ceil(PrismUtils.log2(limit+1)); |
|||
maxRepresentable = (1<<bits) - 1; |
|||
} |
|||
|
|||
public void clear() |
|||
{ |
|||
super.clear(); |
|||
info.clear(); |
|||
if (statesOfInterest != null) |
|||
JDD.Deref(statesOfInterest); |
|||
} |
|||
|
|||
private int bitIndex2Var(int i) |
|||
{ |
|||
if (msbFirst) { |
|||
return bits-i-1; |
|||
} else { |
|||
return i; |
|||
} |
|||
} |
|||
|
|||
@Override |
|||
public int getExtraStateVariableCount() { |
|||
return bits; |
|||
} |
|||
|
|||
public int getLimit() { |
|||
return limit; |
|||
} |
|||
|
|||
@Override |
|||
public JDDNode getTransformedTrans() throws PrismException { |
|||
JDDNode newTrans = JDD.Constant(0); |
|||
|
|||
for (int rew : info.getOccuringRewards()) { |
|||
JDDNode tr_rew = info.getTransitionsWithReward(rew); |
|||
|
|||
JDDNode tr_rew_with_counter = |
|||
JDD.Apply(JDD.TIMES, tr_rew, |
|||
adder(extraRowVars, extraColVars, rew)); |
|||
|
|||
// JDD.PrintMinterms(log, tr_rew_with_counter.copy(), "tr_rew_with_counter ("+rew+")"); |
|||
newTrans = JDD.Apply(JDD.MAX, newTrans, tr_rew_with_counter); |
|||
} |
|||
|
|||
return newTrans; |
|||
} |
|||
|
|||
@Override |
|||
public JDDNode getTransformedStart() { |
|||
JDDNode newStart = JDD.And(statesOfInterest.copy(), |
|||
encodeInt(0, false)); |
|||
|
|||
return newStart; |
|||
} |
|||
|
|||
public JDDVars getExtraRowVars() { |
|||
return extraRowVars; |
|||
} |
|||
|
|||
public JDDNode saturated(boolean col) { |
|||
int max = (1 << bits) - 1; |
|||
//log.println("Max = "+max); |
|||
JDDNode result = JDD.Constant(0); |
|||
for (int i = limit; i <= max; i++) { |
|||
JDDNode iDD = encodeInt(i, col); |
|||
//JDD.PrintMinterms(log, iDD, "i="+i); |
|||
result = JDD.Or(result, iDD); |
|||
} |
|||
return result; |
|||
} |
|||
|
|||
public int decodeInt(BitSet bitset) { |
|||
long[] v = bitset.toLongArray(); |
|||
if (v.length == 0) { |
|||
return 0; |
|||
} else if (v.length > 1 || v[0] > Integer.MAX_VALUE) { |
|||
throw new IllegalArgumentException("Integer value out of range"); |
|||
} |
|||
|
|||
return (int)v[0]; |
|||
} |
|||
|
|||
public JDDNode encodeInt(int value, boolean col) { |
|||
if (value < 0) |
|||
throw new IllegalArgumentException("Can not encode negative integer"); |
|||
|
|||
JDDVars vars = col ? extraColVars : extraRowVars; |
|||
BitSet vBits = BitSet.valueOf(new long[]{value}); |
|||
//log.println(vBits); |
|||
|
|||
if (value > maxRepresentable) { |
|||
throw new IllegalArgumentException("Value "+value+" is out of range"); |
|||
} |
|||
|
|||
JDDNode result = JDD.Constant(1); |
|||
for (int i=0; i < vars.n(); i++) { |
|||
if (vBits.get(i)) |
|||
result = JDD.And(result, vars.getVar(bitIndex2Var(i)).copy()); |
|||
else |
|||
result = JDD.And(result, JDD.Not(vars.getVar(bitIndex2Var(i)).copy())); |
|||
} |
|||
|
|||
return result; |
|||
} |
|||
|
|||
public JDDNode getStatesWithAccumulatedReward(int r) { |
|||
if (r >= limit) { |
|||
return encodeInt(limit, false); |
|||
} else { |
|||
return encodeInt(r, false); |
|||
} |
|||
} |
|||
|
|||
private JDDNode adder(JDDVars row, JDDVars col, int summand) throws PrismException { |
|||
JDDNode result; |
|||
|
|||
if (summand < 0) { |
|||
throw new IllegalArgumentException("Can not create adder for negative summand"); |
|||
} |
|||
if (row.n() != col.n()) { |
|||
throw new IllegalArgumentException("Can not create adder for different number of variables"); |
|||
} |
|||
|
|||
//log.println("Summand = "+summand+", bits = "+bits); |
|||
|
|||
if (summand >= limit) { |
|||
// -> limit_next |
|||
return encodeInt(limit, true); |
|||
} |
|||
|
|||
// convert summand to BitSet |
|||
BitSet summandBits = BitSet.valueOf(new long[]{summand}); |
|||
JDDNode nextValues = JDD.Constant(1); |
|||
JDDNode carry = JDD.Constant(0.0); |
|||
|
|||
// for all the bits (0, ..., n-1) |
|||
for (int i = 0; i < row.n(); i++) { |
|||
// x = i-th bit in the row vector |
|||
JDDNode x = row.getVar(bitIndex2Var(i)).copy(); |
|||
// y = i-th bit of the summand |
|||
JDDNode y = summandBits.get(i) ? JDD.Constant(1.0) : JDD.Constant(0.0); |
|||
|
|||
JDDNode z = JDD.Xor(JDD.Xor(x.copy(), y.copy()), carry.copy()); |
|||
nextValues = JDD.And(nextValues, JDD.Equiv(col.getVar(bitIndex2Var(i)).copy(), z)); |
|||
carry = JDD.Or(JDD.Or(JDD.And(x.copy(), carry.copy()), |
|||
JDD.And(y.copy(), carry)), |
|||
JDD.And(x.copy(), y.copy())); |
|||
JDD.Deref(x); |
|||
JDD.Deref(y); |
|||
} |
|||
|
|||
JDDNode saturated_now = saturated(false); |
|||
//JDD.PrintMinterms(log, saturated_now.copy(), "saturated_now"); |
|||
JDDNode saturated_next = JDD.And(nextValues.copy(), saturated(true)); |
|||
//JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (1)"); |
|||
saturated_next = JDD.ThereExists(saturated_next, extraColVars); |
|||
//JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (2)"); |
|||
saturated_next = JDD.Or(saturated_next, carry.copy()); |
|||
//JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (3)"); |
|||
JDDNode limit_next = encodeInt(limit, true); |
|||
//JDD.PrintMinterms(log, limit_next.copy(), "limit_next"); |
|||
|
|||
//JDD.PrintMinterms(qc.getLog(), negative_now, "negative_now"); |
|||
//JDD.PrintMinterms(qc.getLog(), negative_next, "negative_next"); |
|||
|
|||
// result = saturated_now -> limit_next |
|||
result = JDD.Implies(saturated_now.copy(), limit_next.copy()); |
|||
|
|||
// result &= (!saturated_now & saturated_next) -> limit_next |
|||
result = JDD.And(result, |
|||
JDD.Implies(JDD.And(JDD.Not(saturated_now.copy()), saturated_next.copy()), |
|||
limit_next.copy())); |
|||
|
|||
// result &= (!saturated_now & !satured_next) -> nextValues |
|||
result = JDD.And(result, |
|||
JDD.Implies(JDD.And(JDD.Not(saturated_now.copy()), |
|||
JDD.Not(saturated_next.copy())), |
|||
nextValues.copy())); |
|||
|
|||
JDD.Deref(nextValues); |
|||
JDD.Deref(carry); |
|||
JDD.Deref(saturated_now); |
|||
JDD.Deref(saturated_next); |
|||
JDD.Deref(limit_next); |
|||
|
|||
//JDD.PrintMinterms(qc.getLog(), result.copy(), "adder for "+summand); |
|||
|
|||
//JDD.PrintMinterms(log, result, "adder("+summand+")"); |
|||
return result; |
|||
} |
|||
} |
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue