diff --git a/prism/src/prism/Explicit2MTBDD.java b/prism/src/prism/Explicit2MTBDD.java index 41b2bf5e..1157421c 100644 --- a/prism/src/prism/Explicit2MTBDD.java +++ b/prism/src/prism/Explicit2MTBDD.java @@ -100,6 +100,11 @@ public class Explicit2MTBDD private JDDNode[] ddChoiceVars; // individual dd vars for local non-det. // names for all dd vars used private Vector ddVarNames; + // action info + private Vector synchs; // list of action names + private JDDNode transActions; // dd for transition action labels (MDPs) + private Vector transPerAction; // dds for transition action labels (D/CTMCs) + private int maxNumChoices = 0; @@ -431,6 +436,14 @@ public class Explicit2MTBDD numModules, moduleNames, moduleDDRowVars, moduleDDColVars, numVars, varList, varDDRowVars, varDDColVars, constantValues); } + // set action info + // TODO: disable if not required? + model.setSynchs(synchs); + if (modelType != ModelType.MDP) { + model.setTransPerAction((JDDNode[])transPerAction.toArray(new JDDNode[0])); + } else { + model.setTransActions(transActions); + } // do reachability (or not) if (prism.getDoReach()) { @@ -690,15 +703,23 @@ public class Explicit2MTBDD private void buildTrans() throws PrismException { BufferedReader in; - String s, ss[]; - int i, r, c, k = 0, lineNum = 0; - double d, x = 0; - boolean foundReward; - JDDNode tmp; + String s, ss[], a; + int i, j, r, c, k = 0, lineNum = 0; + double d; + JDDNode elem, tmp; + + // initailise action list + synchs = new Vector(); // initialise mtbdds trans = JDD.Constant(0); transRewards = JDD.Constant(0); + if (modelType != ModelType.MDP) { + transPerAction = new Vector(); + transPerAction.add(JDD.Constant(0)); + } else { + transActions = JDD.Constant(0); + } try { // open file for reading @@ -711,7 +732,7 @@ public class Explicit2MTBDD // skip blank lines s = s.trim(); if (s.length() > 0) { - foundReward = false; + a = ""; // parse line, split into parts ss = s.split(" "); // case for dtmcs/ctmcs... @@ -721,10 +742,8 @@ public class Explicit2MTBDD c = Integer.parseInt(ss[1]); d = Double.parseDouble(ss[2]); if (ss.length == 4) { - //foundReward = true; - //x = Double.parseDouble(ss[3]); + a = ss[3]; } - //System.out.println("("+r+","+c+") = "+d); } // case for mdps... else { @@ -734,42 +753,70 @@ public class Explicit2MTBDD c = Integer.parseInt(ss[2]); d = Double.parseDouble(ss[3]); if (ss.length == 5) { - //foundReward = true; - //x = Double.parseDouble(ss[4]); + a = ss[4]; } - //System.out.println("("+r+","+k+","+c+") = "+d); } // construct element of matrix mtbdd // case where we don't have a state list... if (statesFile == null) { /// ...for dtmcs/ctmcs... if (modelType != ModelType.MDP) { - tmp = JDD.SetMatrixElement(JDD.Constant(0), varDDRowVars[0], varDDColVars[0], r, c, 1.0); + elem = JDD.SetMatrixElement(JDD.Constant(0), varDDRowVars[0], varDDColVars[0], r, c, 1.0); } /// ...for mdps... else { - tmp = JDD.Set3DMatrixElement(JDD.Constant(0), varDDRowVars[0], varDDColVars[0], allDDChoiceVars, r, c, k, 1.0); + elem = JDD.Set3DMatrixElement(JDD.Constant(0), varDDRowVars[0], varDDColVars[0], allDDChoiceVars, r, c, k, 1.0); } } // case where we do have a state list... else { - tmp = JDD.Constant(1); + elem = JDD.Constant(1); for (i = 0; i < numVars; i++) { - tmp = JDD.Apply(JDD.TIMES, tmp, JDD.SetVectorElement(JDD.Constant(0), varDDRowVars[i], statesArray[r][i], 1)); - tmp = JDD.Apply(JDD.TIMES, tmp, JDD.SetVectorElement(JDD.Constant(0), varDDColVars[i], statesArray[c][i], 1)); + elem = JDD.Apply(JDD.TIMES, elem, JDD.SetVectorElement(JDD.Constant(0), varDDRowVars[i], statesArray[r][i], 1)); + elem = JDD.Apply(JDD.TIMES, elem, JDD.SetVectorElement(JDD.Constant(0), varDDColVars[i], statesArray[c][i], 1)); } if (modelType == ModelType.MDP) { - tmp = JDD.Apply(JDD.TIMES, tmp, JDD.SetVectorElement(JDD.Constant(0), allDDChoiceVars, k, 1)); + elem = JDD.Apply(JDD.TIMES, elem, JDD.SetVectorElement(JDD.Constant(0), allDDChoiceVars, k, 1)); } } // add it into mtbdds for transition matrix and transition rewards - JDD.Ref(tmp); - trans = JDD.Apply(JDD.PLUS, trans, JDD.Apply(JDD.TIMES, JDD.Constant(d), tmp)); - if (foundReward) { - JDD.Ref(tmp); - transRewards = JDD.Apply(JDD.PLUS, transRewards, JDD.Apply(JDD.TIMES, JDD.Constant(x), tmp)); + JDD.Ref(elem); + trans = JDD.Apply(JDD.PLUS, trans, JDD.Apply(JDD.TIMES, JDD.Constant(d), elem)); + // look up action name + if (!("".equals(a))) { + j = synchs.indexOf(a); + // add to list if first time seen + if (j == -1) { + synchs.add(a); + j = synchs.size() - 1; + } + j++; + } else { + j = 0; + } + /// ...for dtmcs/ctmcs... + if (modelType != ModelType.MDP) { + // get (or create) dd for action j + if (j < transPerAction.size()) { + tmp = transPerAction.get(j); + } else { + tmp = JDD.Constant(0); + transPerAction.add(tmp); + } + // add element to matrix + JDD.Ref(elem); + tmp = JDD.Apply(JDD.PLUS, tmp, JDD.Apply(JDD.TIMES, JDD.Constant(d), elem)); + transPerAction.set(j, tmp); + } + /// ...for mdps... + else { + JDD.Ref(elem); + tmp = JDD.ThereExists(elem, allDDColVars); + // use max here because we see multiple transitions for a sinlge choice + transActions = JDD.Apply(JDD.MAX, transActions, JDD.Apply(JDD.TIMES, JDD.Constant(j), tmp)); } - JDD.Deref(tmp); + // deref element dd + JDD.Deref(elem); } // read next line s = in.readLine(); lineNum++; diff --git a/prism/src/prism/ProbModel.java b/prism/src/prism/ProbModel.java index 7e873d09..4afde2f3 100644 --- a/prism/src/prism/ProbModel.java +++ b/prism/src/prism/ProbModel.java @@ -695,19 +695,14 @@ public class ProbModel implements Model } } } - if (transActions != null && !transActions.equals(JDD.ZERO)) { - log.print("Action label indices: "); - log.print(JDD.GetNumNodes(transActions) + " nodes ("); - log.print(JDD.GetNumTerminals(transActions) + " terminal)\n"); - } if (transPerAction != null) { for (i = 0; i < numSynchs + 1; i++) { log.print("Action label info ("); log.print((i == 0 ? "" : synchs.get(i - 1)) + "): "); - log.print(JDD.GetNumNodes(transPerAction[i]) + " nodes ("); - log.print(JDD.GetNumTerminals(transPerAction[i]) + " terminal)\n"); + log.println(JDD.GetInfoString(transPerAction[i], getNumDDVarsInTrans())); } } + // Don't need to print info for transActions (only stored for MDPs) } // export transition matrix to a file