From cc90e7a20c7b9f6a395c709023ecac56b0307461 Mon Sep 17 00:00:00 2001 From: Joachim Klein Date: Thu, 13 Jul 2017 15:10:34 +0000 Subject: [PATCH] Sparse engine: Error if number of reachable states is too large Currently, the sparse engine internally uses int (signed 32bit) index variables So, if the number of states is larger than Integer.MAX_VALUE, there is a problem and the code will most probably crash or do nonsensical things. We check this before calling into the native sparse engine code and throw a PrismNotSupportedException. git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@12020 bbc10eb1-c90d-0410-af57-cb519fbb1720 --- prism/src/sparse/PrismSparse.java | 70 +++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java index c1475169..0ab390ac 100644 --- a/prism/src/sparse/PrismSparse.java +++ b/prism/src/sparse/PrismSparse.java @@ -37,6 +37,7 @@ import prism.NativeIntArray; import prism.OpsAndBoundsList; import prism.PrismException; import prism.PrismLog; +import prism.PrismNotSupportedException; import dv.DoubleVector; import dv.IntegerVector; @@ -79,6 +80,21 @@ public class PrismSparse // tidy up in jni (free global references) private static native void PS_FreeGlobalRefs(); + /** + * Check that number of reachable states is in a range that can be handled by + * the sparse engine methods. + * @throws PrismNotSupportedException if that is not the case + */ + private static void checkNumStates(ODDNode odd) throws PrismNotSupportedException + { + // currently, the sparse engine internally uses int (signed 32bit) index values + // so, if the number of states is larger than Integer.MAX_VALUE, there is a problem + long n = odd.getEOff() + odd.getTOff(); + if (n >= Integer.MAX_VALUE) { + throw new PrismNotSupportedException("The sparse engine can currently only handle up to " + Integer.MAX_VALUE + " reachable states, model has " + n + " states"); + } + } + //---------------------------------------------------------------------------------------------- // cudd manager //---------------------------------------------------------------------------------------------- @@ -144,6 +160,8 @@ public class PrismSparse private static native long PS_ProbBoundedUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long yes, long maybe, int bound); public static DoubleVector ProbBoundedUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode yes, JDDNode maybe, int bound) throws PrismException { + checkNumStates(odd); + long ptr = PS_ProbBoundedUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), yes.ptr(), maybe.ptr(), bound); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -153,6 +171,8 @@ public class PrismSparse private static native long PS_ProbUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long yes, long maybe); public static DoubleVector ProbUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode yes, JDDNode maybe) throws PrismException { + checkNumStates(odd); + long ptr = PS_ProbUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), yes.ptr(), maybe.ptr()); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -162,6 +182,8 @@ public class PrismSparse private static native long PS_ProbCumulReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, int bound); public static DoubleVector ProbCumulReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, int bound) throws PrismException { + checkNumStates(odd); + long ptr = PS_ProbCumulReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), bound); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -171,6 +193,8 @@ public class PrismSparse private static native long PS_ProbInstReward(long trans, long sr, long odd, long rv, int nrv, long cv, int ncv, int time); public static DoubleVector ProbInstReward(JDDNode trans, JDDNode sr, ODDNode odd, JDDVars rows, JDDVars cols, int time) throws PrismException { + checkNumStates(odd); + long ptr = PS_ProbInstReward(trans.ptr(), sr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -180,6 +204,8 @@ public class PrismSparse private static native long PS_ProbReachReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, long goal, long inf, long maybe); public static DoubleVector ProbReachReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode goal, JDDNode inf, JDDNode maybe) throws PrismException { + checkNumStates(odd); + long ptr = PS_ProbReachReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), goal.ptr(), inf.ptr(), maybe.ptr()); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -189,6 +215,8 @@ public class PrismSparse private static native long PS_ProbTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, int time); public static DoubleVector ProbTransient(JDDNode trans, ODDNode odd, DoubleVector init, JDDVars rows, JDDVars cols, int time) throws PrismException { + checkNumStates(odd); + long ptr = PS_ProbTransient(trans.ptr(), odd.ptr(), init.getPtr(), rows.array(), rows.n(), cols.array(), cols.n(), time); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -202,6 +230,8 @@ public class PrismSparse private static native long PS_NondetBoundedUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, int time, boolean minmax); public static DoubleVector NondetBoundedUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, int time, boolean minmax) throws PrismException { + checkNumStates(odd); + long ptr = PS_NondetBoundedUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), time, minmax); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -211,6 +241,8 @@ public class PrismSparse private static native long PS_NondetUntil(long trans, long trans_actions, List synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax, long strat); public static DoubleVector NondetUntil(JDDNode trans, JDDNode transActions, List synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax, IntegerVector strat) throws PrismException { + checkNumStates(odd); + long ptr = PS_NondetUntil(trans.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax, (strat == null) ? 0 : strat.getPtr()); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -220,6 +252,8 @@ public class PrismSparse private static native long PS_NondetCumulReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, int bound, boolean minmax); public static DoubleVector NondetCumulReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, int bound, boolean minmax) throws PrismException { + checkNumStates(odd); + long ptr = PS_NondetCumulReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), bound, minmax); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -229,6 +263,8 @@ public class PrismSparse private static native long PS_NondetInstReward(long trans, long sr, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, int time, boolean minmax, long init); public static DoubleVector NondetInstReward(JDDNode trans, JDDNode sr, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, int time, boolean minmax, JDDNode init) throws PrismException { + checkNumStates(odd); + long ptr = PS_NondetInstReward(trans.ptr(), sr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), time, minmax, init.ptr()); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -238,6 +274,8 @@ public class PrismSparse private static native long PS_NondetReachReward(long trans, long trans_actions, List synchs, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long goal, long inf, long maybe, boolean minmax); public static DoubleVector NondetReachReward(JDDNode trans, JDDNode transActions, List synchs, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode goal, JDDNode inf, JDDNode maybe, boolean minmax) throws PrismException { + checkNumStates(odd); + long ptr = PS_NondetReachReward(trans.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), goal.ptr(), inf.ptr(), maybe.ptr(), minmax); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -246,6 +284,8 @@ public class PrismSparse private static native double[] PS_NondetMultiObj(long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, boolean minmax, long start, long ptr_adversary, long ptr_TransSparseMatrix, List synchs, long[] ptr_yes_vec, int[] probStepBounds, long[] ptr_RewSparseMatrix, double[] rewardWeights, int[] rewardStepBounds); public static double[] NondetMultiObj(ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, boolean minmax, JDDNode start, NativeIntArray adversary, NDSparseMatrix transSparseMatrix, List synchs, DoubleVector[] yes_vec, int[] probStepBounds, NDSparseMatrix[] rewSparseMatrix, double[] rewardWeights, int[] rewardStepBounds) throws PrismException { + checkNumStates(odd); + long[] ptr_ndsp_r = null; if (rewSparseMatrix != null) { ptr_ndsp_r = new long[rewSparseMatrix.length]; @@ -271,6 +311,8 @@ public class PrismSparse private static native double[] PS_NondetMultiObjGS(long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, boolean minmax, long start, long ptr_adversary, long ptr_TransSparseMatrix, long[] ptr_yes_vec, long[] ptr_RewSparseMatrix, double[] rewardWeights); public static double[] NondetMultiObjGS(ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, boolean minmax, JDDNode start, NativeIntArray adversary, NDSparseMatrix transSparseMatrix, DoubleVector[] yes_vec, NDSparseMatrix[] rewSparseMatrix, double[] rewardWeights) throws PrismException { + checkNumStates(odd); + long[] ptr_ndsp_r = null; if (rewSparseMatrix != null) { ptr_ndsp_r = new long[rewSparseMatrix.length]; @@ -296,6 +338,8 @@ public class PrismSparse private static native double PS_NondetMultiReach(long trans, long trans_actions, List synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long targets[], int relops[], double bounds[], long maybe, long start); public static double NondetMultiReach(JDDNode trans, JDDNode transActions, List synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List targets, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start) throws PrismException { + checkNumStates(odd); + // Convert lists to arrays for passing to JNI int i, n = targets.size(); long targetsArr[] = new long[n]; @@ -317,6 +361,8 @@ public class PrismSparse private static native double PS_NondetMultiReach1(long trans, long trans_actions, List synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long targets[], long combinations[], int combinationIDs[], int relops[], double bounds[], long maybe, long start); public static double NondetMultiReach1(JDDNode trans, JDDNode transActions, List synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List targets, List combinations, List combinationIDs, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start) throws PrismException { + checkNumStates(odd); + // Convert lists to arrays for passing to JNI int i, n = targets.size(); long targetsArr[] = new long[n]; @@ -347,6 +393,8 @@ public class PrismSparse public static double NondetMultiReachReward(JDDNode trans, JDDNode transActions, List synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List targets, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start, List trr, JDDNode becs) throws PrismException { + checkNumStates(odd); + // Convert lists to arrays for passing to JNI int i;//, n = targets.size(); long targetsArr[] = new long[targets.size()]; @@ -381,6 +429,8 @@ public class PrismSparse public static double NondetMultiReachReward1(JDDNode trans, JDDNode transActions, List synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List targets, List combinations, List combinationIDs, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start, List trr, JDDNode becs) throws PrismException { + checkNumStates(odd); + // Convert lists to arrays for passing to JNI int i;//, n = targets.size(); long targetsArr[] = new long[targets.size()]; @@ -427,6 +477,8 @@ public class PrismSparse private static native long PS_StochBoundedUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long yes, long maybe, double time, long mult); public static DoubleVector StochBoundedUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode yes, JDDNode maybe, double time, DoubleVector multProbs) throws PrismException { + checkNumStates(odd); + long mult = (multProbs == null) ? 0 : multProbs.getPtr(); long ptr = PS_StochBoundedUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), yes.ptr(), maybe.ptr(), time, mult); if (ptr == 0) throw new PrismException(getErrorMessage()); @@ -437,6 +489,8 @@ public class PrismSparse private static native long PS_StochCumulReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, double time); public static DoubleVector StochCumulReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, double time) throws PrismException { + checkNumStates(odd); + long ptr = PS_StochCumulReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -446,6 +500,8 @@ public class PrismSparse private static native long PS_StochSteadyState(long trans, long odd, long init, long rv, int nrv, long cv, int ncv); public static DoubleVector StochSteadyState(JDDNode trans, ODDNode odd, JDDNode init, JDDVars rows, JDDVars cols) throws PrismException { + checkNumStates(odd); + long ptr = PS_StochSteadyState(trans.ptr(), odd.ptr(), init.ptr(), rows.array(), rows.n(), cols.array(), cols.n()); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -455,6 +511,8 @@ public class PrismSparse private static native long PS_StochTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, double time); public static DoubleVector StochTransient(JDDNode trans, ODDNode odd, DoubleVector init, JDDVars rows, JDDVars cols, double time) throws PrismException { + checkNumStates(odd); + long ptr = PS_StochTransient(trans.ptr(), odd.ptr(), init.getPtr(), rows.array(), rows.n(), cols.array(), cols.n(), time); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -468,6 +526,8 @@ public class PrismSparse private static native int PS_ExportMatrix(long matrix, String name, long rv, int nrv, long cv, int ncv, long odd, int exportType, String filename); public static void ExportMatrix(JDDNode matrix, String name, JDDVars rows, JDDVars cols, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException { + checkNumStates(odd); + int res = PS_ExportMatrix(matrix.ptr(), name, rows.array(), rows.n(), cols.array(), cols.n(), odd.ptr(), exportType, filename); if (res == -1) { throw new FileNotFoundException(); @@ -481,6 +541,8 @@ public class PrismSparse private static native int PS_ExportMDP(long mdp, long trans_actions, List synchs, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename); public static void ExportMDP(JDDNode mdp, JDDNode transActions, List synchs, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException { + checkNumStates(odd); + int res = PS_ExportMDP(mdp.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename); if (res == -1) { throw new FileNotFoundException(); @@ -494,6 +556,8 @@ public class PrismSparse private static native int PS_ExportSubMDP(long mdp, long submdp, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename); public static void ExportSubMDP(JDDNode mdp, JDDNode submdp, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException { + checkNumStates(odd); + int res = PS_ExportSubMDP(mdp.ptr(), submdp.ptr(), name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename); if (res == -1) { throw new FileNotFoundException(); @@ -511,6 +575,8 @@ public class PrismSparse private static native long PS_Power(long odd, long rv, int nrv, long cv, int ncv, long a, long b, long init, boolean transpose); public static DoubleVector Power(ODDNode odd, JDDVars rows, JDDVars cols, JDDNode a, JDDNode b, JDDNode init, boolean transpose) throws PrismException { + checkNumStates(odd); + long ptr = PS_Power(odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), a.ptr(), b.ptr(), init.ptr(), transpose); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -520,6 +586,8 @@ public class PrismSparse private static native long PS_JOR(long odd, long rv, int nrv, long cv, int ncv, long a, long b, long init, boolean transpose, boolean row_sums, double omega); public static DoubleVector JOR(ODDNode odd, JDDVars rows, JDDVars cols, JDDNode a, JDDNode b, JDDNode init, boolean transpose, boolean row_sums, double omega) throws PrismException { + checkNumStates(odd); + long ptr = PS_JOR(odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), a.ptr(), b.ptr(), init.ptr(), transpose, row_sums, omega); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); @@ -529,6 +597,8 @@ public class PrismSparse private static native long PS_SOR(long odd, long rv, int nrv, long cv, int ncv, long a, long b, long init, boolean transpose, boolean row_sums, double omega, boolean forwards); public static DoubleVector SOR(ODDNode odd, JDDVars rows, JDDVars cols, JDDNode a, JDDNode b, JDDNode init, boolean transpose, boolean row_sums, double omega, boolean forwards) throws PrismException { + checkNumStates(odd); + long ptr = PS_SOR(odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), a.ptr(), b.ptr(), init.ptr(), transpose, row_sums, omega, forwards); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));