Browse Source

imported patch symb-counter-transform-transformations.patch

tud-infrastructure-2018-10-12
Joachim Klein 7 years ago
parent
commit
5f356b0bf2
  1. 245
      prism/src/prism/CounterTransformation.java
  2. 8
      prism/src/prism/ProbModelTransformationOperator.java
  3. 95
      prism/src/prism/RewardCounterProduct.java
  4. 221
      prism/src/prism/RewardCounterTransformationAdd.java

245
prism/src/prism/CounterTransformation.java

@ -0,0 +1,245 @@
package prism;
import java.util.ArrayList;
import java.util.List;
import jdd.JDD;
import jdd.JDDNode;
import parser.ast.Expression;
import parser.ast.ExpressionReward;
import parser.ast.RewardStruct;
import parser.ast.TemporalOperatorBound;
import parser.ast.TemporalOperatorBounds;
import parser.visitor.ReplaceBound;
import prism.IntegerBound;
import prism.PrismException;
public class CounterTransformation<M extends Model> implements ModelExpressionTransformation<M, M> {
private Expression originalExpression;
private Expression transformedExpression;
private M originalModel;
private RewardCounterProduct<M> product;
StateModelChecker mc;
/**
* The originalExpression will be modified!
* @param mc
* @param originalModel
* @param originalExpression
* @param bound
* @param statesOfInterest
* @throws PrismException
*/
public CounterTransformation(StateModelChecker mc,
M originalModel, Expression originalExpression,
TemporalOperatorBound bound,
JDDNode statesOfInterest) throws PrismException {
this.originalModel = originalModel;
this.originalExpression = originalExpression.deepCopy();
this.mc = mc;
transformedExpression = originalExpression;
doTransformation(originalModel, bound, statesOfInterest);
}
/**
* The originalExpression will be modified!
* @param mc
* @param originalModel
* @param originalExpression
* @param bound
* @param statesOfInterest
* @throws PrismException
*/
public CounterTransformation(StateModelChecker mc,
M originalModel, Expression originalExpression,
List<TemporalOperatorBound> bounds,
JDDNode statesOfInterest) throws PrismException {
this.originalModel = originalModel;
this.originalExpression = originalExpression.deepCopy();
this.mc = mc;
transformedExpression = originalExpression;
doTransformation(originalModel, bounds, statesOfInterest);
}
@Override
public Expression getTransformedExpression() {
return transformedExpression;
}
@Override
public M getTransformedModel() {
return product.getTransformedModel();
}
@Override
public JDDNode getTransformedStatesOfInterest() {
return product.getTransformedStatesOfInterest();
}
@Override
public M getOriginalModel() {
return originalModel;
}
@Override
public Expression getOriginalExpression() {
return originalExpression;
}
@Override
public StateValues projectToOriginalModel(StateValues svTransformedModel)
throws PrismException {
return product.projectToOriginalModel(svTransformedModel);
}
private void doTransformation(M originalModel, TemporalOperatorBound bound, JDDNode statesOfInterest) throws PrismException {
List<TemporalOperatorBound> bounds = new ArrayList<TemporalOperatorBound>();
bounds.add(bound);
doTransformation(originalModel, bounds, statesOfInterest);
}
private void doTransformation(M originalModel, List<TemporalOperatorBound> bounds, JDDNode statesOfInterest) throws PrismException {
if (originalModel instanceof NondetModel) {
doTransformation((NondetModel)originalModel, bounds, statesOfInterest);
} else {
throw new PrismException("Counter-Transformation is not supported for "+originalModel.getClass());
}
}
@SuppressWarnings("unchecked")
private void doTransformation(NondetModel originalModel, List<TemporalOperatorBound> bounds, JDDNode statesOfInterest) throws PrismException
{
List<IntegerBound> intBounds = new ArrayList<IntegerBound>();
if (bounds.isEmpty()) {
throw new IllegalArgumentException("Can not do counter transformation without any bounds.");
}
for (TemporalOperatorBound bound : bounds) {
IntegerBound intBound = IntegerBound.fromTemporalOperatorBound(bound, mc.constantValues, true);
intBounds.add(intBound);
if (!bound.hasSameDomainDiscreteTime(bounds.get(0))) {
throw new IllegalArgumentException("Can only do counter transformation for bounds with same domain.");
}
}
JDDNode trRewards = null;
switch (bounds.get(0).getBoundType()) {
case REWARD_BOUND: {
// Get reward info
Object rsi = bounds.get(0).getRewardStructureIndex();
JDDNode srew = mc.getStateRewardsByIndexObject(rsi, originalModel, mc.constantValues).copy();
JDDNode trew = mc.getTransitionRewardsByIndexObject(rsi, originalModel, mc.constantValues).copy();
trRewards = JDD.Apply(JDD.PLUS, srew, trew);
break;
}
case DEFAULT_BOUND:
case STEP_BOUND:
case TIME_BOUND:
// a time/step bound, use constant reward structure assigning 1 to each state
trRewards = JDD.Constant(1);
break;
}
if (trRewards == null) {
throw new PrismException("Couldn't generate reward information.");
}
int saturation_limit = IntegerBound.getMaximalInterestingValueForConjunction(intBounds);
product = (RewardCounterProduct<M>) RewardCounterProduct.generate(mc.prism, originalModel, trRewards, saturation_limit, statesOfInterest);
// add 'in_bound-x' label
JDDNode statesInBound = product.getStatesWithAccumulatedRewardInBoundConjunction(intBounds);
//JDD.PrintMinterms(mc.prism.getMainLog(), statesInBound.copy(), "statesInBound (1)");
statesInBound = JDD.And(statesInBound, product.getTransformedModel().getReach().copy());
//JDD.PrintMinterms(mc.prism.getMainLog(), statesInBound.copy(), "statesInBound (2)");
String labelInBound = ((NondetModel)product.getTransformedModel()).addUniqueLabelDD("in_bound", statesInBound, mc.getDefinedLabelNames());
// transform expression
for (TemporalOperatorBound bound : bounds) {
ReplaceBound replace = new ReplaceBound(bound, labelInBound);
transformedExpression = (Expression) transformedExpression.accept(replace);
if (!replace.wasSuccessfull()) {
throw new PrismException("Replacing bound was not successfull.");
}
}
}
public static <M extends Model> ModelExpressionTransformation<M, M> replaceBoundsWithCounters(StateModelChecker mc,
M originalModel, Expression originalExpression,
List<TemporalOperatorBound> bounds,
JDDNode statesOfInterest) throws PrismException {
if (bounds.isEmpty()) {
throw new PrismException("No bounds to replace!");
}
if (!originalExpression.isSimplePathFormula()) {
throw new PrismException("Replacing bounds is only supported in simple path formulas.");
}
Prism prism = mc.prism;
// TODO: Check nesting depth = 1
ModelExpressionTransformation<M, M> nested = null;
for (TemporalOperatorBound bound : bounds) {
// resolve RewardStruct for reward bounds
if (bound.isRewardBound()) {
int r = ExpressionReward.getRewardStructIndexByIndexObject(bound.getRewardStructureIndex(), mc.prism.getPRISMModel(), mc.constantValues);
bound.setResolvedRewardStructIndex(r);
}
}
List<List<TemporalOperatorBound>> groupedBoundList = TemporalOperatorBounds.groupBoundsDiscreteTime(bounds);
for (List<TemporalOperatorBound> groupedBounds : groupedBoundList) {
if (groupedBounds.get(0).isRewardBound()) {
String rewStructName = mc.getModulesFile().getRewardStructNames().get(groupedBounds.get(0).getResolvedRewardStructIndex());
prism.getLog().println("Transform to incorporate counter for reward '" + rewStructName + "' and " + groupedBounds);
} else {
prism.getLog().println("Transform to incorporate counter for steps "+groupedBounds);
}
ModelExpressionTransformation<M, M> current;
if (nested == null) {
current = new CounterTransformation<M>(mc, originalModel, originalExpression, groupedBounds, statesOfInterest);
nested = current;
} else {
current = new CounterTransformation<M>(mc, nested.getTransformedModel(), nested.getTransformedExpression(), groupedBounds, nested.getTransformedStatesOfInterest());
nested = new ModelExpressionTransformationNested<M, M, M>(nested, current);
}
prism.getLog().println("Transformed "+nested.getTransformedModel().getModelType()+": ");
nested.getTransformedModel().printTransInfo(prism.getLog());
/* try {
prism.exportTransToFile(nested.getTransformedModel(), true, Prism.EXPORT_DOT_STATES, new java.io.File("t.dot"));
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}*/
prism.getLog().println("Transformed expression: "+ nested.getTransformedExpression());
}
return nested;
}
@Override
public void clear()
{
product.clear();
}
}

