diff --git a/prism/src/explicit/CTMCModelChecker.java b/prism/src/explicit/CTMCModelChecker.java index 26ac5f9e..23acbc32 100644 --- a/prism/src/explicit/CTMCModelChecker.java +++ b/prism/src/explicit/CTMCModelChecker.java @@ -101,7 +101,7 @@ public class CTMCModelChecker extends DTMCModelChecker // a trivial case: "U<=0" if (lTime == 0 && uTime == 0) { // prob is 1 in b2 states, 0 otherwise - probs = StateValues.createFromBitSetAsDoubles(model.getNumStates(), b2); + probs = StateValues.createFromBitSetAsDoubles(b2, model); } else { // break down into different cases to compute probabilities @@ -112,13 +112,13 @@ public class CTMCModelChecker extends DTMCModelChecker if (lTime == 0) { // compute probs res = computeUntilProbs((DTMC) model, b1, b2); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); } else { // compute unbounded until probs tmpRes = computeUntilProbs((DTMC) model, b1, b2); // compute bounded until probs res = computeTransientBackwardsProbs((CTMC) model, b1, b1, lTime, tmpRes.soln); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); } } // <= uTime @@ -126,7 +126,7 @@ public class CTMCModelChecker extends DTMCModelChecker // nb: uTime != 0 since would be caught above (trivial case) b1.andNot(b2); res = computeTransientBackwardsProbs((CTMC) model, b2, b1, uTime, null); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); // set values to exactly 1 for target (b2) states // (these are computed inexactly during uniformisation) int n = model.getNumStates(); @@ -141,7 +141,7 @@ public class CTMCModelChecker extends DTMCModelChecker tmp.andNot(b2); tmpRes = computeTransientBackwardsProbs((CTMC) model, b2, tmp, uTime - lTime, null); res = computeTransientBackwardsProbs((CTMC) model, b1, b1, lTime, tmpRes.soln); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); } } @@ -183,7 +183,7 @@ public class CTMCModelChecker extends DTMCModelChecker // Compute transient probabilities res = computeTransientProbs(ctmc, t, initDistNew); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, ctmc); return probs; } diff --git a/prism/src/explicit/CTMDPModelChecker.java b/prism/src/explicit/CTMDPModelChecker.java index e4657a79..b40a8dab 100644 --- a/prism/src/explicit/CTMDPModelChecker.java +++ b/prism/src/explicit/CTMDPModelChecker.java @@ -63,10 +63,10 @@ public class CTMDPModelChecker extends MDPModelChecker // a trivial case: "U<=0" if (uTime == 0) { // prob is 1 in b2 states, 0 otherwise - probs = StateValues.createFromBitSetAsDoubles(model.getNumStates(), b2); + probs = StateValues.createFromBitSetAsDoubles(b2, model); } else { res = computeBoundedUntilProbs((CTMDP) model, b1, b2, uTime, min); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); } return probs; diff --git a/prism/src/explicit/DTMCModelChecker.java b/prism/src/explicit/DTMCModelChecker.java index ebb3261a..d04156cc 100644 --- a/prism/src/explicit/DTMCModelChecker.java +++ b/prism/src/explicit/DTMCModelChecker.java @@ -112,10 +112,10 @@ public class DTMCModelChecker extends ProbModelChecker // a trivial case: "U<=0" if (time == 0) { // prob is 1 in b2 states, 0 otherwise - probs = StateValues.createFromBitSetAsDoubles(model.getNumStates(), b2); + probs = StateValues.createFromBitSetAsDoubles(b2, model); } else { res = computeBoundedUntilProbs((DTMC) model, b1, b2, time); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); } return probs; @@ -141,7 +141,7 @@ public class DTMCModelChecker extends ProbModelChecker // allDDRowVars.n()) + " states\n"); res = computeUntilProbs((DTMC) model, b1, b2); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); return probs; } @@ -187,7 +187,7 @@ public class DTMCModelChecker extends ProbModelChecker // allDDRowVars.n())); res = computeReachRewards((DTMC) model, modelRewards, b); - rewards = StateValues.createFromDoubleArray(res.soln); + rewards = StateValues.createFromDoubleArray(res.soln, model); return rewards; } @@ -228,7 +228,7 @@ public class DTMCModelChecker extends ProbModelChecker // Compute transient probabilities res = computeSteadyStateProbs(dtmc, initDistNew); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, dtmc); return probs; } diff --git a/prism/src/explicit/MDPModelChecker.java b/prism/src/explicit/MDPModelChecker.java index 1cd24d45..fc582274 100644 --- a/prism/src/explicit/MDPModelChecker.java +++ b/prism/src/explicit/MDPModelChecker.java @@ -121,10 +121,10 @@ public class MDPModelChecker extends ProbModelChecker // a trivial case: "U<=0" if (time == 0) { // prob is 1 in b2 states, 0 otherwise - probs = StateValues.createFromBitSetAsDoubles(model.getNumStates(), b2); + probs = StateValues.createFromBitSetAsDoubles(b2, model); } else { res = computeBoundedUntilProbs((MDP) model, b1, b2, time, min); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); } return probs; @@ -150,7 +150,7 @@ public class MDPModelChecker extends ProbModelChecker // allDDRowVars.n()) + " states\n"); res = computeUntilProbs((MDP) model, b1, b2, min); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); return probs; } @@ -196,7 +196,7 @@ public class MDPModelChecker extends ProbModelChecker // allDDRowVars.n())); res = computeReachRewards((MDP) model, modelRewards, b, min); - rewards = StateValues.createFromDoubleArray(res.soln); + rewards = StateValues.createFromDoubleArray(res.soln, model); return rewards; } diff --git a/prism/src/explicit/ProbModelChecker.java b/prism/src/explicit/ProbModelChecker.java index a6892d29..c36bd9c6 100644 --- a/prism/src/explicit/ProbModelChecker.java +++ b/prism/src/explicit/ProbModelChecker.java @@ -141,7 +141,7 @@ public class ProbModelChecker extends StateModelChecker else { BitSet sol = probs.getBitSetFromInterval(relOp, p); probs.clear(); - return StateValues.createFromBitSet(sol, model.getNumStates()); + return StateValues.createFromBitSet(sol, model); } } @@ -246,7 +246,7 @@ public class ProbModelChecker extends StateModelChecker else { BitSet sol = rews.getBitSetFromInterval(relOp, r); rews.clear(); - return StateValues.createFromBitSet(sol, model.getNumStates()); + return StateValues.createFromBitSet(sol, model); } } } diff --git a/prism/src/explicit/STPGModelChecker.java b/prism/src/explicit/STPGModelChecker.java index 771b1b02..971af7b7 100644 --- a/prism/src/explicit/STPGModelChecker.java +++ b/prism/src/explicit/STPGModelChecker.java @@ -119,10 +119,10 @@ public class STPGModelChecker extends ProbModelChecker // a trivial case: "U<=0" if (time == 0) { // prob is 1 in b2 states, 0 otherwise - probs = StateValues.createFromBitSetAsDoubles(model.getNumStates(), b2); + probs = StateValues.createFromBitSetAsDoubles(b2, model); } else { res = computeBoundedUntilProbs((STPG) model, b1, b2, time, min1, min2); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); } return probs; @@ -148,7 +148,7 @@ public class STPGModelChecker extends ProbModelChecker // allDDRowVars.n()) + " states\n"); res = computeUntilProbs((STPG) model, b1, b2, min1, min2); - probs = StateValues.createFromDoubleArray(res.soln); + probs = StateValues.createFromDoubleArray(res.soln, model); return probs; } @@ -168,7 +168,7 @@ public class STPGModelChecker extends ProbModelChecker target = checkExpression(model, expr.getOperand2()).getBitSet(); res = computeReachRewards((STPG) model, rewards, target, min1, min2); - rews = StateValues.createFromDoubleArray(res.soln); + rews = StateValues.createFromDoubleArray(res.soln, model); return rews; } diff --git a/prism/src/explicit/StateModelChecker.java b/prism/src/explicit/StateModelChecker.java index 35c6940b..495a44dd 100644 --- a/prism/src/explicit/StateModelChecker.java +++ b/prism/src/explicit/StateModelChecker.java @@ -434,7 +434,7 @@ public class StateModelChecker expr = (Expression) expr.replaceConstants(constantValues); int numStates = model.getNumStates(); - res = new StateValues(expr.getType(), numStates); + res = new StateValues(expr.getType(), model); List statesList = model.getStatesList(); if (expr.getType() instanceof TypeBool) { for (int i = 0; i < numStates; i++) { @@ -476,7 +476,7 @@ public class StateModelChecker */ protected StateValues checkExpressionLiteral(Model model, ExpressionLiteral expr) throws PrismException { - return new StateValues(expr.getType(), model.getNumStates(), expr.evaluate()); + return new StateValues(expr.getType(), expr.evaluate(), model); } /** @@ -494,14 +494,14 @@ public class StateModelChecker for (i = 0; i < numStates; i++) { bs.set(i, model.isFixedDeadlockState(i)); } - return StateValues.createFromBitSet(bs, numStates); + return StateValues.createFromBitSet(bs, model); } else if (expr.getName().equals("init")) { int numStates = model.getNumStates(); BitSet bs = new BitSet(numStates); for (i = 0; i < numStates; i++) { bs.set(i, model.isInitialState(i)); } - return StateValues.createFromBitSet(bs, numStates); + return StateValues.createFromBitSet(bs, model); } else { ll = propertiesFile.getCombinedLabelList(); i = ll.getLabelIndex(expr.getName()); @@ -596,7 +596,7 @@ public class StateModelChecker // Compute min // Store as object/vector resObj = vals.minOverBitSet(bsFilter); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Minimum value over " + filterStatesString; mainLog.println("\n" + resultExpl + ": " + resObj); @@ -608,7 +608,7 @@ public class StateModelChecker // Compute max // Store as object/vector resObj = vals.maxOverBitSet(bsFilter); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Maximum value over " + filterStatesString; mainLog.println("\n" + resultExpl + ": " + resObj); @@ -625,7 +625,7 @@ public class StateModelChecker bsMatch.and(bsFilter); // Store states in vector; for ARGMIN, don't store a single value (in resObj) // Also, don't bother with explanation string - resVals = StateValues.createFromBitSet(bsMatch, model.getNumStates()); + resVals = StateValues.createFromBitSet(bsMatch, model); // Print out number of matching states, but not the actual states mainLog.println("\nNumber of states with minimum value: " + bsMatch.cardinality()); bsMatch = null; @@ -639,7 +639,7 @@ public class StateModelChecker bsMatch.and(bsFilter); // Store states in vector; for ARGMAX, don't store a single value (in resObj) // Also, don't bother with explanation string - resVals = StateValues.createFromBitSet(bsMatch, model.getNumStates()); + resVals = StateValues.createFromBitSet(bsMatch, model); // Print out number of matching states, but not the actual states mainLog.println("\nNumber of states with maximum value: " + bsMatch.cardinality()); bsMatch = null; @@ -649,7 +649,7 @@ public class StateModelChecker count = vals.countOverBitSet(bsFilter); // Store as object/vector resObj = new Integer(count); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = filterTrue ? "Count of satisfying states" : "Count of satisfying states also in filter"; mainLog.println("\n" + resultExpl + ": " + resObj); @@ -658,7 +658,7 @@ public class StateModelChecker // Compute sum // Store as object/vector resObj = vals.sumOverBitSet(bsFilter); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Sum over " + filterStatesString; mainLog.println("\n" + resultExpl + ": " + resObj); @@ -667,7 +667,7 @@ public class StateModelChecker // Compute average // Store as object/vector resObj = vals.averageOverBitSet(bsFilter); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Average over " + filterStatesString; mainLog.println("\n" + resultExpl + ": " + resObj); @@ -675,7 +675,7 @@ public class StateModelChecker case FIRST: // Find first value resObj = vals.firstFromBitSet(bsFilter); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Value in "; if (filterInit) { @@ -708,7 +708,7 @@ public class StateModelChecker b = vals.forallOverBitSet(bsFilter); // Store as object/vector resObj = new Boolean(b); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Property " + (b ? "" : "not ") + "satisfied in "; mainLog.print("\nProperty satisfied in " + vals.countOverBitSet(bsFilter)); @@ -736,7 +736,7 @@ public class StateModelChecker b = vals.existsOverBitSet(bsFilter); // Store as object/vector resObj = new Boolean(b); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Property satisfied in "; if (filterTrue) { @@ -756,7 +756,7 @@ public class StateModelChecker // Find first (only) value // Store as object/vector resObj = vals.firstFromBitSet(bsFilter); - resVals = new StateValues(expr.getType(), model.getNumStates(), resObj); + resVals = new StateValues(expr.getType(), resObj, model); // Create explanation of result and print some details to log resultExpl = "Value in "; if (filterInit) { diff --git a/prism/src/explicit/StateValues.java b/prism/src/explicit/StateValues.java index 516e178e..388e0a9f 100644 --- a/prism/src/explicit/StateValues.java +++ b/prism/src/explicit/StateValues.java @@ -29,8 +29,6 @@ package explicit; import java.util.BitSet; import java.util.List; -import jdd.JDDNode; - import parser.State; import parser.type.Type; import parser.type.TypeBool; @@ -73,15 +71,42 @@ public class StateValues valuesB = null; } + /** + * Construct a new state values vector of the given type. + * All values are initially set to zero/false. + * Also set associated model (and this determines the vector size). + * @param type Value type + * @param model Associated model + */ + public StateValues(Type type, Model model) throws PrismLangException + { + this(type, model.getNumStates()); + statesList = model.getStatesList(); + } + /** * Construct a new state values vector of the given type and size. - * All values are initially set to zero. + * All values are initially set to zero/false. */ public StateValues(Type type, int size) throws PrismLangException { this(type, size, type.defaultValue()); } + /** + * Construct a new state values vector of the given type, initialising all values to {@code init}. + * Also set associated model (and this determines the vector size). + * Throws an exception of {@code init} is of the wrong type. + * @param type Value type + * @param init Initial value for all states (as an appropriate Object) + * @param model Associated model + */ + public StateValues(Type type, Object init, Model model) throws PrismLangException + { + this(type, model.getNumStates(), init); + statesList = model.getStatesList(); + } + /** * Construct a new state values vector of the given type and size, * initialising all values to {@code init}. @@ -126,42 +151,47 @@ public class StateValues /** * Create a new (double-valued) state values vector from an existing array of doubles. * The array is stored directly, not copied. + * Also set associated model (whose state space size should match vector size). */ - public static StateValues createFromDoubleArray(double[] array) + public static StateValues createFromDoubleArray(double[] array, Model model) { StateValues sv = new StateValues(); sv.type = TypeDouble.getInstance(); sv.size = array.length; sv.valuesD = array; + sv.statesList = model.getStatesList(); return sv; } /** * Create a new (Boolean-valued) state values vector from an existing BitSet. * The BitSet is stored directly, not copied. + * Also set associated model (and this determines the vector size). */ - public static StateValues createFromBitSet(BitSet bs, int size) + public static StateValues createFromBitSet(BitSet bs, Model model) { StateValues sv = new StateValues(); sv.type = TypeBool.getInstance(); - sv.size = size; + sv.size = model.getNumStates(); sv.valuesB = bs; + sv.statesList = model.getStatesList(); return sv; } /** * Create a new (double-valued) state values vector from a BitSet, * where each entry is 1.0 if in the bitset, 0.0 otherwise. - * The size must also be given since this is not explicit in the bitset. + * Also set associated model (and this determines the vector size). * The bitset is not modified or stored. */ - public static StateValues createFromBitSetAsDoubles(int size, BitSet bitset) + public static StateValues createFromBitSetAsDoubles(BitSet bitset, Model model) { + int size = model.getNumStates(); double[] array = new double[size]; for (int i = 0; i < size; i++) { array[i] = bitset.get(i) ? 1.0 : 0.0; } - return createFromDoubleArray(array); + return createFromDoubleArray(array, model); } /** @@ -489,13 +519,25 @@ public class StateValues // PRINTING STUFF /** - * Print vector to a log/file (non-zero entries only) + * Print vector to a log/file (non-zero entries only). */ public void print(PrismLog log) throws PrismException { printFiltered(log, null, true, false, true); } + /** + * Print vector to a log/file. + * @param log The log + * @param printSparse Print non-zero elements only? + * @param printMatlab Print in Matlab format? + * @param printStates Print states (variable values) for each element? + */ + public void print(PrismLog log, boolean printSparse, boolean printMatlab, boolean printStates) throws PrismException + { + printFiltered(log, null, true, false, true); + } + /** * Print part of vector to a log/file (non-zero entries only). * @param log The log @@ -507,7 +549,7 @@ public class StateValues } /** - * Print part of vector to a log/file (non-zero entries only). + * Print part of vector to a log/file. * @param log The log * @param filter A BitSet specifying which states to print for (null if all). * @param printSparse Print non-zero elements only?