From 3032a048197e1c014e236c53348da55b007feef0 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Tue, 23 Jul 2013 23:17:55 +0000 Subject: [PATCH] Refactor symbolic LTL code: pull out some MEC stuff. git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@7146 bbc10eb1-c90d-0410-af57-cb519fbb1720 --- prism/src/prism/ECComputer.java | 20 ++- prism/src/prism/ECComputerDefault.java | 16 ++- prism/src/prism/LTLModelChecker.java | 169 ++++-------------------- prism/src/prism/NondetModelChecker.java | 18 +-- 4 files changed, 70 insertions(+), 153 deletions(-) diff --git a/prism/src/prism/ECComputer.java b/prism/src/prism/ECComputer.java index bb54fdc7..e63cca56 100644 --- a/prism/src/prism/ECComputer.java +++ b/prism/src/prism/ECComputer.java @@ -90,12 +90,30 @@ public abstract class ECComputer extends PrismComponent } /** - * Compute states of maximal end components (MECs) and store them. + * Compute states of all maximal end components (MECs) and store them. * They can be retrieved using {@link #getMECStates()}. * You will need to to deref these afterwards. */ public abstract void computeMECStates() throws PrismException; + /** + * Compute states of all maximal end components (MECs) contained within {@code states}, and store them. + * They can be retrieved using {@link #getMECStates()}. + * You will need to to deref these afterwards. + * @param states BDD of the set of containing states + */ + public abstract void computeMECStates(JDDNode states) throws PrismException; + + /** + * Compute states of all accepting maximal end components (MECs) contained within {@code states}, and store them, + * where acceptance is defined as those which intersect with {@code filter}. + * They can be retrieved using {@link #getMECStates()}. + * You will need to to deref these afterwards. + * @param states BDD of the set of containing states + * @param filter BDD for the set of accepting states + */ + public abstract void computeMECStates(JDDNode states, JDDNode filter) throws PrismException; + /** * Get the list of states for computed MECs. * You need to deref these BDDs when you are finished with them. diff --git a/prism/src/prism/ECComputerDefault.java b/prism/src/prism/ECComputerDefault.java index 9d479fb4..d7c04498 100644 --- a/prism/src/prism/ECComputerDefault.java +++ b/prism/src/prism/ECComputerDefault.java @@ -55,13 +55,25 @@ public class ECComputerDefault extends ECComputer @Override public void computeMECStates() throws PrismException { - mecs = findEndComponents(reach, reach); + mecs = findEndComponents(reach, null); + } + + @Override + public void computeMECStates(JDDNode states) throws PrismException + { + mecs = findEndComponents(states, null); + } + + @Override + public void computeMECStates(JDDNode states, JDDNode filter) throws PrismException + { + mecs = findEndComponents(states, filter); } // Computation /** - * Find all maximal accepting end components (ECs) contained within {@code states}, + * Find all accepting maximal end components (MECs) contained within {@code states}, * where acceptance is defined as those which intersect with {@code filter}. * (If {@code filter} is null, the acceptance condition is trivially satisfied.) * @param states BDD of the set of containing states diff --git a/prism/src/prism/LTLModelChecker.java b/prism/src/prism/LTLModelChecker.java index 8a24c98c..ac03bf62 100644 --- a/prism/src/prism/LTLModelChecker.java +++ b/prism/src/prism/LTLModelChecker.java @@ -37,15 +37,14 @@ import parser.type.*; /** * LTL model checking functionality */ -public class LTLModelChecker +public class LTLModelChecker extends PrismComponent { - protected Prism prism; - protected PrismLog mainLog; - - public LTLModelChecker(Prism prism) throws PrismException + /** + * Create a new DTMCModelChecker, inherit basic state from parent (unless null). + */ + public LTLModelChecker(PrismComponent parent) throws PrismException { - this.prism = prism; - mainLog = prism.getMainLog(); + super(parent); } /** @@ -633,7 +632,7 @@ public class LTLModelChecker int i, j, n; // Compute BSCCs for model - sccComputer = prism.getSCCComputer(model); + sccComputer = SCCComputer.createSCCComputer(this, model); sccComputer.computeBSCCs(); vectBSCCs = sccComputer.getBSCCs(); JDD.Deref(sccComputer.getNotInBSCCs()); @@ -743,7 +742,7 @@ public class LTLModelChecker candidateStates = JDD.ThereExists(candidateStates, model.getAllDDColVars()); candidateStates = JDD.ThereExists(candidateStates, model.getAllDDNondetVars()); // find all maximal end components - Vector allecs = findEndComponents(model, candidateStates, acceptanceVector_L); + List allecs = findEndComponents(model, candidateStates, acceptanceVector_L); JDD.Deref(candidateStates); for (i = 0; i < dra.getNumAcceptancePairs(); i++) { @@ -752,7 +751,7 @@ public class LTLModelChecker acceptanceVector_L = statesL.get(i); for (JDDNode ec : allecs) { // build bdd of accepting states (under H_i) in the product model - Vector ecs; + List ecs; JDD.Ref(ec); JDD.Ref(acceptanceVector_H); candidateStates = JDD.And(ec, acceptanceVector_H); @@ -840,7 +839,7 @@ public class LTLModelChecker acceptingStates = findFairECs(model, candidateStates); } else { // find ECs in acceptingStates that are accepting under L_i - Vector ecs = findEndComponents(model, candidateStates); + List ecs = findEndComponents(model, candidateStates); JDD.Deref(candidateStates); acceptingStates = filteredUnion(ecs, acceptanceVector_L); } @@ -854,7 +853,7 @@ public class LTLModelChecker } public JDDNode findMultiAcceptingStates(DRA dra, NondetModel model, JDDVars draDDRowVars, JDDVars draDDColVars, boolean fairness, - Vector allecs, ArrayList statesH, ArrayList statesL) throws PrismException + List allecs, List statesH, List statesL) throws PrismException { JDDNode acceptingStates = null, allAcceptingStates, candidateStates; JDDNode acceptanceVector_H, acceptanceVector_L; @@ -870,7 +869,7 @@ public class LTLModelChecker acceptanceVector_L = statesL.get(i); for (JDDNode ec : allecs) { // build bdd of accepting states (under H_i) in the product model - Vector ecs = null; + List ecs = null; JDD.Ref(ec); JDD.Ref(acceptanceVector_H); candidateStates = JDD.And(ec, acceptanceVector_H); @@ -1002,7 +1001,7 @@ public class LTLModelChecker JDD.Ref(e.statesL.get(j)); JDD.Ref(nextstatesL.get(k)); JDDNode acceptanceVector_L = JDD.And(e.statesL.get(j), nextstatesL.get(k)); - Vector ecs = null; + List ecs = null; ecs = findEndComponents(model, candidateStates1, acceptanceVector_L); JDD.Deref(candidateStates1); @@ -1144,76 +1143,30 @@ public class LTLModelChecker } /** - * Find all maximal end components (ECs) contained within {@code states}. + * Find all maximal end components (MECs) contained within {@code states}. * @param states BDD of the set of containing states * @return a vector of (referenced) BDDs representing the ECs */ - public Vector findEndComponents(NondetModel model, JDDNode states) throws PrismException + public List findEndComponents(NondetModel model, JDDNode states) throws PrismException { - return findEndComponents(model, states, null); + ECComputer ecComputer = ECComputer.createECComputer(this, model); + ecComputer.computeMECStates(states, null); + return ecComputer.getMECStates(); } /** - * Find all maximal accepting end components (ECs) contained within {@code states}, + * Find all accepting maximal end components (MECs) contained within {@code states}, * where acceptance is defined as those which intersect with {@code filter}. * (If {@code filter} is null, the acceptance condition is trivially satisfied.) * @param states BDD of the set of containing states * @param filter BDD for the set of accepting states * @return a vector of (referenced) BDDs representing the ECs */ - public Vector findEndComponents(NondetModel model, JDDNode states, JDDNode filter) throws PrismException + public List findEndComponents(NondetModel model, JDDNode states, JDDNode filter) throws PrismException { - Stack candidates = new Stack(); - JDD.Ref(states); - candidates.push(states); - Vector ecs = new Vector(); - SCCComputer sccComputer; - - while (!candidates.isEmpty()) { - JDDNode candidate = candidates.pop(); - // Compute the stable set - JDD.Ref(candidate); - JDDNode stableSet = findMaximalStableSet(model, candidate); - // Drop empty sets - if (stableSet.equals(JDD.ZERO)) { - JDD.Deref(stableSet); - JDD.Deref(candidate); - continue; - } - - if (stableSet.equals(candidate) && JDD.GetNumMinterms(stableSet, model.getNumDDRowVars()) == 1) { - ecs.add(candidate); - JDD.Deref(stableSet); - continue; - } - - // Filter bad transitions - JDD.Ref(stableSet); - JDDNode stableSetTrans = maxStableSetTrans(model, stableSet); - - // now find the maximal SCCs in (stableSet, stableSetTrans) - List sccs; - sccComputer = prism.getSCCComputer(stableSet, stableSetTrans, model.getAllDDRowVars(), model.getAllDDColVars()); - if (filter != null) - sccComputer.computeSCCs(filter); - else - sccComputer.computeSCCs(); - JDD.Deref(stableSet); - JDD.Deref(stableSetTrans); - sccs = sccComputer.getSCCs(); - JDD.Deref(sccComputer.getNotInSCCs()); - if (sccs.size() > 0) { - if (sccs.size() > 1 || !sccs.get(0).equals(candidate)) { - candidates.addAll(sccs); - JDD.Deref(candidate); - } else { - ecs.add(candidate); - JDD.Deref(sccs.get(0)); - } - } else - JDD.Deref(candidate); - } - return ecs; + ECComputer ecComputer = ECComputer.createECComputer(this, model); + ecComputer.computeMECStates(states, filter); + return ecComputer.getMECStates(); } /** @@ -1222,10 +1175,10 @@ public class LTLModelChecker * @param states BDD of the set of containing states * @return a vector of (referenced) BDDs representing the ECs */ - public Vector findBottomEndComponents(NondetModel model, JDDNode states) throws PrismException + public List findBottomEndComponents(NondetModel model, JDDNode states) throws PrismException { - Vector ecs = findEndComponents(model, states); - Vector becs = new Vector(); + List ecs = findEndComponents(model, states); + List becs = new Vector(); JDDNode out; for (JDDNode scc : ecs) { @@ -1235,7 +1188,7 @@ public class LTLModelChecker JDD.Ref(scc); out = JDD.And(out, JDD.Not(JDD.PermuteVariables(scc, model.getAllDDRowVars(), model.getAllDDColVars()))); if (out.equals(JDD.ZERO)) { - becs.addElement(scc); + becs.add(scc); } else { JDD.Deref(scc); } @@ -1244,72 +1197,6 @@ public class LTLModelChecker return becs; } - /** - * Returns a stable set of states contained in candidateStates - * - * @param candidateStates - * set of candidate states S x H_i (dereferenced after calling this function) - * @return a referenced BDD with the maximal stable set in c - */ - private JDDNode findMaximalStableSet(NondetModel model, JDDNode candidateStates) - { - - JDDNode old = JDD.Constant(0); - JDDNode current = candidateStates; - - while (!current.equals(old)) { - JDD.Deref(old); - JDD.Ref(current); - old = current; - - JDD.Ref(current); - JDD.Ref(model.getTrans()); - // Select transitions starting in current - JDDNode currTrans = JDD.Apply(JDD.TIMES, model.getTrans(), current); - // Select transitions starting in current and ending in current - JDDNode tmp = JDD.PermuteVariables(current, model.getAllDDRowVars(), model.getAllDDColVars()); - tmp = JDD.Apply(JDD.TIMES, currTrans, tmp); - // Sum all successor probabilities for each (state, action) tuple - tmp = JDD.SumAbstract(tmp, model.getAllDDColVars()); - // If the sum for a (state,action) tuple is 1, - // there is an action that remains in the stable set with prob 1 - tmp = JDD.GreaterThan(tmp, 1 - prism.getSumRoundOff()); - // Without fairness, we just need one action per state - current = JDD.ThereExists(tmp, model.getAllDDNondetVars()); - } - JDD.Deref(old); - return current; - } - - /** - * Returns the transition relation of a stable set - * - * @param b - * BDD of a stable set (dereferenced after calling this function) - * @return referenced BDD of the transition relation restricted to the stable set - */ - public JDDNode maxStableSetTrans(NondetModel model, JDDNode b) - { - - JDD.Ref(b); - JDD.Ref(model.getTrans()); - // Select transitions starting in b - JDDNode currTrans = JDD.Apply(JDD.TIMES, model.getTrans(), b); - JDDNode mask = JDD.PermuteVariables(b, model.getAllDDRowVars(), model.getAllDDColVars()); - // Select transitions starting in current and ending in current - mask = JDD.Apply(JDD.TIMES, currTrans, mask); - // Sum all successor probabilities for each (state, action) tuple - mask = JDD.SumAbstract(mask, model.getAllDDColVars()); - // If the sum for a (state,action) tuple is 1, - // there is an action that remains in the stable set with prob 1 - mask = JDD.GreaterThan(mask, 1 - prism.getSumRoundOff()); - // select the transitions starting in these tuples - JDD.Ref(model.getTrans01()); - JDDNode stableTrans01 = JDD.And(model.getTrans01(), mask); - // Abstract over actions - return JDD.ThereExists(stableTrans01, model.getAllDDNondetVars()); - } - public JDDNode maxStableSetTrans1(NondetModel model, JDDNode b) { @@ -1324,7 +1211,7 @@ public class LTLModelChecker mask = JDD.SumAbstract(mask, model.getAllDDColVars()); // If the sum for a (state,action) tuple is 1, // there is an action that remains in the stable set with prob 1 - mask = JDD.GreaterThan(mask, 1 - prism.getSumRoundOff()); + mask = JDD.GreaterThan(mask, 1 - settings.getDouble(PrismSettings.PRISM_SUM_ROUND_OFF)); // select the transitions starting in these tuples JDD.Ref(model.getTrans01()); JDDNode stableTrans01 = JDD.And(model.getTrans01(), mask); diff --git a/prism/src/prism/NondetModelChecker.java b/prism/src/prism/NondetModelChecker.java index 33916d9c..5220bff9 100644 --- a/prism/src/prism/NondetModelChecker.java +++ b/prism/src/prism/NondetModelChecker.java @@ -490,7 +490,7 @@ public class NondetModelChecker extends NonProbModelChecker //Vojta: in addition to calling a method which does the computation //there are some other bits which I don't currently understand protected JDDNode computeAcceptingEndComponent(DRA dra, NondetModel modelProduct, JDDVars draDDRowVars, JDDVars draDDColVars, - Vector allecs, ArrayList statesH, ArrayList statesL, //Vojta: at the time of writing this I have no idea what these two parameters do, so I don't know how to call them + List allecs, List statesH, List statesL, //Vojta: at the time of writing this I have no idea what these two parameters do, so I don't know how to call them LTLModelChecker mcLtl, boolean conflictformulaeGtOne, String name) throws PrismException { mainLog.println("\nFinding accepting end components for " + name + "..."); @@ -512,7 +512,7 @@ public class NondetModelChecker extends NonProbModelChecker protected void removeNonZeroMecsForMax(NondetModel modelProduct, LTLModelChecker mcLtl, List rewardsIndex, OpsAndBoundsList opsAndBounds, int numTargets, DRA dra[], JDDVars draDDRowVars[], JDDVars draDDColVars[]) throws PrismException { - Vector mecs = mcLtl.findEndComponents(modelProduct, modelProduct.getReach()); + List mecs = mcLtl.findEndComponents(modelProduct, modelProduct.getReach()); JDDNode removedActions = JDD.Constant(0); JDDNode rmecs = JDD.Constant(0); boolean mecflags[] = new boolean[mecs.size()]; @@ -707,7 +707,7 @@ public class NondetModelChecker extends NonProbModelChecker protected void addDummyFormula(NondetModel modelProduct, LTLModelChecker mcLtl, List targetDDs, OpsAndBoundsList opsAndBounds) throws PrismException { - Vector tmpecs = mcLtl.findEndComponents(modelProduct, modelProduct.getReach()); + List tmpecs = mcLtl.findEndComponents(modelProduct, modelProduct.getReach()); JDDNode acceptingStates = JDD.Constant(0); for (JDDNode set : tmpecs) acceptingStates = JDD.Or(acceptingStates, set); @@ -749,7 +749,7 @@ public class NondetModelChecker extends NonProbModelChecker return modelNew; } - protected Vector computeAllEcs(NondetModel modelProduct, LTLModelChecker mcLtl, ArrayList> allstatesH, + protected List computeAllEcs(NondetModel modelProduct, LTLModelChecker mcLtl, ArrayList> allstatesH, ArrayList> allstatesL, JDDNode acceptanceVector_H, JDDNode acceptanceVector_L, JDDVars draDDRowVars[], JDDVars draDDColVars[], OpsAndBoundsList opsAndBounds, int numTargets) throws PrismException { @@ -765,7 +765,7 @@ public class NondetModelChecker extends NonProbModelChecker candidateStates = JDD.ThereExists(candidateStates, modelProduct.getAllDDColVars()); candidateStates = JDD.ThereExists(candidateStates, modelProduct.getAllDDNondetVars()); // find all maximal end components - Vector allecs = mcLtl.findEndComponents(modelProduct, candidateStates, acceptanceVector_L); + List allecs = mcLtl.findEndComponents(modelProduct, candidateStates, acceptanceVector_L); JDD.Deref(candidateStates); JDD.Deref(acceptanceVector_L); return allecs; @@ -957,7 +957,7 @@ public class NondetModelChecker extends NonProbModelChecker } // Find accepting maximum end components for each LTL formula - Vector allecs = computeAllEcs(modelProduct, mcLtl, allstatesH, allstatesL, acceptanceVector_H, acceptanceVector_L, draDDRowVars, draDDColVars, + List allecs = computeAllEcs(modelProduct, mcLtl, allstatesH, allstatesL, acceptanceVector_H, acceptanceVector_L, draDDRowVars, draDDColVars, opsAndBounds, numTargets); for (int i = 0; i < numTargets; i++) { @@ -1058,7 +1058,7 @@ public class NondetModelChecker extends NonProbModelChecker } protected void findTargetStates(NondetModel modelProduct, LTLModelChecker mcLtl, int numTargets, int conflictformulae, boolean reachExpr[], DRA dra[], - JDDVars draDDRowVars[], JDDVars draDDColVars[], Vector targetDDs, List multitargetDDs, List multitargetIDs) + JDDVars draDDRowVars[], JDDVars draDDColVars[], List targetDDs, List multitargetDDs, List multitargetIDs) throws PrismException { int i, j; @@ -1100,7 +1100,7 @@ public class NondetModelChecker extends NonProbModelChecker } // Find accepting maximum end components for each LTL formula - Vector allecs = null; + List allecs = null; //use acceptanceVector_H and acceptanceVector_L to speed up scc computation /*// check number of states in each scc allecs = mcLtl.findEndComponents(modelProduct, modelProduct.getReach()); @@ -2060,7 +2060,7 @@ public class NondetModelChecker extends NonProbModelChecker no = JDD.Constant(0); bottomec = PrismMTBDD.Prob0A(modelProduct.getTrans01(), modelProduct.getReach(), modelProduct.getAllDDRowVars(), modelProduct.getAllDDColVars(), modelProduct.getAllDDNondetVars(), modelProduct.getReach(), yes); - Vector becs = mcLtl.findBottomEndComponents(modelProduct, bottomec); + List becs = mcLtl.findBottomEndComponents(modelProduct, bottomec); JDD.Deref(bottomec); bottomec = JDD.Constant(0); for (JDDNode ec : becs)