From 5f356b0bf2d964773e04a66bff37a09796f7f445 Mon Sep 17 00:00:00 2001 From: Joachim Klein Date: Fri, 12 Oct 2018 14:26:24 +0200 Subject: [PATCH] imported patch symb-counter-transform-transformations.patch --- prism/src/prism/CounterTransformation.java | 245 ++++++++++++++++++ .../ProbModelTransformationOperator.java | 8 +- prism/src/prism/RewardCounterProduct.java | 95 +++++++ .../prism/RewardCounterTransformationAdd.java | 221 ++++++++++++++++ 4 files changed, 565 insertions(+), 4 deletions(-) create mode 100644 prism/src/prism/CounterTransformation.java create mode 100644 prism/src/prism/RewardCounterProduct.java create mode 100644 prism/src/prism/RewardCounterTransformationAdd.java diff --git a/prism/src/prism/CounterTransformation.java b/prism/src/prism/CounterTransformation.java new file mode 100644 index 00000000..c9d074ce --- /dev/null +++ b/prism/src/prism/CounterTransformation.java @@ -0,0 +1,245 @@ +package prism; + +import java.util.ArrayList; +import java.util.List; + +import jdd.JDD; +import jdd.JDDNode; +import parser.ast.Expression; +import parser.ast.ExpressionReward; +import parser.ast.RewardStruct; +import parser.ast.TemporalOperatorBound; +import parser.ast.TemporalOperatorBounds; +import parser.visitor.ReplaceBound; +import prism.IntegerBound; +import prism.PrismException; + +public class CounterTransformation implements ModelExpressionTransformation { + private Expression originalExpression; + private Expression transformedExpression; + private M originalModel; + private RewardCounterProduct product; + + StateModelChecker mc; + + /** + * The originalExpression will be modified! + * @param mc + * @param originalModel + * @param originalExpression + * @param bound + * @param statesOfInterest + * @throws PrismException + */ + public CounterTransformation(StateModelChecker mc, + M originalModel, Expression originalExpression, + TemporalOperatorBound bound, + JDDNode statesOfInterest) throws PrismException { + this.originalModel = originalModel; + this.originalExpression = originalExpression.deepCopy(); + this.mc = mc; + + transformedExpression = originalExpression; + doTransformation(originalModel, bound, statesOfInterest); + } + + /** + * The originalExpression will be modified! + * @param mc + * @param originalModel + * @param originalExpression + * @param bound + * @param statesOfInterest + * @throws PrismException + */ + public CounterTransformation(StateModelChecker mc, + M originalModel, Expression originalExpression, + List bounds, + JDDNode statesOfInterest) throws PrismException { + this.originalModel = originalModel; + this.originalExpression = originalExpression.deepCopy(); + this.mc = mc; + + transformedExpression = originalExpression; + doTransformation(originalModel, bounds, statesOfInterest); + } + + + @Override + public Expression getTransformedExpression() { + return transformedExpression; + } + + @Override + public M getTransformedModel() { + return product.getTransformedModel(); + } + + @Override + public JDDNode getTransformedStatesOfInterest() { + return product.getTransformedStatesOfInterest(); + } + + @Override + public M getOriginalModel() { + return originalModel; + } + + @Override + public Expression getOriginalExpression() { + return originalExpression; + } + + @Override + public StateValues projectToOriginalModel(StateValues svTransformedModel) + throws PrismException { + return product.projectToOriginalModel(svTransformedModel); + } + + private void doTransformation(M originalModel, TemporalOperatorBound bound, JDDNode statesOfInterest) throws PrismException { + List bounds = new ArrayList(); + bounds.add(bound); + doTransformation(originalModel, bounds, statesOfInterest); + } + + private void doTransformation(M originalModel, List bounds, JDDNode statesOfInterest) throws PrismException { + if (originalModel instanceof NondetModel) { + doTransformation((NondetModel)originalModel, bounds, statesOfInterest); + } else { + throw new PrismException("Counter-Transformation is not supported for "+originalModel.getClass()); + } + } + + + + @SuppressWarnings("unchecked") + private void doTransformation(NondetModel originalModel, List bounds, JDDNode statesOfInterest) throws PrismException + { + List intBounds = new ArrayList(); + + if (bounds.isEmpty()) { + throw new IllegalArgumentException("Can not do counter transformation without any bounds."); + } + + for (TemporalOperatorBound bound : bounds) { + IntegerBound intBound = IntegerBound.fromTemporalOperatorBound(bound, mc.constantValues, true); + intBounds.add(intBound); + + if (!bound.hasSameDomainDiscreteTime(bounds.get(0))) { + throw new IllegalArgumentException("Can only do counter transformation for bounds with same domain."); + } + } + JDDNode trRewards = null; + + switch (bounds.get(0).getBoundType()) { + case REWARD_BOUND: { + // Get reward info + Object rsi = bounds.get(0).getRewardStructureIndex(); + JDDNode srew = mc.getStateRewardsByIndexObject(rsi, originalModel, mc.constantValues).copy(); + JDDNode trew = mc.getTransitionRewardsByIndexObject(rsi, originalModel, mc.constantValues).copy(); + + trRewards = JDD.Apply(JDD.PLUS, srew, trew); + break; + } + case DEFAULT_BOUND: + case STEP_BOUND: + case TIME_BOUND: + // a time/step bound, use constant reward structure assigning 1 to each state + trRewards = JDD.Constant(1); + break; + } + + if (trRewards == null) { + throw new PrismException("Couldn't generate reward information."); + } + + int saturation_limit = IntegerBound.getMaximalInterestingValueForConjunction(intBounds); + + product = (RewardCounterProduct) RewardCounterProduct.generate(mc.prism, originalModel, trRewards, saturation_limit, statesOfInterest); + + // add 'in_bound-x' label + JDDNode statesInBound = product.getStatesWithAccumulatedRewardInBoundConjunction(intBounds); + //JDD.PrintMinterms(mc.prism.getMainLog(), statesInBound.copy(), "statesInBound (1)"); + statesInBound = JDD.And(statesInBound, product.getTransformedModel().getReach().copy()); + //JDD.PrintMinterms(mc.prism.getMainLog(), statesInBound.copy(), "statesInBound (2)"); + String labelInBound = ((NondetModel)product.getTransformedModel()).addUniqueLabelDD("in_bound", statesInBound, mc.getDefinedLabelNames()); + + // transform expression + for (TemporalOperatorBound bound : bounds) { + ReplaceBound replace = new ReplaceBound(bound, labelInBound); + transformedExpression = (Expression) transformedExpression.accept(replace); + + if (!replace.wasSuccessfull()) { + throw new PrismException("Replacing bound was not successfull."); + } + } + } + + + public static ModelExpressionTransformation replaceBoundsWithCounters(StateModelChecker mc, + M originalModel, Expression originalExpression, + List bounds, + JDDNode statesOfInterest) throws PrismException { + + if (bounds.isEmpty()) { + throw new PrismException("No bounds to replace!"); + } + + if (!originalExpression.isSimplePathFormula()) { + throw new PrismException("Replacing bounds is only supported in simple path formulas."); + } + + Prism prism = mc.prism; + + // TODO: Check nesting depth = 1 + + ModelExpressionTransformation nested = null; + for (TemporalOperatorBound bound : bounds) { + // resolve RewardStruct for reward bounds + if (bound.isRewardBound()) { + int r = ExpressionReward.getRewardStructIndexByIndexObject(bound.getRewardStructureIndex(), mc.prism.getPRISMModel(), mc.constantValues); + bound.setResolvedRewardStructIndex(r); + } + } + + List> groupedBoundList = TemporalOperatorBounds.groupBoundsDiscreteTime(bounds); + + for (List groupedBounds : groupedBoundList) { + if (groupedBounds.get(0).isRewardBound()) { + String rewStructName = mc.getModulesFile().getRewardStructNames().get(groupedBounds.get(0).getResolvedRewardStructIndex()); + prism.getLog().println("Transform to incorporate counter for reward '" + rewStructName + "' and " + groupedBounds); + } else { + prism.getLog().println("Transform to incorporate counter for steps "+groupedBounds); + } + + ModelExpressionTransformation current; + + if (nested == null) { + current = new CounterTransformation(mc, originalModel, originalExpression, groupedBounds, statesOfInterest); + nested = current; + } else { + current = new CounterTransformation(mc, nested.getTransformedModel(), nested.getTransformedExpression(), groupedBounds, nested.getTransformedStatesOfInterest()); + nested = new ModelExpressionTransformationNested(nested, current); + } + + prism.getLog().println("Transformed "+nested.getTransformedModel().getModelType()+": "); + nested.getTransformedModel().printTransInfo(prism.getLog()); +/* try { + prism.exportTransToFile(nested.getTransformedModel(), true, Prism.EXPORT_DOT_STATES, new java.io.File("t.dot")); + } catch (FileNotFoundException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + }*/ + prism.getLog().println("Transformed expression: "+ nested.getTransformedExpression()); + } + + return nested; + } + + @Override + public void clear() + { + product.clear(); + } + +} diff --git a/prism/src/prism/ProbModelTransformationOperator.java b/prism/src/prism/ProbModelTransformationOperator.java index 0d1d02c1..5731cd2c 100644 --- a/prism/src/prism/ProbModelTransformationOperator.java +++ b/prism/src/prism/ProbModelTransformationOperator.java @@ -106,13 +106,13 @@ public abstract class ProbModelTransformationOperator * Get the transformed transition function. *
[ REFS: result, DEREFS: none ] */ - public abstract JDDNode getTransformedTrans(); + public abstract JDDNode getTransformedTrans() throws PrismException; /** * Get the transformed start function. *
[ REFS: result, DEREFS: none ] */ - public abstract JDDNode getTransformedStart(); + public abstract JDDNode getTransformedStart() throws PrismException; /** * Get the transformed state reward relation, given the old reward relation. @@ -120,7 +120,7 @@ public abstract class ProbModelTransformationOperator * Default implementation: Return the old reward relation, unchanged. *
[ REFS: result, DEREFS: none ] */ - public JDDNode getTransformedStateReward(JDDNode oldReward) + public JDDNode getTransformedStateReward(JDDNode oldReward) throws PrismException { return oldReward.copy(); } @@ -131,7 +131,7 @@ public abstract class ProbModelTransformationOperator * Default implementation: Return the old reward relation, unchanged. *
[ REFS: result, DEREFS: none ] */ - public JDDNode getTransformedTransReward(JDDNode oldReward) + public JDDNode getTransformedTransReward(JDDNode oldReward) throws PrismException { return oldReward.copy(); } diff --git a/prism/src/prism/RewardCounterProduct.java b/prism/src/prism/RewardCounterProduct.java new file mode 100644 index 00000000..f8af02e0 --- /dev/null +++ b/prism/src/prism/RewardCounterProduct.java @@ -0,0 +1,95 @@ +package prism; + + +import java.util.List; + +import jdd.JDD; +import jdd.JDDNode; +import jdd.JDDVars; + +public class RewardCounterProduct extends Product +{ + private int limit; + private RewardCounterTransformationAdd transform; + + private RewardCounterProduct(M originalModel, + M productModel, + RewardCounterTransformationAdd transform, + JDDNode productStatesOfInterest, + JDDVars automatonRowVars) { + super(productModel, originalModel, productStatesOfInterest, automatonRowVars); + this.transform = transform; + this.limit = transform.getLimit(); + } + + @Override + public void clear() + { + super.clear(); + transform.clear(); + } + + /** + * Get the states in the product for a given accumulated reward. + * If acc_reward is >= limit, then the states with rewards beyond the + * limit are returned. + */ + public JDDNode getStatesWithAccumulatedReward(int acc_reward) { + if (acc_reward >= transform.getLimit()) { + acc_reward = transform.getLimit(); + } + return JDD.And(productModel.getReach().copy(), + transform.encodeInt(acc_reward, false)); + } + + /** + * Get the states in the product inside a given integer bound. + */ + public JDDNode getStatesWithAccumulatedRewardInBound(IntegerBound bound) { + JDDNode result = JDD.Constant(0); + for (int r=0; r<=bound.getMaximalInterestingValue(); r++) { + if (bound.isInBounds(r)) { + result = JDD.Or(result, getStatesWithAccumulatedReward(r)); + } + } + return result; + } + + /** + * Generate the product of a MDP with an accumulated reward counter. + * The counter has the range [0,limit], with saturation semantics for accumulated + * rewards >=limit. + * @param originalModel the MDP + * @param rewards integer MCRewards + * @param limit the saturation value for the counter + * @param statesOfInterest the set of state of interest, the starting point for the counters + * @return + * @throws PrismException + */ + static public RewardCounterProduct generate(PrismComponent parent, NondetModel originalModel, JDDNode trRewards, int limit, JDDNode statesOfInterest) throws PrismException { + TransitionsByRewardsInfo info = new TransitionsByRewardsInfo(parent, originalModel, trRewards); + RewardCounterTransformationAdd transform = new RewardCounterTransformationAdd(originalModel, info, limit, statesOfInterest); + + NondetModel transformedModel = originalModel.getTransformed(transform); + JDDNode productStatesOfInterest = transformedModel.getStart().copy(); + return new RewardCounterProduct(originalModel, transformedModel, transform, productStatesOfInterest, transform.getExtraRowVars().copy()); + } + + + /** + * Get the states in the product DTMC inside the conjunction of integer bound. + */ + JDDNode getStatesWithAccumulatedRewardInBoundConjunction(List bounds) { + JDDNode result = JDD.Constant(0); + for (int r=0; r<=limit; r++) { + //System.out.println("r="+r+" is in bound?"); + if (IntegerBound.isInBoundForConjunction(bounds, r)) { + //System.out.println("r="+r+" is in bound"); + JDDNode accStates = getStatesWithAccumulatedReward(r); + // JDD.PrintMinterms(new PrismFileLog("stdout"), accStates.copy(), "accStates"); + result = JDD.Or(result, accStates); + } + } + return result; + } +} diff --git a/prism/src/prism/RewardCounterTransformationAdd.java b/prism/src/prism/RewardCounterTransformationAdd.java new file mode 100644 index 00000000..fe50d0d5 --- /dev/null +++ b/prism/src/prism/RewardCounterTransformationAdd.java @@ -0,0 +1,221 @@ +package prism; + +import java.util.BitSet; + +import jdd.JDD; +import jdd.JDDNode; +import jdd.JDDVars; + +public class RewardCounterTransformationAdd extends ProbModelTransformationOperator { + private int bits; + private int limit; + private int maxRepresentable; + private TransitionsByRewardsInfo info; + private JDDNode statesOfInterest; + // private PrismLog log = new PrismFileLog("stdout"); + private boolean msbFirst = true; + + /** Count rewards from [0,limit]. All values >= limit are encoded by limit */ + public RewardCounterTransformationAdd(ProbModel model, + TransitionsByRewardsInfo info, + int limit, + JDDNode statesOfInterest) { + super(model); + + this.info = info; + this.limit = limit; + this.statesOfInterest = statesOfInterest; + + //log.println("Limit = "+limit); + bits = (int) Math.ceil(PrismUtils.log2(limit+1)); + maxRepresentable = (1< maxRepresentable) { + throw new IllegalArgumentException("Value "+value+" is out of range"); + } + + JDDNode result = JDD.Constant(1); + for (int i=0; i < vars.n(); i++) { + if (vBits.get(i)) + result = JDD.And(result, vars.getVar(bitIndex2Var(i)).copy()); + else + result = JDD.And(result, JDD.Not(vars.getVar(bitIndex2Var(i)).copy())); + } + + return result; + } + + public JDDNode getStatesWithAccumulatedReward(int r) { + if (r >= limit) { + return encodeInt(limit, false); + } else { + return encodeInt(r, false); + } + } + + private JDDNode adder(JDDVars row, JDDVars col, int summand) throws PrismException { + JDDNode result; + + if (summand < 0) { + throw new IllegalArgumentException("Can not create adder for negative summand"); + } + if (row.n() != col.n()) { + throw new IllegalArgumentException("Can not create adder for different number of variables"); + } + + //log.println("Summand = "+summand+", bits = "+bits); + + if (summand >= limit) { + // -> limit_next + return encodeInt(limit, true); + } + + // convert summand to BitSet + BitSet summandBits = BitSet.valueOf(new long[]{summand}); + JDDNode nextValues = JDD.Constant(1); + JDDNode carry = JDD.Constant(0.0); + + // for all the bits (0, ..., n-1) + for (int i = 0; i < row.n(); i++) { + // x = i-th bit in the row vector + JDDNode x = row.getVar(bitIndex2Var(i)).copy(); + // y = i-th bit of the summand + JDDNode y = summandBits.get(i) ? JDD.Constant(1.0) : JDD.Constant(0.0); + + JDDNode z = JDD.Xor(JDD.Xor(x.copy(), y.copy()), carry.copy()); + nextValues = JDD.And(nextValues, JDD.Equiv(col.getVar(bitIndex2Var(i)).copy(), z)); + carry = JDD.Or(JDD.Or(JDD.And(x.copy(), carry.copy()), + JDD.And(y.copy(), carry)), + JDD.And(x.copy(), y.copy())); + JDD.Deref(x); + JDD.Deref(y); + } + + JDDNode saturated_now = saturated(false); + //JDD.PrintMinterms(log, saturated_now.copy(), "saturated_now"); + JDDNode saturated_next = JDD.And(nextValues.copy(), saturated(true)); + //JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (1)"); + saturated_next = JDD.ThereExists(saturated_next, extraColVars); + //JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (2)"); + saturated_next = JDD.Or(saturated_next, carry.copy()); + //JDD.PrintMinterms(log, saturated_next.copy(), "saturated_next (3)"); + JDDNode limit_next = encodeInt(limit, true); + //JDD.PrintMinterms(log, limit_next.copy(), "limit_next"); + + //JDD.PrintMinterms(qc.getLog(), negative_now, "negative_now"); + //JDD.PrintMinterms(qc.getLog(), negative_next, "negative_next"); + + // result = saturated_now -> limit_next + result = JDD.Implies(saturated_now.copy(), limit_next.copy()); + + // result &= (!saturated_now & saturated_next) -> limit_next + result = JDD.And(result, + JDD.Implies(JDD.And(JDD.Not(saturated_now.copy()), saturated_next.copy()), + limit_next.copy())); + + // result &= (!saturated_now & !satured_next) -> nextValues + result = JDD.And(result, + JDD.Implies(JDD.And(JDD.Not(saturated_now.copy()), + JDD.Not(saturated_next.copy())), + nextValues.copy())); + + JDD.Deref(nextValues); + JDD.Deref(carry); + JDD.Deref(saturated_now); + JDD.Deref(saturated_next); + JDD.Deref(limit_next); + + //JDD.PrintMinterms(qc.getLog(), result.copy(), "adder for "+summand); + + //JDD.PrintMinterms(log, result, "adder("+summand+")"); + return result; + } +}