8
prism/src/prism/ProbModelTransformationOperator.java

@ -106,13 +106,13 @@ public abstract class ProbModelTransformationOperator
* Get the transformed transition function.
* <br>[ REFS: <i>result</i>, DEREFS: <i>none</i> ]
*/
public abstract JDDNode getTransformedTrans();
public abstract JDDNode getTransformedTrans() throws PrismException;
/**
* Get the transformed start function.
* <br>[ REFS: <i>result</i>, DEREFS: <i>none</i> ]
*/
public abstract JDDNode getTransformedStart();
public abstract JDDNode getTransformedStart() throws PrismException;
/**
* Get the transformed state reward relation, given the old reward relation.
@ -120,7 +120,7 @@ public abstract class ProbModelTransformationOperator
* Default implementation: Return the old reward relation, unchanged.
* <br>[ REFS: <i>result</i>, DEREFS: <i>none</i> ]
*/
public JDDNode getTransformedStateReward(JDDNode oldReward)
public JDDNode getTransformedStateReward(JDDNode oldReward) throws PrismException
{
return oldReward.copy();
}
@ -131,7 +131,7 @@ public abstract class ProbModelTransformationOperator
* Default implementation: Return the old reward relation, unchanged.
* <br>[ REFS: <i>result</i>, DEREFS: <i>none</i> ]
*/
public JDDNode getTransformedTransReward(JDDNode oldReward)
public JDDNode getTransformedTransReward(JDDNode oldReward) throws PrismException
{
return oldReward.copy();
}

