diff --git a/prism/src/prism/Model.java b/prism/src/prism/Model.java index acf2f6e0..cdca9160 100644 --- a/prism/src/prism/Model.java +++ b/prism/src/prism/Model.java @@ -98,9 +98,12 @@ public interface Model ODDNode getODD(); + void resetTrans(JDDNode trans); + void resetTransRewards(int i, JDDNode transRewards); void doReachability(); void doReachability(boolean extraReachInfo); void skipReachability(); + void setReach(JDDNode reach); void filterReachableStates(); void findDeadlocks(); void fixDeadlocks(); diff --git a/prism/src/prism/Modules2MTBDD.java b/prism/src/prism/Modules2MTBDD.java index 1ac154d8..c157621f 100644 --- a/prism/src/prism/Modules2MTBDD.java +++ b/prism/src/prism/Modules2MTBDD.java @@ -106,6 +106,14 @@ public class Modules2MTBDD // flags for keeping track of which variables have been used private boolean[] varsUsed; + // symmetry info + private boolean doSymmetry; // use symmetry reduction + private JDDNode symm; // dd of symmetric states + private JDDNode nonSymms[]; // dds of non-(i,i+1)-symmetric states (i=0...numSymmModules-1) + private int numModulesBeforeSymm; // number of modules in the PRISM file before the symmetric ones + private int numModulesAfterSymm; // number of modules in the PRISM file after the symmetric ones + private int numSymmModules; // number of symmetric components + // data structure used to store mtbdds and related info // for some component of the whole model @@ -146,6 +154,9 @@ public class Modules2MTBDD mainLog = p.getMainLog(); techLog = p.getTechLog(); modulesFile = mf; + // get symmetry reduction info + String s = prism.getSettings().getString(PrismSettings.PRISM_SYMM_RED_PARAMS); + doSymmetry = !(s == null || s == ""); } // main method - translate @@ -262,6 +273,9 @@ public class Modules2MTBDD mainLog.print("Reach: " + JDD.GetNumNodes(model.getReach()) + " nodes\n"); } + // symmetrification + if (doSymmetry) doSymmetry(model); + // find any deadlocks model.findDeadlocks(); @@ -289,6 +303,12 @@ public class Modules2MTBDD JDD.Deref(ddChoiceVars[i]); } } + if (doSymmetry) { + JDD.Deref(symm); + for (i = 0; i < numSymmModules-1; i++) { + JDD.Deref(nonSymms[i]); + } + } expr2mtbdd.clearDummyModel(); @@ -1968,6 +1988,234 @@ public class Modules2MTBDD } } } + + // symmetrification + + private void doSymmetry(Model model) throws PrismException + { + JDDNode tmp, transNew, reach, trans, transRewards[]; + int i, j, k, numSwaps; + boolean done; + long clock; + String ss[]; + + // parse symmetry reduction parameters + ss = prism.getSettings().getString(PrismSettings.PRISM_SYMM_RED_PARAMS).split(" "); + if (ss.length != 2) throw new PrismException ("Invalid parameters for symmetry reduction"); + try { + numModulesBeforeSymm = Integer.parseInt(ss[0].trim()); + numModulesAfterSymm = Integer.parseInt(ss[1].trim()); + } + catch (NumberFormatException e) { + throw new PrismException("Invalid parameters for symmetry reduction"); + } + + clock = System.currentTimeMillis(); + + // get a copies of model (MT)BDDs + reach = model.getReach(); + JDD.Ref(reach); + trans = model.getTrans(); + JDD.Ref(trans); + transRewards = new JDDNode[numRewardStructs]; + for (i = 0; i < numRewardStructs; i++) { + transRewards[i] = model.getTransRewards(i); + JDD.Ref(transRewards[i]); + } + + mainLog.print("\nApplying symmetry reduction...\n"); + + //identifySymmetricModules(); + numSymmModules = numModules - (numModulesBeforeSymm + numModulesAfterSymm); + computeSymmetryFilters(reach); + + // compute number of local states +// JDD.Ref(reach); +// tmp = reach; +// for (i = 0; i < numModules; i++) { +// if (i != numModulesBeforeSymm) tmp = JDD.ThereExists(tmp, moduleDDRowVars[i]); +// } +// tmp = JDD.ThereExists(tmp, globalDDRowVars); +// mainLog.println("Local states: " + (int)JDD.GetNumMinterms(tmp, moduleDDRowVars[numModulesBeforeSymm].n())); +// JDD.Deref(tmp); + + //ODDNode odd = ODDUtils.BuildODD(reach, allDDRowVars); + //try {sparse.PrismSparse.NondetExport(trans, allDDRowVars, allDDColVars, allDDNondetVars, odd, Prism.EXPORT_PLAIN, "trans-full.tra"); } catch (FileNotFoundException e) {} + + // trans - rows + mainLog.print("trans (full): "); + mainLog.println(JDD.GetInfoString(trans, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + JDD.Ref(symm); + trans = JDD.Apply(JDD.TIMES, trans, symm); + mainLog.print("trans (symm): "); + mainLog.println(JDD.GetInfoString(trans, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + + // trans rewards - rows + for (k = 0; k < numRewardStructs; k++) { + mainLog.print("transrew["+k+"] (full): "); + mainLog.println(JDD.GetInfoString(transRewards[k], (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + JDD.Ref(symm); + transRewards[k] = JDD.Apply(JDD.TIMES, transRewards[k], symm); + mainLog.print("transrew["+k+"] (symm): "); + mainLog.println(JDD.GetInfoString(transRewards[k], (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + } + + mainLog.println("Starting quicksort..."); + done = false; + numSwaps = 0; + for (i = numSymmModules; i > 1 && !done; i--) { + // store trans from previous iter + JDD.Ref(trans); + transNew = trans; + for (j = 0; j < i-1; j++) { + + // are there any states where j+1>j+2? + if (nonSymms[j].equals(JDD.ZERO)) continue; + + // identify offending block in trans + JDD.Ref(transNew); + JDD.Ref(nonSymms[j]); + tmp = JDD.Apply(JDD.TIMES, transNew, JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars)); + //mainLog.print("bad block: "); + //mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + + if (tmp.equals(JDD.ZERO)) { JDD.Deref(tmp); continue; } + numSwaps++; + mainLog.println("Iteration "+(numSymmModules-i+1)+"."+(j+1)); + + // swap + tmp = JDD.SwapVariables(tmp, moduleDDColVars[numModulesBeforeSymm+j], moduleDDColVars[numModulesBeforeSymm+j+1]); + //mainLog.print("bad block (swapped): "); + //mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + + // insert swapped block + JDD.Ref(nonSymms[j]); + JDD.Ref(tmp); + transNew = JDD.ITE(JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars), JDD.Constant(0), JDD.Apply(JDD.PLUS, transNew, tmp)); + //mainLog.print("trans (symm): "); + //mainLog.println(JDD.GetInfoString(transNew, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + JDD.Deref(tmp); + + for (k = 0; k < numRewardStructs; k++) { + // identify offending block in trans rewards + JDD.Ref(transRewards[k]); + JDD.Ref(nonSymms[j]); + tmp = JDD.Apply(JDD.TIMES, transRewards[k], JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars)); + //mainLog.print("bad block: "); + //mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + + // swap + tmp = JDD.SwapVariables(tmp, moduleDDColVars[numModulesBeforeSymm+j], moduleDDColVars[numModulesBeforeSymm+j+1]); + //mainLog.print("bad block (swapped): "); + //mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + + // insert swapped block + JDD.Ref(nonSymms[j]); + JDD.Ref(tmp); + transRewards[k] = JDD.ITE(JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars), JDD.Constant(0), JDD.Apply(JDD.PLUS, transRewards[k], tmp)); + //mainLog.print("transrew["+k+"] (symm): "); + //mainLog.println(JDD.GetInfoString(transRewards[k], (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2))); + JDD.Deref(tmp); + } + } + + if (transNew.equals(trans)) { + done = true; + } + JDD.Deref(trans); + trans = transNew; + } + + // reset (MT)BDDs in model + model.resetTrans(trans); + for (i = 0; i < numRewardStructs; i++) { + model.resetTransRewards(i, transRewards[i]); + } + + // reset reach bdd, etc. + JDD.Ref(symm); + reach = JDD.And(reach, symm); + + model.setReach(reach); + model.filterReachableStates(); + + clock = System.currentTimeMillis() - clock; + mainLog.println("Symmetry complete: " + (numSymmModules-i) + " iterations, " + numSwaps + " swaps, " + clock/1000.0 + " seconds"); + } + + private void computeSymmetryFilters(JDDNode reach) throws PrismException + { + int i; + JDDNode tmp; + + // array for non-symmetric parts + nonSymms = new JDDNode[numSymmModules-1]; + // dd for all symmetric states + JDD.Ref(reach); + symm = reach; + // loop through symmetric module pairs + for (i = 0; i < numSymmModules-1; i++) { + // (locally) symmetric states, i.e. where i+1 <= i+2 + tmp = JDD.VariablesLessThanEquals(moduleDDRowVars[numModulesBeforeSymm+i], moduleDDRowVars[numModulesBeforeSymm+i+1]); + // non-(locally)-symmetric states + JDD.Ref(tmp); + JDD.Ref(reach); + nonSymms[i] = JDD.And(JDD.Not(tmp), reach); + // all symmetric states + symm = JDD.And(symm, tmp); + } + } + + // old version of computeSymmetryFilters() + /*private void computeSymmetryFilters() throws PrismException + { + int i, j, k, n; + String varNames[][] = null; + JDDNode tmp; + Expression expr, exprTmp; + + // get var names for each symm module + n = modulesFile.getModule(numModulesBeforeSymm).getNumDeclarations(); + varNames = new String[numModules][]; + for (i = numModulesBeforeSymm; i < numModulesBeforeSymm+numSymmModules; i++) { + varNames[i-numModulesBeforeSymm] = new String[n]; + j = 0; + while (j < numVars && varList.getModule(j) != i) j++; + for (k = 0; k < n; k++) { + varNames[i-numModulesBeforeSymm][k] = varList.getName(j+k); + } + } + + // array for non-symmetric parts + nonSymms = new JDDNode[numSymmModules-1]; + // dd for all symmetric states + JDD.Ref(reach); + symm = reach; + // loop through symmetric module pairs + for (i = 0; i < numSymmModules-1; i++) { + // expression for (locally) symmetric states, i.e. where i+1 <= i+2 + expr = new ExpressionTrue(); + for (j = varNames[0].length-1; j >= 0 ; j--) { + exprTmp = new ExpressionAnd(); + ((ExpressionAnd)exprTmp).addOperand(new ExpressionBrackets(new ExpressionRelOp(new ExpressionVar(varNames[i][j], 0), "=", new ExpressionVar(varNames[i+1][j], 0)))); + ((ExpressionAnd)exprTmp).addOperand(new ExpressionBrackets(expr)); + expr = exprTmp; + exprTmp = new ExpressionOr(); + ((ExpressionOr)exprTmp).addOperand(new ExpressionBrackets(new ExpressionRelOp(new ExpressionVar(varNames[i][j], 0), "<", new ExpressionVar(varNames[i+1][j], 0)))); + ((ExpressionOr)exprTmp).addOperand(expr); + expr = exprTmp; + } + mainLog.println(expr); + // bdd for (locally) symmetric states, i.e. where i+1 <= i+2 + tmp = expr2mtbdd.translateExpression(expr); + // non-(locally)-symmetric states + JDD.Ref(tmp); + JDD.Ref(reach); + nonSymms[i] = JDD.And(JDD.Not(tmp), reach); + // all symmetric states + symm = JDD.And(symm, tmp); + } + }*/ } //------------------------------------------------------------------------------ diff --git a/prism/src/prism/NondetModel.java b/prism/src/prism/NondetModel.java index 463ef670..593640ce 100644 --- a/prism/src/prism/NondetModel.java +++ b/prism/src/prism/NondetModel.java @@ -165,6 +165,7 @@ public class NondetModel extends ProbModel // build mask for nondeterminstic choices JDD.Ref(trans01); JDD.Ref(reach); + if (this.nondetMask != null) JDD.Deref(this.nondetMask); // nb: this assumes that there are no deadlock states nondetMask = JDD.And(JDD.Not(JDD.ThereExists(trans01, allDDColVars)), reach); diff --git a/prism/src/prism/PrismCL.java b/prism/src/prism/PrismCL.java index bea5e6fd..51be031a 100644 --- a/prism/src/prism/PrismCL.java +++ b/prism/src/prism/PrismCL.java @@ -1403,6 +1403,16 @@ public class PrismCL } } + // enable symmetry reduction + else if (sw.equals("symm")) { + if (i < args.length-2) { + prism.getSettings().set(PrismSettings.PRISM_SYMM_RED_PARAMS, args[++i]+" "+args[++i]); + } + else { + errorAndExit("-symm switch requires two parameters (num. modules before/after symmetric ones)"); + } + } + // unknown switch - error else { errorAndExit("Invalid switch -" + sw + " (type \"prism -help\" for full list)"); diff --git a/prism/src/prism/PrismSettings.java b/prism/src/prism/PrismSettings.java index b37586fc..7cb4ce19 100644 --- a/prism/src/prism/PrismSettings.java +++ b/prism/src/prism/PrismSettings.java @@ -89,6 +89,7 @@ public class PrismSettings implements Observer public static final String PRISM_EXTRA_DD_INFO = "prism.extraDDInfo"; public static final String PRISM_EXTRA_REACH_INFO = "prism.extraReachInfo"; public static final String PRISM_SCC_METHOD = "prism.sccMethod"; + public static final String PRISM_SYMM_RED_PARAMS = "prism.symmRedParams"; //GUI Model public static final String MODEL_AUTO_PARSE = "model.autoParse"; @@ -185,7 +186,8 @@ public class PrismSettings implements Observer { BOOLEAN_TYPE, PRISM_DO_SS_DETECTION, "Use steady-state detection", "3.0", new Boolean(true), "0,", "Use steady-state detection during CTMC transient probability computation." }, { BOOLEAN_TYPE, PRISM_EXTRA_DD_INFO, "Extra MTBDD information", "3.2", new Boolean(false), "0,", "Display extra information about (MT)BDDs used during and after model construction." }, { BOOLEAN_TYPE, PRISM_EXTRA_REACH_INFO, "Extra reachability information", "3.2", new Boolean(false), "0,", "Display extra information about progress of reachability during model construction." }, - { CHOICE_TYPE, PRISM_SCC_METHOD, "SCC decomposition method", "3.2", "Lockstep", "Xie-Beerel,Lockstep,SCC-Find", "Which algorithm to use for decomposing a graph into strongly connected components (SCCs)." } + { CHOICE_TYPE, PRISM_SCC_METHOD, "SCC decomposition method", "3.2", "Lockstep", "Xie-Beerel,Lockstep,SCC-Find", "Which algorithm to use for decomposing a graph into strongly connected components (SCCs)." }, + { STRING_TYPE, PRISM_SYMM_RED_PARAMS, "Symmetry reduction parameters", "3.2", "", "", "Parameters for symmetry reduction (format: \"i j\" where i and j are the number of modules before and after the symmetric ones; empty string means symmetry reduction disabled)." } }, { { BOOLEAN_TYPE, MODEL_AUTO_PARSE, "Auto parse", "3.0", new Boolean(true), "", "Parse PRISM models automatically as they are loaded/edited in the text editor." }, @@ -262,7 +264,14 @@ public class PrismSettings implements Observer Setting set; - if(setting[0].equals(INTEGER_TYPE)) + if(setting[0].equals(STRING_TYPE)) + { + set = new SingleLineStringSetting(display, (String)value, comment, optionOwners[i], false); + set.setKey(key); + set.setVersion(version); + optionOwners[i].addSetting(set); + } + else if(setting[0].equals(INTEGER_TYPE)) { if(constraint.equals("")) set = new IntegerSetting(display, (Integer)value, comment, optionOwners[i], false); diff --git a/prism/src/prism/ProbModel.java b/prism/src/prism/ProbModel.java index 95538128..c57289d4 100644 --- a/prism/src/prism/ProbModel.java +++ b/prism/src/prism/ProbModel.java @@ -413,6 +413,28 @@ public class ProbModel implements Model numStartStates = JDD.GetNumMinterms(start, allDDRowVars.n()); } + /** + * Reset transition matrix DD + */ + + public void resetTrans(JDDNode trans) + { + if (this.trans != null) JDD.Deref(this.trans); + this.trans = trans; + } + + /** + * Reset transition rewards DDs + */ + + public void resetTransRewards(int i, JDDNode transRewards) + { + if (this.transRewards[i] != null) { + JDD.Deref(this.transRewards[i]); + } + this.transRewards[i] = transRewards; + } + // do reachability public void doReachability() @@ -423,13 +445,7 @@ public class ProbModel implements Model public void doReachability(boolean extraReachInfo) { // compute reachable states - reach = PrismMTBDD.Reachability(trans01, allDDRowVars, allDDColVars, start, extraReachInfo ? 1 : 0); - - // work out number of reachable states - numStates = JDD.GetNumMinterms(reach, allDDRowVars.n()); - - // build odd - odd = ODDUtils.BuildODD(reach, allDDRowVars); + setReach(PrismMTBDD.Reachability(trans01, allDDRowVars, allDDColVars, start, extraReachInfo ? 1 : 0)); } // this method allows you to skip the reachability phase @@ -447,6 +463,22 @@ public class ProbModel implements Model odd = ODDUtils.BuildODD(reach, allDDRowVars); } + /** + * Set reachable states BDD (and compute number of states and ODD) + */ + + public void setReach(JDDNode reach) + { + if (this.reach != null) JDD.Deref(this.reach); + this.reach = reach; + + // work out number of reachable states + numStates = JDD.GetNumMinterms(reach, allDDRowVars.n()); + + // build odd + odd = ODDUtils.BuildODD(reach, allDDRowVars); + } + // remove non-reachable states from various dds // (and calculate num transitions) @@ -480,6 +512,11 @@ public class ProbModel implements Model transRewards[i] = JDD.Apply(JDD.TIMES, tmp, transRewards[i]); } + // filter start states, work out number of initial states + JDD.Ref(reach); + start = JDD.Apply(JDD.TIMES, reach, start); + numStartStates = JDD.GetNumMinterms(start, allDDRowVars.n()); + // work out number of transitions numTransitions = JDD.GetNumMinterms(trans01, getNumDDVarsInTrans()); }