Browse Source

Added symmetry reduction into main trunk.

git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@853 bbc10eb1-c90d-0410-af57-cb519fbb1720
master
Dave Parker 17 years ago
parent
commit
b82b0c91f9
  1. 3
      prism/src/prism/Model.java
  2. 248
      prism/src/prism/Modules2MTBDD.java
  3. 1
      prism/src/prism/NondetModel.java
  4. 10
      prism/src/prism/PrismCL.java
  5. 13
      prism/src/prism/PrismSettings.java
  6. 51
      prism/src/prism/ProbModel.java

3
prism/src/prism/Model.java

@ -98,9 +98,12 @@ public interface Model
ODDNode getODD();
void resetTrans(JDDNode trans);
void resetTransRewards(int i, JDDNode transRewards);
void doReachability();
void doReachability(boolean extraReachInfo);
void skipReachability();
void setReach(JDDNode reach);
void filterReachableStates();
void findDeadlocks();
void fixDeadlocks();

248
prism/src/prism/Modules2MTBDD.java

@ -106,6 +106,14 @@ public class Modules2MTBDD
// flags for keeping track of which variables have been used
private boolean[] varsUsed;
// symmetry info
private boolean doSymmetry; // use symmetry reduction
private JDDNode symm; // dd of symmetric states
private JDDNode nonSymms[]; // dds of non-(i,i+1)-symmetric states (i=0...numSymmModules-1)
private int numModulesBeforeSymm; // number of modules in the PRISM file before the symmetric ones
private int numModulesAfterSymm; // number of modules in the PRISM file after the symmetric ones
private int numSymmModules; // number of symmetric components
// data structure used to store mtbdds and related info
// for some component of the whole model
@ -146,6 +154,9 @@ public class Modules2MTBDD
mainLog = p.getMainLog();
techLog = p.getTechLog();
modulesFile = mf;
// get symmetry reduction info
String s = prism.getSettings().getString(PrismSettings.PRISM_SYMM_RED_PARAMS);
doSymmetry = !(s == null || s == "");
}
// main method - translate
@ -262,6 +273,9 @@ public class Modules2MTBDD
mainLog.print("Reach: " + JDD.GetNumNodes(model.getReach()) + " nodes\n");
}
// symmetrification
if (doSymmetry) doSymmetry(model);
// find any deadlocks
model.findDeadlocks();
@ -289,6 +303,12 @@ public class Modules2MTBDD
JDD.Deref(ddChoiceVars[i]);
}
}
if (doSymmetry) {
JDD.Deref(symm);
for (i = 0; i < numSymmModules-1; i++) {
JDD.Deref(nonSymms[i]);
}
}
expr2mtbdd.clearDummyModel();
@ -1968,6 +1988,234 @@ public class Modules2MTBDD
}
}
}
// symmetrification
private void doSymmetry(Model model) throws PrismException
{
JDDNode tmp, transNew, reach, trans, transRewards[];
int i, j, k, numSwaps;
boolean done;
long clock;
String ss[];
// parse symmetry reduction parameters
ss = prism.getSettings().getString(PrismSettings.PRISM_SYMM_RED_PARAMS).split(" ");
if (ss.length != 2) throw new PrismException ("Invalid parameters for symmetry reduction");
try {
numModulesBeforeSymm = Integer.parseInt(ss[0].trim());
numModulesAfterSymm = Integer.parseInt(ss[1].trim());
}
catch (NumberFormatException e) {
throw new PrismException("Invalid parameters for symmetry reduction");
}
clock = System.currentTimeMillis();
// get a copies of model (MT)BDDs
reach = model.getReach();
JDD.Ref(reach);
trans = model.getTrans();
JDD.Ref(trans);
transRewards = new JDDNode[numRewardStructs];
for (i = 0; i < numRewardStructs; i++) {
transRewards[i] = model.getTransRewards(i);
JDD.Ref(transRewards[i]);
}
mainLog.print("\nApplying symmetry reduction...\n");
//identifySymmetricModules();
numSymmModules = numModules - (numModulesBeforeSymm + numModulesAfterSymm);
computeSymmetryFilters(reach);
// compute number of local states
// JDD.Ref(reach);
// tmp = reach;
// for (i = 0; i < numModules; i++) {
// if (i != numModulesBeforeSymm) tmp = JDD.ThereExists(tmp, moduleDDRowVars[i]);
// }
// tmp = JDD.ThereExists(tmp, globalDDRowVars);
// mainLog.println("Local states: " + (int)JDD.GetNumMinterms(tmp, moduleDDRowVars[numModulesBeforeSymm].n()));
// JDD.Deref(tmp);
//ODDNode odd = ODDUtils.BuildODD(reach, allDDRowVars);
//try {sparse.PrismSparse.NondetExport(trans, allDDRowVars, allDDColVars, allDDNondetVars, odd, Prism.EXPORT_PLAIN, "trans-full.tra"); } catch (FileNotFoundException e) {}
// trans - rows
mainLog.print("trans (full): ");
mainLog.println(JDD.GetInfoString(trans, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
JDD.Ref(symm);
trans = JDD.Apply(JDD.TIMES, trans, symm);
mainLog.print("trans (symm): ");
mainLog.println(JDD.GetInfoString(trans, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
// trans rewards - rows
for (k = 0; k < numRewardStructs; k++) {
mainLog.print("transrew["+k+"] (full): ");
mainLog.println(JDD.GetInfoString(transRewards[k], (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
JDD.Ref(symm);
transRewards[k] = JDD.Apply(JDD.TIMES, transRewards[k], symm);
mainLog.print("transrew["+k+"] (symm): ");
mainLog.println(JDD.GetInfoString(transRewards[k], (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
}
mainLog.println("Starting quicksort...");
done = false;
numSwaps = 0;
for (i = numSymmModules; i > 1 && !done; i--) {
// store trans from previous iter
JDD.Ref(trans);
transNew = trans;
for (j = 0; j < i-1; j++) {
// are there any states where j+1>j+2?
if (nonSymms[j].equals(JDD.ZERO)) continue;
// identify offending block in trans
JDD.Ref(transNew);
JDD.Ref(nonSymms[j]);
tmp = JDD.Apply(JDD.TIMES, transNew, JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars));
//mainLog.print("bad block: ");
//mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
if (tmp.equals(JDD.ZERO)) { JDD.Deref(tmp); continue; }
numSwaps++;
mainLog.println("Iteration "+(numSymmModules-i+1)+"."+(j+1));
// swap
tmp = JDD.SwapVariables(tmp, moduleDDColVars[numModulesBeforeSymm+j], moduleDDColVars[numModulesBeforeSymm+j+1]);
//mainLog.print("bad block (swapped): ");
//mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
// insert swapped block
JDD.Ref(nonSymms[j]);
JDD.Ref(tmp);
transNew = JDD.ITE(JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars), JDD.Constant(0), JDD.Apply(JDD.PLUS, transNew, tmp));
//mainLog.print("trans (symm): ");
//mainLog.println(JDD.GetInfoString(transNew, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
JDD.Deref(tmp);
for (k = 0; k < numRewardStructs; k++) {
// identify offending block in trans rewards
JDD.Ref(transRewards[k]);
JDD.Ref(nonSymms[j]);
tmp = JDD.Apply(JDD.TIMES, transRewards[k], JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars));
//mainLog.print("bad block: ");
//mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
// swap
tmp = JDD.SwapVariables(tmp, moduleDDColVars[numModulesBeforeSymm+j], moduleDDColVars[numModulesBeforeSymm+j+1]);
//mainLog.print("bad block (swapped): ");
//mainLog.println(JDD.GetInfoString(tmp, (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
// insert swapped block
JDD.Ref(nonSymms[j]);
JDD.Ref(tmp);
transRewards[k] = JDD.ITE(JDD.PermuteVariables(nonSymms[j], allDDRowVars, allDDColVars), JDD.Constant(0), JDD.Apply(JDD.PLUS, transRewards[k], tmp));
//mainLog.print("transrew["+k+"] (symm): ");
//mainLog.println(JDD.GetInfoString(transRewards[k], (type==ModulesFile.NONDETERMINISTIC)?(allDDRowVars.n()*2+allDDNondetVars.n()):(allDDRowVars.n()*2)));
JDD.Deref(tmp);
}
}
if (transNew.equals(trans)) {
done = true;
}
JDD.Deref(trans);
trans = transNew;
}
// reset (MT)BDDs in model
model.resetTrans(trans);
for (i = 0; i < numRewardStructs; i++) {
model.resetTransRewards(i, transRewards[i]);
}
// reset reach bdd, etc.
JDD.Ref(symm);
reach = JDD.And(reach, symm);
model.setReach(reach);
model.filterReachableStates();
clock = System.currentTimeMillis() - clock;
mainLog.println("Symmetry complete: " + (numSymmModules-i) + " iterations, " + numSwaps + " swaps, " + clock/1000.0 + " seconds");
}
private void computeSymmetryFilters(JDDNode reach) throws PrismException
{
int i;
JDDNode tmp;
// array for non-symmetric parts
nonSymms = new JDDNode[numSymmModules-1];
// dd for all symmetric states
JDD.Ref(reach);
symm = reach;
// loop through symmetric module pairs
for (i = 0; i < numSymmModules-1; i++) {
// (locally) symmetric states, i.e. where i+1 <= i+2
tmp = JDD.VariablesLessThanEquals(moduleDDRowVars[numModulesBeforeSymm+i], moduleDDRowVars[numModulesBeforeSymm+i+1]);
// non-(locally)-symmetric states
JDD.Ref(tmp);
JDD.Ref(reach);
nonSymms[i] = JDD.And(JDD.Not(tmp), reach);
// all symmetric states
symm = JDD.And(symm, tmp);
}
}
// old version of computeSymmetryFilters()
/*private void computeSymmetryFilters() throws PrismException
{
int i, j, k, n;
String varNames[][] = null;
JDDNode tmp;
Expression expr, exprTmp;
// get var names for each symm module
n = modulesFile.getModule(numModulesBeforeSymm).getNumDeclarations();
varNames = new String[numModules][];
for (i = numModulesBeforeSymm; i < numModulesBeforeSymm+numSymmModules; i++) {
varNames[i-numModulesBeforeSymm] = new String[n];
j = 0;
while (j < numVars && varList.getModule(j) != i) j++;
for (k = 0; k < n; k++) {
varNames[i-numModulesBeforeSymm][k] = varList.getName(j+k);
}
}
// array for non-symmetric parts
nonSymms = new JDDNode[numSymmModules-1];
// dd for all symmetric states
JDD.Ref(reach);
symm = reach;
// loop through symmetric module pairs
for (i = 0; i < numSymmModules-1; i++) {
// expression for (locally) symmetric states, i.e. where i+1 <= i+2
expr = new ExpressionTrue();
for (j = varNames[0].length-1; j >= 0 ; j--) {
exprTmp = new ExpressionAnd();
((ExpressionAnd)exprTmp).addOperand(new ExpressionBrackets(new ExpressionRelOp(new ExpressionVar(varNames[i][j], 0), "=", new ExpressionVar(varNames[i+1][j], 0))));
((ExpressionAnd)exprTmp).addOperand(new ExpressionBrackets(expr));
expr = exprTmp;
exprTmp = new ExpressionOr();
((ExpressionOr)exprTmp).addOperand(new ExpressionBrackets(new ExpressionRelOp(new ExpressionVar(varNames[i][j], 0), "<", new ExpressionVar(varNames[i+1][j], 0))));
((ExpressionOr)exprTmp).addOperand(expr);
expr = exprTmp;
}
mainLog.println(expr);
// bdd for (locally) symmetric states, i.e. where i+1 <= i+2
tmp = expr2mtbdd.translateExpression(expr);
// non-(locally)-symmetric states
JDD.Ref(tmp);
JDD.Ref(reach);
nonSymms[i] = JDD.And(JDD.Not(tmp), reach);
// all symmetric states
symm = JDD.And(symm, tmp);
}
}*/
}
//------------------------------------------------------------------------------

1
prism/src/prism/NondetModel.java

@ -165,6 +165,7 @@ public class NondetModel extends ProbModel
// build mask for nondeterminstic choices
JDD.Ref(trans01);
JDD.Ref(reach);
if (this.nondetMask != null) JDD.Deref(this.nondetMask);
// nb: this assumes that there are no deadlock states
nondetMask = JDD.And(JDD.Not(JDD.ThereExists(trans01, allDDColVars)), reach);

10
prism/src/prism/PrismCL.java

@ -1403,6 +1403,16 @@ public class PrismCL
}
}
// enable symmetry reduction
else if (sw.equals("symm")) {
if (i < args.length-2) {
prism.getSettings().set(PrismSettings.PRISM_SYMM_RED_PARAMS, args[++i]+" "+args[++i]);
}
else {
errorAndExit("-symm switch requires two parameters (num. modules before/after symmetric ones)");
}
}
// unknown switch - error
else {
errorAndExit("Invalid switch -" + sw + " (type \"prism -help\" for full list)");

13
prism/src/prism/PrismSettings.java

@ -89,6 +89,7 @@ public class PrismSettings implements Observer
public static final String PRISM_EXTRA_DD_INFO = "prism.extraDDInfo";
public static final String PRISM_EXTRA_REACH_INFO = "prism.extraReachInfo";
public static final String PRISM_SCC_METHOD = "prism.sccMethod";
public static final String PRISM_SYMM_RED_PARAMS = "prism.symmRedParams";
//GUI Model
public static final String MODEL_AUTO_PARSE = "model.autoParse";
@ -185,7 +186,8 @@ public class PrismSettings implements Observer
{ BOOLEAN_TYPE, PRISM_DO_SS_DETECTION, "Use steady-state detection", "3.0", new Boolean(true), "0,", "Use steady-state detection during CTMC transient probability computation." },
{ BOOLEAN_TYPE, PRISM_EXTRA_DD_INFO, "Extra MTBDD information", "3.2", new Boolean(false), "0,", "Display extra information about (MT)BDDs used during and after model construction." },
{ BOOLEAN_TYPE, PRISM_EXTRA_REACH_INFO, "Extra reachability information", "3.2", new Boolean(false), "0,", "Display extra information about progress of reachability during model construction." },
{ CHOICE_TYPE, PRISM_SCC_METHOD, "SCC decomposition method", "3.2", "Lockstep", "Xie-Beerel,Lockstep,SCC-Find", "Which algorithm to use for decomposing a graph into strongly connected components (SCCs)." }
{ CHOICE_TYPE, PRISM_SCC_METHOD, "SCC decomposition method", "3.2", "Lockstep", "Xie-Beerel,Lockstep,SCC-Find", "Which algorithm to use for decomposing a graph into strongly connected components (SCCs)." },
{ STRING_TYPE, PRISM_SYMM_RED_PARAMS, "Symmetry reduction parameters", "3.2", "", "", "Parameters for symmetry reduction (format: \"i j\" where i and j are the number of modules before and after the symmetric ones; empty string means symmetry reduction disabled)." }
},
{
{ BOOLEAN_TYPE, MODEL_AUTO_PARSE, "Auto parse", "3.0", new Boolean(true), "", "Parse PRISM models automatically as they are loaded/edited in the text editor." },
@ -262,7 +264,14 @@ public class PrismSettings implements Observer
Setting set;
if(setting[0].equals(INTEGER_TYPE))
if(setting[0].equals(STRING_TYPE))
{
set = new SingleLineStringSetting(display, (String)value, comment, optionOwners[i], false);
set.setKey(key);
set.setVersion(version);
optionOwners[i].addSetting(set);
}
else if(setting[0].equals(INTEGER_TYPE))
{
if(constraint.equals(""))
set = new IntegerSetting(display, (Integer)value, comment, optionOwners[i], false);

51
prism/src/prism/ProbModel.java

@ -413,6 +413,28 @@ public class ProbModel implements Model
numStartStates = JDD.GetNumMinterms(start, allDDRowVars.n());
}
/**
* Reset transition matrix DD
*/
public void resetTrans(JDDNode trans)
{
if (this.trans != null) JDD.Deref(this.trans);
this.trans = trans;
}
/**
* Reset transition rewards DDs
*/
public void resetTransRewards(int i, JDDNode transRewards)
{
if (this.transRewards[i] != null) {
JDD.Deref(this.transRewards[i]);
}
this.transRewards[i] = transRewards;
}
// do reachability
public void doReachability()
@ -423,13 +445,7 @@ public class ProbModel implements Model
public void doReachability(boolean extraReachInfo)
{
// compute reachable states
reach = PrismMTBDD.Reachability(trans01, allDDRowVars, allDDColVars, start, extraReachInfo ? 1 : 0);
// work out number of reachable states
numStates = JDD.GetNumMinterms(reach, allDDRowVars.n());
// build odd
odd = ODDUtils.BuildODD(reach, allDDRowVars);
setReach(PrismMTBDD.Reachability(trans01, allDDRowVars, allDDColVars, start, extraReachInfo ? 1 : 0));
}
// this method allows you to skip the reachability phase
@ -447,6 +463,22 @@ public class ProbModel implements Model
odd = ODDUtils.BuildODD(reach, allDDRowVars);
}
/**
* Set reachable states BDD (and compute number of states and ODD)
*/
public void setReach(JDDNode reach)
{
if (this.reach != null) JDD.Deref(this.reach);
this.reach = reach;
// work out number of reachable states
numStates = JDD.GetNumMinterms(reach, allDDRowVars.n());
// build odd
odd = ODDUtils.BuildODD(reach, allDDRowVars);
}
// remove non-reachable states from various dds
// (and calculate num transitions)
@ -480,6 +512,11 @@ public class ProbModel implements Model
transRewards[i] = JDD.Apply(JDD.TIMES, tmp, transRewards[i]);
}
// filter start states, work out number of initial states
JDD.Ref(reach);
start = JDD.Apply(JDD.TIMES, reach, start);
numStartStates = JDD.GetNumMinterms(start, allDDRowVars.n());
// work out number of transitions
numTransitions = JDD.GetNumMinterms(trans01, getNumDDVarsInTrans());
}

Loading…
Cancel
Save