95
prism/src/prism/RewardCounterProduct.java

@ -0,0 +1,95 @@
package prism;
import java.util.List;
import jdd.JDD;
import jdd.JDDNode;
import jdd.JDDVars;
public class RewardCounterProduct<M extends Model> extends Product<M>
{
private int limit;
private RewardCounterTransformationAdd transform;
private RewardCounterProduct(M originalModel,
M productModel,
RewardCounterTransformationAdd transform,
JDDNode productStatesOfInterest,
JDDVars automatonRowVars) {
super(productModel, originalModel, productStatesOfInterest, automatonRowVars);
this.transform = transform;
this.limit = transform.getLimit();
}
@Override
public void clear()
{
super.clear();
transform.clear();
}
/**
* Get the states in the product for a given accumulated reward.
* If acc_reward is >= limit, then the states with rewards beyond the
* limit are returned.
*/
public JDDNode getStatesWithAccumulatedReward(int acc_reward) {
if (acc_reward >= transform.getLimit()) {
acc_reward = transform.getLimit();
}
return JDD.And(productModel.getReach().copy(),
transform.encodeInt(acc_reward, false));
}
/**
* Get the states in the product inside a given integer bound.
*/
public JDDNode getStatesWithAccumulatedRewardInBound(IntegerBound bound) {
JDDNode result = JDD.Constant(0);
for (int r=0; r<=bound.getMaximalInterestingValue(); r++) {
if (bound.isInBounds(r)) {
result = JDD.Or(result, getStatesWithAccumulatedReward(r));
}
}
return result;
}
/**
* Generate the product of a MDP with an accumulated reward counter.
* The counter has the range [0,limit], with saturation semantics for accumulated
* rewards >=limit.
* @param originalModel the MDP
* @param rewards integer MCRewards
* @param limit the saturation value for the counter
* @param statesOfInterest the set of state of interest, the starting point for the counters
* @return
* @throws PrismException
*/
static public RewardCounterProduct<NondetModel> generate(PrismComponent parent, NondetModel originalModel, JDDNode trRewards, int limit, JDDNode statesOfInterest) throws PrismException {
TransitionsByRewardsInfo info = new TransitionsByRewardsInfo(parent, originalModel, trRewards);
RewardCounterTransformationAdd transform = new RewardCounterTransformationAdd(originalModel, info, limit, statesOfInterest);
NondetModel transformedModel = originalModel.getTransformed(transform);
JDDNode productStatesOfInterest = transformedModel.getStart().copy();
return new RewardCounterProduct<NondetModel>(originalModel, transformedModel, transform, productStatesOfInterest, transform.getExtraRowVars().copy());
}
/**
* Get the states in the product DTMC inside the conjunction of integer bound.
*/
JDDNode getStatesWithAccumulatedRewardInBoundConjunction(List<IntegerBound> bounds) {
JDDNode result = JDD.Constant(0);
for (int r=0; r<=limit; r++) {
//System.out.println("r="+r+" is in bound?");
if (IntegerBound.isInBoundForConjunction(bounds, r)) {
//System.out.println("r="+r+" is in bound");
JDDNode accStates = getStatesWithAccumulatedReward(r);
// JDD.PrintMinterms(new PrismFileLog("stdout"), accStates.copy(), "accStates");
result = JDD.Or(result, accStates);
}
}
return result;
}
}

