diff --git a/prism/include/DoubleVector.h b/prism/include/DoubleVector.h index 8790c6e7..15e87aee 100644 --- a/prism/include/DoubleVector.h +++ b/prism/include/DoubleVector.h @@ -143,6 +143,14 @@ JNIEXPORT jdouble JNICALL Java_dv_DoubleVector_DV_1SumOverBDD JNIEXPORT jdouble JNICALL Java_dv_DoubleVector_DV_1SumOverMTBDD (JNIEnv *, jobject, jlong, jlong, jlong, jint, jlong); +/* + * Class: dv_DoubleVector + * Method: DV_SumOverDDVars + * Signature: (JJJIIIJJ)V + */ +JNIEXPORT void JNICALL Java_dv_DoubleVector_DV_1SumOverDDVars + (JNIEnv *, jobject, jlong, jlong, jlong, jint, jint, jint, jlong, jlong); + /* * Class: dv_DoubleVector * Method: DV_BDDGreaterThanEquals diff --git a/prism/include/dv.h b/prism/include/dv.h index 812837aa..62b31f45 100644 --- a/prism/include/dv.h +++ b/prism/include/dv.h @@ -73,6 +73,7 @@ EXPORT double min_double_vector_over_bdd(DdManager *ddman, double *vec, DdNode * EXPORT double max_double_vector_over_bdd(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, ODDNode *odd); EXPORT double sum_double_vector_over_bdd(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, ODDNode *odd); EXPORT double sum_double_vector_over_mtbdd(DdManager *ddman, double *vec, DdNode *mult, DdNode **vars, int num_vars, ODDNode *odd); +EXPORT void sum_double_vector_over_dd_vars(DdManager *ddman, double *vec, double *vec2, DdNode **vars, int num_vars, int first_var, int last_var, ODDNode *odd, ODDNode *odd2); EXPORT DistVector *double_vector_to_dist(double *v, int n); EXPORT void free_dist_vector(DistVector *&dv); diff --git a/prism/src/dv/DoubleVector.cc b/prism/src/dv/DoubleVector.cc index b40802a5..110ba943 100644 --- a/prism/src/dv/DoubleVector.cc +++ b/prism/src/dv/DoubleVector.cc @@ -373,6 +373,33 @@ jlong __pointer odd //------------------------------------------------------------------------------ +JNIEXPORT void JNICALL Java_dv_DoubleVector_DV_1SumOverDDVars +( +JNIEnv *env, +jobject obj, +jlong __pointer vector, +jlong __pointer vector2, +jlong __pointer vars, +jint num_vars, +jint first_var, +jint last_var, +jlong __pointer odd, +jlong __pointer odd2 +) +{ + sum_double_vector_over_dd_vars( + ddman, + jlong_to_double(vector), + jlong_to_double(vector2), + jlong_to_DdNode_array(vars), num_vars, + first_var, last_var, + jlong_to_ODDNode(odd), + jlong_to_ODDNode(odd2) + ); +} + +//------------------------------------------------------------------------------ + JNIEXPORT jlong __pointer JNICALL Java_dv_DoubleVector_DV_1BDDGreaterThanEquals ( JNIEnv *env, diff --git a/prism/src/dv/DoubleVector.java b/prism/src/dv/DoubleVector.java index 3bf9699d..aca2bca7 100644 --- a/prism/src/dv/DoubleVector.java +++ b/prism/src/dv/DoubleVector.java @@ -202,6 +202,17 @@ public class DoubleVector return DV_SumOverMTBDD(v, mult.ptr(), vars.array(), vars.n(), odd.ptr()); } + // sum up the elements of a double array, over a subset of its dd vars + // the dd var subset must be a continuous range of vars, identified by indices: first_var, last_var + // store the result in a new DoubleVector (whose indices are given by odd2) + private native void DV_SumOverDDVars(long vec, long vec2, long vars, int num_vars, int first_var, int last_var, long odd, long odd2); + public DoubleVector sumOverDDVars(JDDVars vars, ODDNode odd, ODDNode odd2, int first_var, int last_var) + { + DoubleVector dv2 = new DoubleVector((int)(odd2.getEOff() + odd2.getTOff())); + DV_SumOverDDVars(v, dv2.v, vars.array(), vars.n(), first_var, last_var, odd.ptr(), odd2.ptr()); + return dv2; + } + // generate bdd (from an interval: relative operator and bound) private native long DV_BDDGreaterThanEquals(long v, double bound, long vars, int num_vars, long odd); private native long DV_BDDGreaterThan(long v, double bound, long vars, int num_vars, long odd); diff --git a/prism/src/dv/dv.cc b/prism/src/dv/dv.cc index 2934b06d..c6f8c0bd 100644 --- a/prism/src/dv/dv.cc +++ b/prism/src/dv/dv.cc @@ -39,6 +39,7 @@ static double min_double_vector_over_bdd_rec(DdManager *ddman, double *vec, DdNo static double max_double_vector_over_bdd_rec(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, int level, ODDNode *odd, long o); static double sum_double_vector_over_bdd_rec(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, int level, ODDNode *odd, long o); static double sum_double_vector_over_mtbdd_rec(DdManager *ddman, double *vec, DdNode *mult, DdNode **vars, int num_vars, int level, ODDNode *odd, long o); +static void sum_double_vector_over_dd_vars_rec(DdManager *ddman, double *vec, double *vec2, DdNode **vars, int num_vars, int level, int first_var, int last_var, ODDNode *odd, ODDNode *odd2, long o, long o2); //------------------------------------------------------------------------------ @@ -384,6 +385,40 @@ double sum_double_vector_over_mtbdd_rec(DdManager *ddman, double *vec, DdNode *m } } +//------------------------------------------------------------------------------ + +// sum up the elements of a double array, over a subset of its dd vars +// the dd var subset must be a continuous range of vars, identified by indices: first_var, last_var +// store the result in the vector vec2 + +EXPORT void sum_double_vector_over_dd_vars(DdManager *ddman, double *vec, double *vec2, DdNode **vars, int num_vars, int first_var, int last_var, ODDNode *odd, ODDNode *odd2) +{ + return sum_double_vector_over_dd_vars_rec(ddman, vec, vec2, vars, num_vars, 0, first_var, last_var, odd, odd2, 0, 0); +} + +void sum_double_vector_over_dd_vars_rec(DdManager *ddman, double *vec, double *vec2, DdNode **vars, int num_vars, int level, int first_var, int last_var, ODDNode *odd, ODDNode *odd2, long o, long o2) +{ + if (level == num_vars) { + vec2[o2] += vec[o]; + } + else { + if (odd->eoff > 0) { + if (vars[level]->index >= first_var && vars[level]->index <= last_var) { + sum_double_vector_over_dd_vars_rec(ddman, vec, vec2, vars, num_vars, level+1, first_var, last_var, odd->e, odd2, o, o2); + } else { + sum_double_vector_over_dd_vars_rec(ddman, vec, vec2, vars, num_vars, level+1, first_var, last_var, odd->e, odd2->e, o, o2); + } + } + if (odd->toff > 0) { + if (vars[level]->index >= first_var && vars[level]->index <= last_var) { + sum_double_vector_over_dd_vars_rec(ddman, vec, vec2, vars, num_vars, level+1, first_var, last_var, odd->t, odd2, o+odd->eoff, o2); + } else { + sum_double_vector_over_dd_vars_rec(ddman, vec, vec2, vars, num_vars, level+1, first_var, last_var, odd->t, odd2->t, o+odd->eoff, o2+odd2->eoff); + } + } + } +} + //----------------------------------------------------------------------------- // Converts an array of doubles to a a DistVector, which stores only one copy of each distinct double diff --git a/prism/src/parser/VarList.java b/prism/src/parser/VarList.java index 7d2d237e..0df06c56 100644 --- a/prism/src/parser/VarList.java +++ b/prism/src/parser/VarList.java @@ -35,42 +35,65 @@ import prism.PrismUtils; public class VarList { int numVars; - Vector names; - Vector lows; - Vector highs; - Vector ranges; - Vector rangeLogTwos; - Vector starts; - Vector modules; - Vector types; + Vector names; + Vector lows; + Vector highs; + Vector ranges; + Vector rangeLogTwos; + Vector starts; + Vector modules; + Vector types; public VarList() { numVars = 0; - names = new Vector(); - lows = new Vector(); - highs = new Vector(); - ranges = new Vector(); - rangeLogTwos = new Vector(); - starts = new Vector(); - modules = new Vector(); - types = new Vector(); + names = new Vector(); + lows = new Vector(); + highs = new Vector(); + ranges = new Vector(); + rangeLogTwos = new Vector(); + starts = new Vector(); + modules = new Vector(); + types = new Vector(); } + /** + * Add a new variable to the end of the VarList. + */ public void addVar(String n, int l, int h, int s, int m, int t) { int r, r2; - names.addElement(n); - lows.addElement(new Integer(l)); - highs.addElement(new Integer(h)); + names.add(n); + lows.add(l); + highs.add(h); r = h - l + 1; - ranges.addElement(new Integer(r)); + ranges.add(r); r2 = (int)Math.ceil(PrismUtils.log2(r)); - rangeLogTwos.addElement(new Integer(r2)); - starts.addElement(new Integer(s)); - modules.addElement(new Integer(m)); - types.addElement(new Integer(t)); + rangeLogTwos.add(r2); + starts.add(s); + modules.add(m); + types.add(t); + numVars++; + } + + /** + * Add a new variable at position i in the VarList. + */ + public void addVar(int i, String n, int l, int h, int s, int m, int t) + { + int r, r2; + + names.add(i, n); + lows.add(i, l); + highs.add(i, h); + r = h - l + 1; + ranges.add(i, r); + r2 = (int)Math.ceil(PrismUtils.log2(r)); + rangeLogTwos.add(i, r2); + starts.add(i, s); + modules.add(i, m); + types.add(i, t); numVars++; } @@ -128,6 +151,22 @@ public class VarList { return ((Integer)types.elementAt(i)).intValue(); } + + public Object clone() + { + VarList rv = new VarList(); + rv.numVars = numVars; + rv.names.addAll(names); + rv.lows.addAll(lows); + rv.highs.addAll(highs); + rv.ranges.addAll(ranges); + rv.rangeLogTwos.addAll(rangeLogTwos); + rv.starts.addAll(starts); + rv.modules.addAll(modules); + rv.types.addAll(types); + + return rv; + } } //------------------------------------------------------------------------------ diff --git a/prism/src/prism/LTLModelChecker.java b/prism/src/prism/LTLModelChecker.java index d332cbcf..6e44a0b1 100644 --- a/prism/src/prism/LTLModelChecker.java +++ b/prism/src/prism/LTLModelChecker.java @@ -30,10 +30,6 @@ package prism; import java.util.*; import jdd.*; -import dv.*; -import mtbdd.*; -import sparse.*; -import hybrid.*; import parser.*; import parser.ast.*; import jltl2ba.APElement; @@ -51,7 +47,7 @@ public class LTLModelChecker protected JDDVars draDDRowVars; protected JDDVars draDDColVars; - public LTLModelChecker(Prism prism, ModelChecker parent, Model model) throws PrismException + public LTLModelChecker(Prism prism, ModelChecker parent) throws PrismException { this.prism = prism; mainLog = prism.getMainLog(); @@ -63,13 +59,13 @@ public class LTLModelChecker * As an optimisation, model checking that results in true/false for all states is * converted to an actual true/false, and duplicate results are given the same label. */ - public Expression checkMaximalStateFormulas(ModelChecker mc, Model model, Expression expr, Vector labelDDs) + public Expression checkMaximalStateFormulas(Model model, Expression expr, Vector labelDDs) throws PrismException { // A state formula if (expr.getType() == Expression.BOOLEAN) { // Model check - JDDNode dd = mc.checkExpressionDD(expr); + JDDNode dd = parent.checkExpressionDD(expr); // Detect special cases (true, false) for optimisation if (dd.equals(JDD.ZERO)) { JDD.Deref(dd); @@ -94,66 +90,66 @@ public class LTLModelChecker else if (expr.getType() == Expression.PATH_BOOLEAN) { if (expr instanceof ExpressionBinaryOp) { ExpressionBinaryOp exprBinOp = (ExpressionBinaryOp) expr; - exprBinOp.setOperand1(checkMaximalStateFormulas(mc, model, exprBinOp.getOperand1(), labelDDs)); - exprBinOp.setOperand2(checkMaximalStateFormulas(mc, model, exprBinOp.getOperand2(), labelDDs)); + exprBinOp.setOperand1(checkMaximalStateFormulas(model, exprBinOp.getOperand1(), labelDDs)); + exprBinOp.setOperand2(checkMaximalStateFormulas(model, exprBinOp.getOperand2(), labelDDs)); } else if (expr instanceof ExpressionUnaryOp) { ExpressionUnaryOp exprUnOp = (ExpressionUnaryOp) expr; - exprUnOp.setOperand(checkMaximalStateFormulas(mc, model, exprUnOp.getOperand(), labelDDs)); + exprUnOp.setOperand(checkMaximalStateFormulas(model, exprUnOp.getOperand(), labelDDs)); } else if (expr instanceof ExpressionTemporal) { ExpressionTemporal exprTemp = (ExpressionTemporal) expr; if (exprTemp.getOperand1() != null) { - exprTemp.setOperand1(checkMaximalStateFormulas(mc, model, exprTemp.getOperand1(), labelDDs)); + exprTemp.setOperand1(checkMaximalStateFormulas(model, exprTemp.getOperand1(), labelDDs)); } if (exprTemp.getOperand2() != null) { - exprTemp.setOperand2(checkMaximalStateFormulas(mc, model, exprTemp.getOperand2(), labelDDs)); + exprTemp.setOperand2(checkMaximalStateFormulas(model, exprTemp.getOperand2(), labelDDs)); } } } return expr; } - public NondetModel constructProductModel(DRA dra, Model model, Vector labelDDs) + public NondetModel constructProductModel(DRA dra, Model model, Vector labelDDs) throws PrismException { - // Old model stuff - VarList varList; + // Existing model - dds, vars, etc. JDDVars varDDRowVars[]; JDDVars varDDColVars[]; JDDVars allDDRowVars; JDDVars allDDColVars; - JDDVars allDDNondetVars; Vector ddVarNames; - - JDDNode draDD, newTrans, newStart; + VarList varList; + // New (product) model - dds, vars, etc. + JDDNode newTrans, newStart; JDDVars newVarDDRowVars[], newVarDDColVars[]; JDDVars newAllDDRowVars, newAllDDColVars; Vector newDDVarNames; + VarList newVarList; String draVar; + // Misc int i, j, n; boolean before; - varList = model.getVarList(); + // Get details of old model varDDRowVars = model.getVarDDRowVars(); varDDColVars = model.getVarDDColVars(); allDDRowVars = model.getAllDDRowVars(); allDDColVars = model.getAllDDColVars(); - allDDNondetVars = ((NondetModel) model).getAllDDNondetVars(); ddVarNames = model.getDDVarNames(); - - // Build new variables and lists + varList = model.getVarList(); - // Create a (new, unique) name for the new variable that represents DRA states + // Create a (new, unique) name for the variable that will represent DRA states draVar = "_dra"; while (varList.getIndex(draVar) != -1) { draVar = "_" + draVar; } + // See how many new dd vars will be needed for DRA // and whether there is room to put them before rather than after the existing vars n = (int) Math.ceil(PrismUtils.log2(dra.size())); before = true; - if (allDDRowVars.getMinVarIndex() - allDDNondetVars.getMaxVarIndex() < 2 * n) { + if (allDDRowVars.getMinVarIndex() - ((NondetModel) model).getAllDDNondetVars().getMaxVarIndex() < 2 * n) { before = false; } - before = false; + // Create the new dd variables draDDRowVars = new JDDVars(); draDDColVars = new JDDVars(); @@ -163,10 +159,14 @@ public class LTLModelChecker for (i = 0; i < n; i++) { draDDRowVars.addVar(JDD.Var(j++)); draDDColVars.addVar(JDD.Var(j++)); - if (!before) { newDDVarNames.add(""); newDDVarNames.add(""); } - newDDVarNames.set(j - 2, draVar+"."+i); - newDDVarNames.set(j - 1, draVar+"'."+i); + if (!before) { + newDDVarNames.add(""); + newDDVarNames.add(""); + } + newDDVarNames.set(j - 2, draVar + "." + i); + newDDVarNames.set(j - 1, draVar + "'." + i); } + // Create/populate new lists newVarDDRowVars = new JDDVars[varDDRowVars.length + 1]; newVarDDColVars = new JDDVars[varDDRowVars.length + 1]; @@ -191,38 +191,62 @@ public class LTLModelChecker newAllDDRowVars.addVars(draDDRowVars); newAllDDColVars.addVars(draDDColVars); } + newVarList = (VarList)varList.clone(); + newVarList.addVar(before ? 0 : varList.getNumVars(), draVar, 0, dra.size() - 1, 0, 1, Expression.INT); + // Extra references (because will get derefed when new model is done with) + // TODO: tidy this up, make it corresond to model.clear() allDDRowVars.refAll(); allDDRowVars.refAll(); - allDDRowVars.refAll(); - allDDColVars.refAll(); allDDColVars.refAll(); allDDColVars.refAll(); + for (i = 0; i < model.getNumModules(); i++) { + model.getModuleDDRowVars(i).refAll(); + model.getModuleDDColVars(i).refAll(); + } draDDRowVars.refAll(); draDDColVars.refAll(); ((NondetModel) model).getAllDDSchedVars().refAll(); ((NondetModel) model).getAllDDSynchVars().refAll(); ((NondetModel) model).getAllDDChoiceVars().refAll(); - allDDNondetVars.refAll(); - - draDD = buildTransMask(model, dra, labelDDs, draDDRowVars, draDDColVars); + ((NondetModel) model).getAllDDNondetVars().refAll(); + + newTrans = buildTransMask(dra, labelDDs, allDDRowVars, allDDColVars, draDDRowVars, draDDColVars); JDD.Ref(model.getTrans()); - newTrans = JDD.Apply(JDD.TIMES, model.getTrans(), draDD); + newTrans = JDD.Apply(JDD.TIMES, model.getTrans(), newTrans); newStart = buildStartMask(dra, labelDDs); JDD.Ref(model.getStart()); newStart = JDD.And(model.getStart(), newStart); - - //TODO: new varlist - NondetModel modelProd = new NondetModel(newTrans, newStart, new JDDNode[0], new JDDNode[0], new String[0], - newAllDDRowVars, newAllDDColVars, ((NondetModel) model).getAllDDSchedVars(), ((NondetModel) model) - .getAllDDSynchVars(), ((NondetModel) model).getAllDDChoiceVars(), allDDNondetVars, newDDVarNames, model.getNumModules(), model.getModuleNames(), - model.getModuleDDRowVars(), model.getModuleDDColVars(), model.getNumVars()+1, model.getVarList(), - newVarDDRowVars, newVarDDColVars, model.getConstantValues()); + + // Create a new model model object to store the product model + NondetModel modelProd = new NondetModel( + // New transition matrix/start state + newTrans, newStart, + // Don't pass in any rewards info + new JDDNode[0], new JDDNode[0], new String[0], + // New list of all row/col vars + newAllDDRowVars, newAllDDColVars, + // Nondet variables (unchanged) + ((NondetModel) model).getAllDDSchedVars(), ((NondetModel) model).getAllDDSynchVars(), + ((NondetModel) model).getAllDDChoiceVars(), ((NondetModel) model).getAllDDNondetVars(), + // New list of var names + newDDVarNames, + // Module info (unchanged) + model.getNumModules(), model.getModuleNames(), model.getModuleDDRowVars(), model.getModuleDDColVars(), + // New var info + model.getNumVars() + 1, newVarList, newVarDDRowVars, newVarDDColVars, + // Constants (no change) + model.getConstantValues()); + + // Do reachability/etc. for the new model modelProd.doReachability(prism.getExtraReachInfo()); modelProd.filterReachableStates(); modelProd.findDeadlocks(); - + if (modelProd.getDeadlockStates().size() > 0) { + throw new PrismException("Model-DRA product has deadlock states"); + } + return modelProd; } @@ -232,13 +256,12 @@ public class LTLModelChecker * that exist in the DRA. * @return a referenced mask BDD over trans */ - private JDDNode buildTransMask(Model model, DRA dra, Vector labelDDs, JDDVars draDDRowVars, - JDDVars draDDColVars) + public JDDNode buildTransMask(DRA dra, Vector labelDDs, JDDVars allDDRowVars, JDDVars allDDColVars, + JDDVars draDDRowVars, JDDVars draDDColVars) { - JDDNode draMask; Iterator it; DA_State state; - JDDNode label, exprBDD, transition; + JDDNode draMask, label, exprBDD, transition; int i, n; draMask = JDD.Constant(0); @@ -260,7 +283,7 @@ public class LTLModelChecker label = JDD.And(label, exprBDD); } // Switch label BDD to col vars - label = JDD.PermuteVariables(label, model.getAllDDRowVars(), model.getAllDDColVars()); + label = JDD.PermuteVariables(label, allDDRowVars, allDDColVars); // Build a BDD for the edge transition = JDD.SetMatrixElement(JDD.Constant(0), draDDRowVars, draDDColVars, state.getName(), edge .getValue().getName(), 1); @@ -273,26 +296,24 @@ public class LTLModelChecker return draMask; } - + /** * Builds a mask BDD for start (which contains start nodes for every * DRA state after adding draRow/ColVars) that includes only the start states * (s, q) such that q = delta(q_in, label(s)) in the DRA. * @return a referenced mask BDD over start */ - private JDDNode buildStartMask(DRA dra, Vector labelDDs) { - JDDNode startMask; - JDDNode label; - JDDNode dest; - JDDNode tmp; + public JDDNode buildStartMask(DRA dra, Vector labelDDs) + { + JDDNode startMask, label, exprBDD, dest, tmp; startMask = JDD.Constant(0); for (Map.Entry edge : dra.getStartState().edges().entrySet()) { // Build a transition label BDD for each edge - //System.out.println(state.getName() + " to " + edge.getValue().getName() + " through " + edge.getKey().toString(dra.getAPSet(), false)); + //System.out.println("To " + edge.getValue().getName() + " through " + edge.getKey().toString(dra.getAPSet(), false)); label = JDD.Constant(1); for (int i = 0; i < dra.getAPSize(); i++) { - JDDNode exprBDD = labelDDs.get(Integer.parseInt(dra.getAPSet().getAP(i).substring(1))); + exprBDD = labelDDs.get(Integer.parseInt(dra.getAPSet().getAP(i).substring(1))); JDD.Ref(exprBDD); if (!edge.getKey().get(i)) { exprBDD = JDD.Not(exprBDD); @@ -309,21 +330,20 @@ public class LTLModelChecker // Add this destination to our start mask startMask = JDD.Or(startMask, tmp); } - // mainLog.println("Start state mask BDD:"); - // JDD.PrintVector(startMask, allDDRowVars); return startMask; } - + /** * computes maximal accepting SCSSs for each Rabin acceptance pair * * @returns a referenced BDD of the union of all the accepting SCSSs - */ - public JDDNode findAcceptingSCSSs(DRA dra, NondetModel model) throws PrismException { + */ + public JDDNode findAcceptingSCSSs(DRA dra, NondetModel model) throws PrismException + { JDDNode allAcceptingSCSSs = JDD.Constant(0); - + // for each acceptance pair (H_i, L_i) in the DRA, build H'_i = S x H_i // and compute the SCSS maximals in H'_i for (int i = 0; i < dra.acceptance().size(); i++) { @@ -356,10 +376,10 @@ public class LTLModelChecker acceptingStates = JDD.Apply(JDD.TIMES, acceptingStates, acceptanceVector_H); acceptingStates = JDD.ThereExists(acceptingStates, model.getAllDDColVars()); acceptingStates = JDD.ThereExists(acceptingStates, model.getAllDDNondetVars()); - + // find SCSSs in acceptingStates that are accepting under L_i JDDNode acceptingSCSSs = filteredUnion(findMaximalSCSSs(model, acceptingStates), acceptanceVector_L); - + // Add SCSSs to our destination bdd allAcceptingSCSSs = JDD.Or(allAcceptingSCSSs, acceptingSCSSs); } @@ -371,20 +391,21 @@ public class LTLModelChecker * @param states BDD of a set of states (dereferenced after calling this function) * @return a vector of referenced BDDs containing all the maximal SCSSs in states */ - private Vector findMaximalSCSSs(NondetModel model, JDDNode states) throws PrismException { - + private Vector findMaximalSCSSs(NondetModel model, JDDNode states) throws PrismException + { + boolean initialCandidate = true; Stack candidates = new Stack(); candidates.push(states); Vector scsss = new Vector(); - + while (!candidates.isEmpty()) { JDDNode candidate = candidates.pop(); - + // Compute the stable set JDD.Ref(candidate); JDDNode stableSet = findMaxStableSet(model, candidate); - + if (!initialCandidate) { // candidate is an SCC, check if it's stable if (stableSet.equals(candidate)) { @@ -392,29 +413,33 @@ public class LTLModelChecker JDD.Deref(stableSet); continue; } - } - else initialCandidate = false; + } else + initialCandidate = false; JDD.Deref(candidate); - + // Filter bad transitions JDD.Ref(stableSet); JDDNode stableSetTrans = maxStableSetTrans(model, stableSet); - + // now find the maximal SCCs in (stableSet, stableSetTrans) Vector sccs; SCCComputer sccComputer; switch (prism.getSCCMethod()) { case Prism.LOCKSTEP: - sccComputer = new SCCComputerLockstep(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model.getAllDDColVars()); + sccComputer = new SCCComputerLockstep(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model + .getAllDDColVars()); break; case Prism.SCCFIND: - sccComputer = new SCCComputerSCCFind(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model.getAllDDColVars()); + sccComputer = new SCCComputerSCCFind(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model + .getAllDDColVars()); break; case Prism.XIEBEEREL: - sccComputer = new SCCComputerXB(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model.getAllDDColVars()); + sccComputer = new SCCComputerXB(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model + .getAllDDColVars()); break; default: - sccComputer = new SCCComputerLockstep(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model.getAllDDColVars()); + sccComputer = new SCCComputerLockstep(prism, stableSet, stableSetTrans, model.getAllDDRowVars(), model + .getAllDDColVars()); } sccComputer.computeBSCCs(); JDD.Deref(stableSet); @@ -425,25 +450,26 @@ public class LTLModelChecker } return scsss; } - + /** * Returns the maximal stable set in c * @param c a set of nodes where we want to find a stable set * (dereferenced after calling this function) * @return a referenced BDD with the maximal stable set in c */ - private JDDNode findMaxStableSet(NondetModel model, JDDNode c) { + private JDDNode findMaxStableSet(NondetModel model, JDDNode c) + { JDDNode old; JDDNode current; JDDNode mask; - + JDD.Ref(c); current = c; do { /* if (verbose) { - mainLog.println("Stable set pass " + i + ":"); - } */ + mainLog.println("Stable set pass " + i + ":"); + } */ old = current; // states that aren't in B (column vector) JDD.Ref(current); @@ -460,11 +486,11 @@ public class LTLModelChecker // states in B that have an action that always transitions to other states in B current = JDD.Apply(JDD.TIMES, current, mask); /* if (verbose) { - mainLog.println("Stable set search pass " + i); - JDD.PrintVector(current, allDDRowVars); - mainLog.println(); - i++; - } */ + mainLog.println("Stable set search pass " + i); + JDD.PrintVector(current, allDDRowVars); + mainLog.println(); + i++; + } */ } while (!current.equals(old)); JDD.Deref(c); return current; @@ -475,7 +501,8 @@ public class LTLModelChecker * @param b BDD of a stable set (dereferenced after calling this function) * @return referenced BDD of the transition relation restricted to the stable set */ - private JDDNode maxStableSetTrans(NondetModel model, JDDNode b) { + private JDDNode maxStableSetTrans(NondetModel model, JDDNode b) + { JDDNode ssTrans; JDDNode mask; @@ -500,7 +527,7 @@ public class LTLModelChecker return ssTrans; } - + /** * Returns the union of each set in the vector that has nonempty intersection * with the filter BDD. @@ -508,10 +535,11 @@ public class LTLModelChecker * @param filter filter BDD against which each set is checked for nonempty intersection * also dereferenced after calling this function * @return Referenced BDD with the filtered union - */ - private JDDNode filteredUnion(Vector sets, JDDNode filter) { + */ + private JDDNode filteredUnion(Vector sets, JDDNode filter) + { JDDNode union = JDD.Constant(0); - for (JDDNode set: sets) { + for (JDDNode set : sets) { JDD.Ref(filter); union = JDD.Or(union, JDD.And(set, filter)); } diff --git a/prism/src/prism/NondetModelChecker.java b/prism/src/prism/NondetModelChecker.java index 40fab54a..4c49de9f 100644 --- a/prism/src/prism/NondetModelChecker.java +++ b/prism/src/prism/NondetModelChecker.java @@ -312,7 +312,7 @@ public class NondetModelChecker extends StateModelChecker { // Test whether this is a simple path formula (i.e. PCTL) // and then pass control to appropriate method. - if (expr.isSimplePathFormula()) { + if (1==2){//expr.isSimplePathFormula()) { return checkProbPathFormulaSimple(expr, qual, min); } else { return checkProbPathFormulaLTL(expr, qual, min); @@ -372,18 +372,22 @@ public class NondetModelChecker extends StateModelChecker protected StateProbs checkProbPathFormulaLTL(Expression expr, boolean qual, boolean min) throws PrismException { LTLModelChecker mcLtl; - StateProbs probs = null; + StateProbs probsProduct = null, probs = null; Expression ltl; Vector labelDDs; DRA dra; + NondetModel modelProduct; + NondetModelChecker mcProduct; + JDDNode startMask; + JDDVars draDDRowVars; int i; long l; - mcLtl = new LTLModelChecker(prism, this, model); + mcLtl = new LTLModelChecker(prism, this); // Model check maximal state formulas labelDDs = new Vector(); - ltl = mcLtl.checkMaximalStateFormulas(this, model, expr.deepCopy(), labelDDs); + ltl = mcLtl.checkMaximalStateFormulas(model, expr.deepCopy(), labelDDs); // Convert LTL formula to deterministic Rabin automaton (DRA) mainLog.println("\nBuilding deterministic Rabin automaton (for "+ltl+")..."); @@ -396,18 +400,31 @@ public class NondetModelChecker extends StateModelChecker // Build product of MDP and automaton mainLog.println("\nConstructing Model-DRA product..."); - NondetModel productModel = mcLtl.constructProductModel(dra, model, labelDDs); + modelProduct = mcLtl.constructProductModel(dra, model, labelDDs); mainLog.println(); - productModel.printTransInfo(mainLog, prism.getExtraDDInfo()); + modelProduct.printTransInfo(mainLog, prism.getExtraDDInfo()); - JDDNode acc = mcLtl.findAcceptingSCSSs(dra, productModel); + mainLog.println("\nFinding accepting SCCs..."); + JDDNode acc = mcLtl.findAcceptingSCSSs(dra, modelProduct); - NondetModelChecker mc2 = new NondetModelChecker(prism, productModel, null); - probs = mc2.computeUntilProbs(productModel.getTrans(), productModel.getTrans01(), productModel.getReach(), acc, min); + mainLog.println("\nComputing reachability probabilities..."); + mcProduct = new NondetModelChecker(prism, modelProduct, null); + probsProduct = mcProduct.computeUntilProbs(modelProduct.getTrans(), modelProduct.getTrans01(), modelProduct.getReach(), acc, min); - productModel.clear(); + // Convert probability vector to original model + // First, filter over DRA start states + startMask = mcLtl.buildStartMask(dra, labelDDs); + probsProduct.filter(startMask); + JDD.Deref(startMask); + // Then sum over DD vars for the DRA state + draDDRowVars = new JDDVars(); + draDDRowVars.addVars(modelProduct.getAllDDRowVars()); + draDDRowVars.removeVars(allDDRowVars); + probs = probsProduct.sumOverDDVars(draDDRowVars, model); // Deref, clean up + probsProduct.clear(); + modelProduct.clear(); for (i = 0; i < labelDDs.size(); i++) { JDD.Deref(labelDDs.get(i)); } diff --git a/prism/src/prism/StateProbs.java b/prism/src/prism/StateProbs.java index 91701953..8a542015 100644 --- a/prism/src/prism/StateProbs.java +++ b/prism/src/prism/StateProbs.java @@ -27,6 +27,7 @@ package prism; import jdd.JDDNode; +import jdd.JDDVars; // interface for state probability vector classes @@ -47,6 +48,7 @@ public interface StateProbs double maxOverBDD(JDDNode filter); double sumOverBDD(JDDNode filter); double sumOverMTBDD(JDDNode mult); + StateProbs sumOverDDVars(JDDVars sumVars, Model newModel); JDDNode getBDDFromInterval(String relOp, double bound); JDDNode getBDDFromInterval(double lo, double hi); void print(PrismLog log); diff --git a/prism/src/prism/StateProbsDV.java b/prism/src/prism/StateProbsDV.java index 7eb62043..680bc96e 100644 --- a/prism/src/prism/StateProbsDV.java +++ b/prism/src/prism/StateProbsDV.java @@ -197,7 +197,7 @@ public class StateProbsDV implements StateProbs return probs.sumOverBDD(filter, vars, odd); } - // do a weighted sum of the elements of a double array and the values the mtbdd passed in + // do a weighted sum of the elements of the vector and the values the mtbdd passed in // (used for csl reward steady state operator) public double sumOverMTBDD(JDDNode mult) @@ -205,6 +205,18 @@ public class StateProbsDV implements StateProbs return probs.sumOverMTBDD(mult, vars, odd); } + // sum up the elements of the vector, over a subset of its dd vars + // store the result in a new StateProbsDV (for newModel) + + public StateProbs sumOverDDVars(JDDVars sumVars, Model newModel) + { + DoubleVector tmp; + + tmp = probs.sumOverDDVars(model.getAllDDRowVars(), odd, newModel.getODD(), sumVars.getMinVarIndex(), sumVars.getMaxVarIndex()); + + return new StateProbsDV(tmp, newModel); + } + // generate bdd from an interval (relative operator and bound) public JDDNode getBDDFromInterval(String relOp, double bound) diff --git a/prism/src/prism/StateProbsMTBDD.java b/prism/src/prism/StateProbsMTBDD.java index 04f3b580..0508061c 100644 --- a/prism/src/prism/StateProbsMTBDD.java +++ b/prism/src/prism/StateProbsMTBDD.java @@ -283,6 +283,16 @@ public class StateProbsMTBDD implements StateProbs return d; } + public StateProbs sumOverDDVars(JDDVars sumVars, Model newModel) + { + JDDNode tmp; + + JDD.Ref(probs); + tmp = JDD.SumAbstract(probs, sumVars); + + return new StateProbsMTBDD(tmp, newModel); + } + // generate bdd (from an interval: relative operator and bound) public JDDNode getBDDFromInterval(String relOp, double bound)