Browse Source

Sparse engine: Error if number of reachable states is too large

Currently, the sparse engine internally uses int (signed 32bit) index variables
So, if the number of states is larger than Integer.MAX_VALUE, there is a problem
and the code will most probably crash or do nonsensical things.

We check this before calling into the native sparse engine code and
throw a PrismNotSupportedException.



git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@12020 bbc10eb1-c90d-0410-af57-cb519fbb1720
master
Joachim Klein 9 years ago
parent
commit
cc90e7a20c
  1. 70
      prism/src/sparse/PrismSparse.java

70
prism/src/sparse/PrismSparse.java

@ -37,6 +37,7 @@ import prism.NativeIntArray;
import prism.OpsAndBoundsList;
import prism.PrismException;
import prism.PrismLog;
import prism.PrismNotSupportedException;
import dv.DoubleVector;
import dv.IntegerVector;
@ -79,6 +80,21 @@ public class PrismSparse
// tidy up in jni (free global references)
private static native void PS_FreeGlobalRefs();
/**
* Check that number of reachable states is in a range that can be handled by
* the sparse engine methods.
* @throws PrismNotSupportedException if that is not the case
*/
private static void checkNumStates(ODDNode odd) throws PrismNotSupportedException
{
// currently, the sparse engine internally uses int (signed 32bit) index values
// so, if the number of states is larger than Integer.MAX_VALUE, there is a problem
long n = odd.getEOff() + odd.getTOff();
if (n >= Integer.MAX_VALUE) {
throw new PrismNotSupportedException("The sparse engine can currently only handle up to " + Integer.MAX_VALUE + " reachable states, model has " + n + " states");
}
}
//----------------------------------------------------------------------------------------------
// cudd manager
//----------------------------------------------------------------------------------------------
@ -144,6 +160,8 @@ public class PrismSparse
private static native long PS_ProbBoundedUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long yes, long maybe, int bound);
public static DoubleVector ProbBoundedUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode yes, JDDNode maybe, int bound) throws PrismException
{
checkNumStates(odd);
long ptr = PS_ProbBoundedUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), yes.ptr(), maybe.ptr(), bound);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -153,6 +171,8 @@ public class PrismSparse
private static native long PS_ProbUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long yes, long maybe);
public static DoubleVector ProbUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode yes, JDDNode maybe) throws PrismException
{
checkNumStates(odd);
long ptr = PS_ProbUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), yes.ptr(), maybe.ptr());
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -162,6 +182,8 @@ public class PrismSparse
private static native long PS_ProbCumulReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, int bound);
public static DoubleVector ProbCumulReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, int bound) throws PrismException
{
checkNumStates(odd);
long ptr = PS_ProbCumulReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), bound);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -171,6 +193,8 @@ public class PrismSparse
private static native long PS_ProbInstReward(long trans, long sr, long odd, long rv, int nrv, long cv, int ncv, int time);
public static DoubleVector ProbInstReward(JDDNode trans, JDDNode sr, ODDNode odd, JDDVars rows, JDDVars cols, int time) throws PrismException
{
checkNumStates(odd);
long ptr = PS_ProbInstReward(trans.ptr(), sr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -180,6 +204,8 @@ public class PrismSparse
private static native long PS_ProbReachReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, long goal, long inf, long maybe);
public static DoubleVector ProbReachReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode goal, JDDNode inf, JDDNode maybe) throws PrismException
{
checkNumStates(odd);
long ptr = PS_ProbReachReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), goal.ptr(), inf.ptr(), maybe.ptr());
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -189,6 +215,8 @@ public class PrismSparse
private static native long PS_ProbTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, int time);
public static DoubleVector ProbTransient(JDDNode trans, ODDNode odd, DoubleVector init, JDDVars rows, JDDVars cols, int time) throws PrismException
{
checkNumStates(odd);
long ptr = PS_ProbTransient(trans.ptr(), odd.ptr(), init.getPtr(), rows.array(), rows.n(), cols.array(), cols.n(), time);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -202,6 +230,8 @@ public class PrismSparse
private static native long PS_NondetBoundedUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, int time, boolean minmax);
public static DoubleVector NondetBoundedUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, int time, boolean minmax) throws PrismException
{
checkNumStates(odd);
long ptr = PS_NondetBoundedUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), time, minmax);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -211,6 +241,8 @@ public class PrismSparse
private static native long PS_NondetUntil(long trans, long trans_actions, List<String> synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax, long strat);
public static DoubleVector NondetUntil(JDDNode trans, JDDNode transActions, List<String> synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax, IntegerVector strat) throws PrismException
{
checkNumStates(odd);
long ptr = PS_NondetUntil(trans.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax, (strat == null) ? 0 : strat.getPtr());
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -220,6 +252,8 @@ public class PrismSparse
private static native long PS_NondetCumulReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, int bound, boolean minmax);
public static DoubleVector NondetCumulReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, int bound, boolean minmax) throws PrismException
{
checkNumStates(odd);
long ptr = PS_NondetCumulReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), bound, minmax);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -229,6 +263,8 @@ public class PrismSparse
private static native long PS_NondetInstReward(long trans, long sr, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, int time, boolean minmax, long init);
public static DoubleVector NondetInstReward(JDDNode trans, JDDNode sr, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, int time, boolean minmax, JDDNode init) throws PrismException
{
checkNumStates(odd);
long ptr = PS_NondetInstReward(trans.ptr(), sr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), time, minmax, init.ptr());
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -238,6 +274,8 @@ public class PrismSparse
private static native long PS_NondetReachReward(long trans, long trans_actions, List<String> synchs, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long goal, long inf, long maybe, boolean minmax);
public static DoubleVector NondetReachReward(JDDNode trans, JDDNode transActions, List<String> synchs, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode goal, JDDNode inf, JDDNode maybe, boolean minmax) throws PrismException
{
checkNumStates(odd);
long ptr = PS_NondetReachReward(trans.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), goal.ptr(), inf.ptr(), maybe.ptr(), minmax);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -246,6 +284,8 @@ public class PrismSparse
private static native double[] PS_NondetMultiObj(long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, boolean minmax, long start, long ptr_adversary, long ptr_TransSparseMatrix, List<String> synchs, long[] ptr_yes_vec, int[] probStepBounds, long[] ptr_RewSparseMatrix, double[] rewardWeights, int[] rewardStepBounds);
public static double[] NondetMultiObj(ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, boolean minmax, JDDNode start, NativeIntArray adversary, NDSparseMatrix transSparseMatrix, List<String> synchs, DoubleVector[] yes_vec, int[] probStepBounds, NDSparseMatrix[] rewSparseMatrix, double[] rewardWeights, int[] rewardStepBounds) throws PrismException
{
checkNumStates(odd);
long[] ptr_ndsp_r = null;
if (rewSparseMatrix != null) {
ptr_ndsp_r = new long[rewSparseMatrix.length];
@ -271,6 +311,8 @@ public class PrismSparse
private static native double[] PS_NondetMultiObjGS(long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, boolean minmax, long start, long ptr_adversary, long ptr_TransSparseMatrix, long[] ptr_yes_vec, long[] ptr_RewSparseMatrix, double[] rewardWeights);
public static double[] NondetMultiObjGS(ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, boolean minmax, JDDNode start, NativeIntArray adversary, NDSparseMatrix transSparseMatrix, DoubleVector[] yes_vec, NDSparseMatrix[] rewSparseMatrix, double[] rewardWeights) throws PrismException
{
checkNumStates(odd);
long[] ptr_ndsp_r = null;
if (rewSparseMatrix != null) {
ptr_ndsp_r = new long[rewSparseMatrix.length];
@ -296,6 +338,8 @@ public class PrismSparse
private static native double PS_NondetMultiReach(long trans, long trans_actions, List<String> synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long targets[], int relops[], double bounds[], long maybe, long start);
public static double NondetMultiReach(JDDNode trans, JDDNode transActions, List<String> synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List<JDDNode> targets, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start) throws PrismException
{
checkNumStates(odd);
// Convert lists to arrays for passing to JNI
int i, n = targets.size();
long targetsArr[] = new long[n];
@ -317,6 +361,8 @@ public class PrismSparse
private static native double PS_NondetMultiReach1(long trans, long trans_actions, List<String> synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long targets[], long combinations[], int combinationIDs[], int relops[], double bounds[], long maybe, long start);
public static double NondetMultiReach1(JDDNode trans, JDDNode transActions, List<String> synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List<JDDNode> targets, List<JDDNode> combinations, List<Integer> combinationIDs, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start) throws PrismException
{
checkNumStates(odd);
// Convert lists to arrays for passing to JNI
int i, n = targets.size();
long targetsArr[] = new long[n];
@ -347,6 +393,8 @@ public class PrismSparse
public static double NondetMultiReachReward(JDDNode trans, JDDNode transActions, List<String> synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List<JDDNode> targets, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start,
List<JDDNode> trr, JDDNode becs) throws PrismException
{
checkNumStates(odd);
// Convert lists to arrays for passing to JNI
int i;//, n = targets.size();
long targetsArr[] = new long[targets.size()];
@ -381,6 +429,8 @@ public class PrismSparse
public static double NondetMultiReachReward1(JDDNode trans, JDDNode transActions, List<String> synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, List<JDDNode> targets, List<JDDNode> combinations, List<Integer> combinationIDs, OpsAndBoundsList opsAndBounds, JDDNode maybe, JDDNode start,
List<JDDNode> trr, JDDNode becs) throws PrismException
{
checkNumStates(odd);
// Convert lists to arrays for passing to JNI
int i;//, n = targets.size();
long targetsArr[] = new long[targets.size()];
@ -427,6 +477,8 @@ public class PrismSparse
private static native long PS_StochBoundedUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long yes, long maybe, double time, long mult);
public static DoubleVector StochBoundedUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode yes, JDDNode maybe, double time, DoubleVector multProbs) throws PrismException
{
checkNumStates(odd);
long mult = (multProbs == null) ? 0 : multProbs.getPtr();
long ptr = PS_StochBoundedUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), yes.ptr(), maybe.ptr(), time, mult);
if (ptr == 0) throw new PrismException(getErrorMessage());
@ -437,6 +489,8 @@ public class PrismSparse
private static native long PS_StochCumulReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, double time);
public static DoubleVector StochCumulReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, double time) throws PrismException
{
checkNumStates(odd);
long ptr = PS_StochCumulReward(trans.ptr(), sr.ptr(), trr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -446,6 +500,8 @@ public class PrismSparse
private static native long PS_StochSteadyState(long trans, long odd, long init, long rv, int nrv, long cv, int ncv);
public static DoubleVector StochSteadyState(JDDNode trans, ODDNode odd, JDDNode init, JDDVars rows, JDDVars cols) throws PrismException
{
checkNumStates(odd);
long ptr = PS_StochSteadyState(trans.ptr(), odd.ptr(), init.ptr(), rows.array(), rows.n(), cols.array(), cols.n());
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -455,6 +511,8 @@ public class PrismSparse
private static native long PS_StochTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, double time);
public static DoubleVector StochTransient(JDDNode trans, ODDNode odd, DoubleVector init, JDDVars rows, JDDVars cols, double time) throws PrismException
{
checkNumStates(odd);
long ptr = PS_StochTransient(trans.ptr(), odd.ptr(), init.getPtr(), rows.array(), rows.n(), cols.array(), cols.n(), time);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -468,6 +526,8 @@ public class PrismSparse
private static native int PS_ExportMatrix(long matrix, String name, long rv, int nrv, long cv, int ncv, long odd, int exportType, String filename);
public static void ExportMatrix(JDDNode matrix, String name, JDDVars rows, JDDVars cols, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException
{
checkNumStates(odd);
int res = PS_ExportMatrix(matrix.ptr(), name, rows.array(), rows.n(), cols.array(), cols.n(), odd.ptr(), exportType, filename);
if (res == -1) {
throw new FileNotFoundException();
@ -481,6 +541,8 @@ public class PrismSparse
private static native int PS_ExportMDP(long mdp, long trans_actions, List<String> synchs, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename);
public static void ExportMDP(JDDNode mdp, JDDNode transActions, List<String> synchs, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException
{
checkNumStates(odd);
int res = PS_ExportMDP(mdp.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename);
if (res == -1) {
throw new FileNotFoundException();
@ -494,6 +556,8 @@ public class PrismSparse
private static native int PS_ExportSubMDP(long mdp, long submdp, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename);
public static void ExportSubMDP(JDDNode mdp, JDDNode submdp, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException
{
checkNumStates(odd);
int res = PS_ExportSubMDP(mdp.ptr(), submdp.ptr(), name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename);
if (res == -1) {
throw new FileNotFoundException();
@ -511,6 +575,8 @@ public class PrismSparse
private static native long PS_Power(long odd, long rv, int nrv, long cv, int ncv, long a, long b, long init, boolean transpose);
public static DoubleVector Power(ODDNode odd, JDDVars rows, JDDVars cols, JDDNode a, JDDNode b, JDDNode init, boolean transpose) throws PrismException
{
checkNumStates(odd);
long ptr = PS_Power(odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), a.ptr(), b.ptr(), init.ptr(), transpose);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -520,6 +586,8 @@ public class PrismSparse
private static native long PS_JOR(long odd, long rv, int nrv, long cv, int ncv, long a, long b, long init, boolean transpose, boolean row_sums, double omega);
public static DoubleVector JOR(ODDNode odd, JDDVars rows, JDDVars cols, JDDNode a, JDDNode b, JDDNode init, boolean transpose, boolean row_sums, double omega) throws PrismException
{
checkNumStates(odd);
long ptr = PS_JOR(odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), a.ptr(), b.ptr(), init.ptr(), transpose, row_sums, omega);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
@ -529,6 +597,8 @@ public class PrismSparse
private static native long PS_SOR(long odd, long rv, int nrv, long cv, int ncv, long a, long b, long init, boolean transpose, boolean row_sums, double omega, boolean forwards);
public static DoubleVector SOR(ODDNode odd, JDDVars rows, JDDVars cols, JDDNode a, JDDNode b, JDDNode init, boolean transpose, boolean row_sums, double omega, boolean forwards) throws PrismException
{
checkNumStates(odd);
long ptr = PS_SOR(odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), a.ptr(), b.ptr(), init.ptr(), transpose, row_sums, omega, forwards);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));

Loading…
Cancel
Save