221
prism/src/prism/RewardCounterTransformationAdd.java

@ -0,0 +1,221 @@
package prism;
import java.util.BitSet;
import jdd.JDD;
import jdd.JDDNode;
import jdd.JDDVars;
public class RewardCounterTransformationAdd extends ProbModelTransformationOperator {
private int bits;
private int limit;
private int maxRepresentable;
private TransitionsByRewardsInfo info;
private JDDNode statesOfInterest;
// private PrismLog log = new PrismFileLog("stdout");
private boolean msbFirst = true;
/** Count rewards from [0,limit]. All values >= limit are encoded by limit */
public RewardCounterTransformationAdd(ProbModel model,
TransitionsByRewardsInfo info,
int limit,
JDDNode statesOfInterest) {
super(model);
this.info = info;
this.limit = limit;
this.statesOfInterest = statesOfInterest;
//log.println("Limit = "+limit);
bits = (int) Math.ceil(PrismUtils.log2(limit+1));
maxRepresentable = (1<<bits) - 1;
}
public void clear()
{
super.clear();
info.clear();
if (statesOfInterest != null)
JDD.Deref(statesOfInterest);
}
private int bitIndex2Var(int i)
{
if (msbFirst) {
return bits-i-1;
} else {
return i;
}
}
@Override
public int getExtraStateVariableCount() {
return bits;
}
public int getLimit() {
return limit;
}
@Override
public JDDNode getTransformedTrans() throws PrismException {
JDDNode newTrans = JDD.Constant(0);
for (int rew : info.getOccuringRewards()) {
JDDNode tr_rew = info.getTransitionsWithReward(rew);
JDDNode tr_rew_with_counter =
JDD.Apply(JDD.TIMES, tr_rew,
adder(extraRowVars, extraColVars, rew));
// JDD.PrintMinterms(log, tr_rew_with_counter.copy(), "tr_rew_with_counter ("+rew+")");
newTrans = JDD.Apply(JDD.MAX, newTrans, tr_rew_with_counter);
}
return newTrans;
}
@Override
public JDDNode getTransformedStart() {
JDDNode newStart = JDD.And(statesOfInterest.copy(),
encodeInt(0, false));
return newStart;
}
public JDDVars getExtraRowVars() {
return extraRowVars;
}
public JDDNode saturated(boolean col) {
int max = (1 << bits) - 1;
//log.println("Max = "+max);
JDDNode result = JDD.Constant(0);
for (int i = limit; i <= max; i++) {
JDDNode iDD = encodeInt(i, col);
//JDD.PrintMinterms(log, iDD, "i="+i);
result = JDD.Or(result, iDD);
}
return result;
}
public int decodeInt(BitSet bitset) {
long[] v = bitset.toLongArray();
if (v.length == 0) {
return 0;
} else if (v.length > 1 || v[0] > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Integer value out of range");
}
return (int)v[0];
}
public JDDNode encodeInt(int value, boolean col) {
if (value < 0)
throw new IllegalArgumentException("Can not encode negative integer");
JDDVars vars = col ? extraColVars : extraRowVars;
BitSet vBits = BitSet.valueOf(new long[]{value});
//log.println(vBits);
if (value > maxRepresentable) {
throw new IllegalArgumentException("Value "+value+" is out of range");
}
JDDNode result = JDD.Constant(1);
for (int i=0; i < vars.n(); i++) {
if (vBits.get(i))
result = JDD.And(result, vars.getVar(bitIndex2Var(i)).copy());
else
result = JDD.And(result, JDD.Not(vars.getVar(bitIndex2Var(i)).copy()));
}
return result;
}
public JDDNode getStatesWithAccumulatedReward(int r) {
if (r >= limit) {
return encodeInt(limit, false);
} else {
return encodeInt(r, false);
}
}
private JDDNode adder(JDDVars row, JDDVars col, int summand) throws PrismException {
JDDNode result;
if (summand < 0) {
throw new IllegalArgumentException("Can not create adder for negative summand");
}
if (row.n() != col.n()) {
throw new IllegalArgumentException("Can not create adder for different number of variables");
}
//log.println("Summand = "+summand+", bits = "+bits);
if (summand >= limit) {
// -> limit_next
return encodeInt(limit, true);
}
// convert summand to BitSet
BitSet summandBits = BitSet.valueOf(new long[]{summand});
JDDNode nextValues = JDD.Constant(1);
JDDNode carry = JDD.Constant(0.0);
// for all the bits (0, ..., n-1)
for (int i = 0; i < row.n(); i++) {
// x = i-th bit in the row vector
JDDNode x = row.getVar(bitIndex2Var(i)).copy();
// y = i-th bit of the summand
JDDNode y = summandBits.get(i) ? JDD.Constant(1.0) : JDD.Constant(0.0);
JDDNode z = JDD.Xor(JDD.Xor(x.copy(), y.copy()), carry.copy());
nextValues = JDD.And(nextValues, JDD.Equiv(col.getVar(bitIndex2Var(i)).copy(), z));
carry = JDD.Or(JDD.Or(JDD.And(x.copy(), carry.copy()),
JDD.And(y.copy(), carry)),
JDD.And(x.copy(), y.copy()));
JDD.Deref(x);
JDD.Deref(y);
}
JDDNode saturated_now = saturated(false);
//JDD.PrintMinterms(log, saturated_now.copy(), "saturated_now");
JDDNode saturated_next = JDD.And(nextValues.copy(), saturated(true));
//JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (1)");
saturated_next = JDD.ThereExists(saturated_next, extraColVars);
//JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (2)");
saturated_next = JDD.Or(saturated_next, carry.copy());
//JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (3)");
JDDNode limit_next = encodeInt(limit, true);
//JDD.PrintMinterms(log, limit_next.copy(), "limit_next");
//JDD.PrintMinterms(qc.getLog(), negative_now, "negative_now");
//JDD.PrintMinterms(qc.getLog(), negative_next, "negative_next");
// result = saturated_now -> limit_next
result = JDD.Implies(saturated_now.copy(), limit_next.copy());
// result &= (!saturated_now & saturated_next) -> limit_next
result = JDD.And(result,
JDD.Implies(JDD.And(JDD.Not(saturated_now.copy()), saturated_next.copy()),
limit_next.copy()));
// result &= (!saturated_now & !satured_next) -> nextValues
result = JDD.And(result,
JDD.Implies(JDD.And(JDD.Not(saturated_now.copy()),
JDD.Not(saturated_next.copy())),
nextValues.copy()));
JDD.Deref(nextValues);
JDD.Deref(carry);
JDD.Deref(saturated_now);
JDD.Deref(saturated_next);
JDD.Deref(limit_next);
//JDD.PrintMinterms(qc.getLog(), result.copy(), "adder for "+summand);
//JDD.PrintMinterms(log, result, "adder("+summand+")");
return result;
}
}
Loading…
Cancel
Save