diff --git a/prism/include/PrismSparse.h b/prism/include/PrismSparse.h index 947a7e99..ae30530f 100644 --- a/prism/include/PrismSparse.h +++ b/prism/include/PrismSparse.h @@ -162,10 +162,10 @@ JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1NondetBoundedUntil /* * Class: sparse_PrismSparse * Method: PS_NondetUntil - * Signature: (JJJIJIJIJJZ)J + * Signature: (JJLjava/util/List;JJIJIJIJJZ)J */ JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1NondetUntil - (JNIEnv *, jclass, jlong, jlong, jlong, jint, jlong, jint, jlong, jint, jlong, jlong, jboolean); + (JNIEnv *, jclass, jlong, jlong, jobject, jlong, jlong, jint, jlong, jint, jlong, jint, jlong, jlong, jboolean); /* * Class: sparse_PrismSparse diff --git a/prism/include/prism.h b/prism/include/prism.h index 6ab9ad26..962d5cda 100644 --- a/prism/include/prism.h +++ b/prism/include/prism.h @@ -24,6 +24,8 @@ // //============================================================================== +#include + // Flags for building Windows DLLs #ifdef __MINGW32__ #define EXPORT __declspec(dllexport) @@ -41,6 +43,8 @@ typedef struct FoxGlynnWeights } FoxGlynnWeights; // Function prototypes +EXPORT void get_string_array_from_java(JNIEnv *env, jobject strings_list, jstring *&strings_jstrings, const char **&strings, jint &size); +EXPORT void release_string_array_from_java(JNIEnv *env, jstring *strings_jstrings, const char **strings, jint size); EXPORT FoxGlynnWeights fox_glynn(double q_tmax, double underflow, double overflow, double accuracy); //------------------------------------------------------------------------------ diff --git a/prism/include/sparse.h b/prism/include/sparse.h index 7f8bcbe0..8eec3241 100644 --- a/prism/include/sparse.h +++ b/prism/include/sparse.h @@ -172,6 +172,7 @@ CMSRSparseMatrix *build_cmsr_sparse_matrix(DdManager *ddman, DdNode *matrix, DdN CMSCSparseMatrix *build_cmsc_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd); NDSparseMatrix *build_nd_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd); NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode *submdp, DdNode **rvars, DdNode **cvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd); +int *build_nd_action_vector(DdManager *ddman, DdNode *trans_actions, NDSparseMatrix *mdp_ndsm, DdNode **rvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd); RMSparseMatrix *build_rm_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd, bool transpose); CMSparseMatrix *build_cm_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd, bool transpose); diff --git a/prism/src/prism/Model.java b/prism/src/prism/Model.java index e87a4a2d..86fe0482 100644 --- a/prism/src/prism/Model.java +++ b/prism/src/prism/Model.java @@ -27,7 +27,7 @@ package prism; import java.io.*; -import java.util.Vector; +import java.util.*; import jdd.*; import odd.*; @@ -50,6 +50,7 @@ public interface Model int getVarHigh(int i); int getVarRange(int i); Values getConstantValues(); + List getSynchs(); String globalToLocal(long x); int globalToLocal(long x, int l); State convertBddToState(JDDNode dd); @@ -78,6 +79,7 @@ public interface Model JDDNode getTransRewards(); JDDNode getTransRewards(int i); JDDNode getTransRewards(String s); + JDDNode getTransActions(); JDDVars[] getVarDDRowVars(); JDDVars[] getVarDDColVars(); JDDVars getVarDDRowVars(int i); @@ -95,12 +97,14 @@ public interface Model ODDNode getODD(); + void setSynchs(Vector synchs); void resetTrans(JDDNode trans); void resetTransRewards(int i, JDDNode transRewards); void doReachability(); void doReachability(boolean extraReachInfo); void skipReachability(); void setReach(JDDNode reach); + void setTransActions(JDDNode transActions); void filterReachableStates(); void findDeadlocks(); void fixDeadlocks(); diff --git a/prism/src/prism/Modules2MTBDD.java b/prism/src/prism/Modules2MTBDD.java index 236061cc..5b3bdb1b 100644 --- a/prism/src/prism/Modules2MTBDD.java +++ b/prism/src/prism/Modules2MTBDD.java @@ -77,6 +77,7 @@ public class Modules2MTBDD private JDDNode start; // dd for start state private JDDNode stateRewards[]; // dds for state rewards private JDDNode transRewards[]; // dds for transition rewards + private JDDNode transActions; // dd for transition action labels private JDDNode transInd; // dds for independent bits of trans private JDDNode transSynch[]; // dds for synch action parts of trans private JDDVars allDDRowVars; // all dd vars (rows) @@ -116,8 +117,10 @@ public class Modules2MTBDD private int numModulesAfterSymm; // number of modules in the PRISM file after the symmetric ones private int numSymmModules; // number of symmetric components - // hidden option - do we also store each part of the transition matrix separately? - private boolean storeTransParts = true; + // hidden option - do we also store each part of the transition matrix separately? (now defunct) + private boolean storeTransParts = false; + // hidden option - do we also store action info for the transition matrix? (supersedes the above) + private boolean storeTransActions = true; // data structure used to store mtbdds and related info // for some component of the whole model @@ -164,8 +167,9 @@ public class Modules2MTBDD doSymmetry = !(s == null || s == ""); } - // main method - translate + @SuppressWarnings("unchecked") // for clone of vector in translate() + // main method - translate public Model translate() throws PrismException { Model model = null; @@ -261,14 +265,21 @@ public class Modules2MTBDD numVars, varList, varDDRowVars, varDDColVars, constantValues); } + // We also store a copy of the list of action label names + model.setSynchs((Vector)synchs.clone()); + // For MDPs, we also store the DDs used to construct the part // of the transition matrix that corresponds to each action if (modelType == ModelType.MDP && storeTransParts) { - ((NondetModel)model).setSynchs((Vector)synchs.clone()); ((NondetModel)model).setTransInd(transInd); ((NondetModel)model).setTransSynch(transSynch); } + // if required, we also store info about action labels + if (storeTransActions) { + model.setTransActions(transActions); + } + // do reachability (or not) if (prism.getDoReach()) { mainLog.print("\nComputing reachable states...\n"); @@ -838,6 +849,21 @@ public class Modules2MTBDD } } + // If required, we also build an MTBDD to store the action labels for each transition + if (storeTransActions) { + if (modelType == ModelType.MDP) { + transActions = JDD.Constant(0); + JDD.Ref(sysDDs.ind.trans); + tmp = JDD.ThereExists(JDD.GreaterThan(sysDDs.ind.trans, 0), allDDColVars); + transActions = JDD.Apply(JDD.PLUS, transActions, JDD.Apply(JDD.TIMES, tmp, JDD.Constant(1))); + for (i = 0; i < numSynchs; i++) { + JDD.Ref(sysDDs.synchs[i].trans); + tmp = JDD.ThereExists(JDD.GreaterThan(sysDDs.synchs[i].trans, 0), allDDColVars); + transActions = JDD.Apply(JDD.PLUS, transActions, JDD.Apply(JDD.TIMES, tmp, JDD.Constant(2+i))); + } + } + } + // deref bits of ComponentDD objects - we don't need them any more JDD.Deref(sysDDs.ind.guards); JDD.Deref(sysDDs.ind.trans); diff --git a/prism/src/prism/NondetModel.java b/prism/src/prism/NondetModel.java index 99d4206d..7c75e31b 100644 --- a/prism/src/prism/NondetModel.java +++ b/prism/src/prism/NondetModel.java @@ -41,8 +41,6 @@ import sparse.*; public class NondetModel extends ProbModel { // Extra info - protected Vector synchs; // synchronising action labels - protected double numSynchs; // number of synchronising actions protected double numChoices; // number of choices // Extra dd stuff @@ -97,11 +95,6 @@ public class NondetModel extends ProbModel return allDDNondetVars; } - public Vector getSynchs() - { - return synchs; - } - public JDDNode getTransInd() { return transInd; @@ -135,12 +128,6 @@ public class NondetModel extends ProbModel // set methods for things not set up in constructor - public void setSynchs(Vector synchs) - { - this.synchs = synchs; - this.numSynchs = synchs.size(); - } - public void setTransInd(JDDNode transInd) { this.transInd = transInd; @@ -336,6 +323,11 @@ public class NondetModel extends ProbModel } } } + if (transActions != null && !transActions.equals(JDD.ZERO)) { + log.print("Action label indices: "); + log.print(JDD.GetNumNodes(transActions) + " nodes ("); + log.print(JDD.GetNumTerminals(transActions) + " terminal)\n"); + } } // export transition matrix to a file diff --git a/prism/src/prism/NondetModelChecker.java b/prism/src/prism/NondetModelChecker.java index 97b97718..65066667 100644 --- a/prism/src/prism/NondetModelChecker.java +++ b/prism/src/prism/NondetModelChecker.java @@ -645,7 +645,7 @@ public class NondetModelChecker extends NonProbModelChecker else { // for fairness, we compute max here try { - probs = computeUntilProbs(trans, trans01, newb1, newb2, min && !fairness); + probs = computeUntilProbs(trans, transActions, trans01, newb1, newb2, min && !fairness); } catch (PrismException e) { JDD.Deref(newb1); JDD.Deref(newb2); @@ -934,7 +934,7 @@ public class NondetModelChecker extends NonProbModelChecker // note: this function doesn't need to know anything about fairness // it is just told whether to compute min or max probabilities - protected StateProbs computeUntilProbs(JDDNode tr, JDDNode tr01, JDDNode b1, JDDNode b2, boolean min) + protected StateProbs computeUntilProbs(JDDNode tr, JDDNode tra, JDDNode tr01, JDDNode b1, JDDNode b2, boolean min) throws PrismException { JDDNode yes, no, maybe; @@ -1016,7 +1016,7 @@ public class NondetModelChecker extends NonProbModelChecker probs = new StateProbsMTBDD(probsMTBDD, model); break; case Prism.SPARSE: - probsDV = PrismSparse.NondetUntil(tr, odd, allDDRowVars, allDDColVars, allDDNondetVars, yes, maybe, + probsDV = PrismSparse.NondetUntil(tr, tra, model.getSynchs(), odd, allDDRowVars, allDDColVars, allDDNondetVars, yes, maybe, min); probs = new StateProbsDV(probsDV, model); break; diff --git a/prism/src/prism/ProbModel.java b/prism/src/prism/ProbModel.java index 7496cac2..38ec3894 100644 --- a/prism/src/prism/ProbModel.java +++ b/prism/src/prism/ProbModel.java @@ -27,8 +27,7 @@ package prism; import java.io.*; -import java.util.BitSet; -import java.util.Vector; +import java.util.*; import jdd.*; import odd.*; @@ -51,6 +50,9 @@ public class ProbModel implements Model protected VarList varList; // list of module variables protected long[] gtol; // numbers for use by globalToLocal protected Values constantValues; // values of constants + // actions + protected int numSynchs; // number of synchronising actions + protected Vector synchs; // synchronising action labels // rewards protected int numRewardStructs; // number of reward structs protected String[] rewardStructNames; // reward struct names @@ -71,6 +73,8 @@ public class ProbModel implements Model protected JDDNode fixdl; // fixed deadlock states dd protected JDDNode stateRewards[]; // state rewards dds protected JDDNode transRewards[]; // transition rewards dds + protected JDDNode transActions; // dd for transition action labels + // dd vars protected JDDVars[] varDDRowVars; // dd vars for each module variable (rows) protected JDDVars[] varDDColVars; // dd vars for each module variable (cols) @@ -155,6 +159,14 @@ public class ProbModel implements Model return constantValues; } + /** + * Get vector of action label names. + */ + public List getSynchs() + { + return synchs; + } + // rewards public int getNumRewardStructs() { @@ -277,6 +289,11 @@ public class ProbModel implements Model return null; } + public JDDNode getTransActions() + { + return transActions; + } + // dd vars public JDDVars[] getVarDDRowVars() { @@ -391,7 +408,10 @@ public class ProbModel implements Model varDDRowVars = vrv; varDDColVars = vcv; constantValues = cv; - + + // action label info (optional) is initially null + transActions = null; + // compute numbers for globalToLocal converter gtol = new long[numVars]; for (i = 0; i < numVars; i++) { @@ -409,6 +429,15 @@ public class ProbModel implements Model numStartStates = JDD.GetNumMinterms(start, allDDRowVars.n()); } + /** + * Set vector of action label names. + */ + public void setSynchs(Vector synchs) + { + this.synchs = synchs; + this.numSynchs = synchs.size(); + } + /** * Reset transition matrix DD */ @@ -474,7 +503,15 @@ public class ProbModel implements Model // build odd odd = ODDUtils.BuildODD(reach, allDDRowVars); } - + + /** + * Set the DD used to store transitoin action label indices. + */ + public void setTransActions(JDDNode transActions) + { + this.transActions = transActions; + } + // remove non-reachable states from various dds // (and calculate num transitions) @@ -508,6 +545,13 @@ public class ProbModel implements Model transRewards[i] = JDD.Apply(JDD.TIMES, tmp, transRewards[i]); } + // Action label indices matrix + // (just filter rows here; subclasses, e.g. CTMCs, may do more subsequently) + if (transActions != null) { + JDD.Ref(reach); + transActions = JDD.Apply(JDD.TIMES, reach, transActions); + } + // filter start states, work out number of initial states JDD.Ref(reach); start = JDD.Apply(JDD.TIMES, reach, start); @@ -619,6 +663,11 @@ public class ProbModel implements Model } } } + if (transActions != null && !transActions.equals(JDD.ZERO)) { + log.print("Action label indices: "); + log.print(JDD.GetNumNodes(transActions) + " nodes ("); + log.print(JDD.GetNumTerminals(transActions) + " terminal)\n"); + } } // export transition matrix to a file @@ -795,5 +844,6 @@ public class ProbModel implements Model JDD.Deref(stateRewards[i]); JDD.Deref(transRewards[i]); } + if (transActions != null) JDD.Deref(transActions); } } diff --git a/prism/src/prism/StateModelChecker.java b/prism/src/prism/StateModelChecker.java index af4471e4..f6e93d3e 100644 --- a/prism/src/prism/StateModelChecker.java +++ b/prism/src/prism/StateModelChecker.java @@ -53,6 +53,7 @@ public class StateModelChecker implements ModelChecker protected VarList varList; protected JDDNode trans; protected JDDNode trans01; + protected JDDNode transActions; protected JDDNode start; protected JDDNode reach; protected ODDNode odd; @@ -89,6 +90,7 @@ public class StateModelChecker implements ModelChecker varList = model.getVarList(); trans = model.getTrans(); trans01 = model.getTrans01(); + transActions = model.getTransActions(); start = model.getStart(); reach = model.getReach(); odd = model.getODD(); diff --git a/prism/src/prism/StochModel.java b/prism/src/prism/StochModel.java index a15fad48..1f19dfcf 100644 --- a/prism/src/prism/StochModel.java +++ b/prism/src/prism/StochModel.java @@ -61,4 +61,22 @@ public class StochModel extends ProbModel { super(tr, s, sr, trr, rsn, arv, acv, ddvn, nm, mn, mrv, mcv, nv, vl, vrv, vcv, cv); } + + /** + * Remove non-reachable states from various DDs. + * Most of the work is done in the superclass (ProbModel); + * just a few extra things to do here. + */ + public void filterReachableStates() + { + super.filterReachableStates(); + + // If required, also filter columns of action label indices matrix + // (for the superclass - DTMCs - we only store info per state). + if (transActions != null) { + JDD.Ref(reach); + JDDNode tmp = JDD.PermuteVariables(reach, allDDRowVars, allDDColVars); + transActions = JDD.Apply(JDD.TIMES, tmp, transActions); + } + } } diff --git a/prism/src/prism/prism.cc b/prism/src/prism/prism.cc index 53239846..633b1d6c 100644 --- a/prism/src/prism/prism.cc +++ b/prism/src/prism/prism.cc @@ -29,6 +29,52 @@ #include "prism.h" #include #include +#include + +//------------------------------------------------------------------------------ + +// convert a list of strings (from java/jni) to an array of c strings. +// actually stores arrays of both jstring objects and c strings, and also size +// (because need these to free memory afterwards). + +void get_string_array_from_java(JNIEnv *env, jobject strings_list, jstring *&strings_jstrings, const char **&strings, jint &size) +{ + int i, j; + jclass vn_cls; + jmethodID vn_mid; + // get size of vector of strings + vn_cls = env->GetObjectClass(strings_list); + vn_mid = env->GetMethodID(vn_cls, "size", "()I"); + if (vn_mid == 0) { + return; + } + size = env->CallIntMethod(strings_list,vn_mid); + // put strings from vector into array + strings_jstrings = new jstring[size]; + strings = new const char*[size]; + vn_mid = env->GetMethodID(vn_cls, "get", "(I)Ljava/lang/Object;"); + if (vn_mid == 0) { + return; + } + for (i = 0; i < size; i++) { + strings_jstrings[i] = (jstring)env->CallObjectMethod(strings_list, vn_mid, i); + strings[i] = env->GetStringUTFChars(strings_jstrings[i], 0); + } +} + +//------------------------------------------------------------------------------ + +// release the memory from a list of strings created by get_string_array_from_java + +void release_string_array_from_java(JNIEnv *env, jstring *strings_jstrings, const char **strings, jint size) +{ + // release memory + for (int i = 0; i < size; i++) { + env->ReleaseStringUTFChars(strings_jstrings[i], strings[i]); + } + delete[] strings_jstrings; + delete[] strings; +} //------------------------------------------------------------------------------ diff --git a/prism/src/sparse/PS_NondetUntil.cc b/prism/src/sparse/PS_NondetUntil.cc index 479a2f5c..f116cbfa 100644 --- a/prism/src/sparse/PS_NondetUntil.cc +++ b/prism/src/sparse/PS_NondetUntil.cc @@ -33,6 +33,7 @@ #include #include #include "sparse.h" +#include "prism.h" #include "PrismSparseGlob.h" #include "jnipointer.h" #include @@ -43,7 +44,9 @@ JNIEXPORT jlong __jlongpointer JNICALL Java_sparse_PrismSparse_PS_1NondetUntil ( JNIEnv *env, jclass cls, -jlong __jlongpointer t, // trans matrix +jlong __jlongpointer t, // trans matrix +jlong __jlongpointer ta, // trans action labels +jobject synchs, jlong __jlongpointer od, // odd jlong __jlongpointer rv, // row vars jint num_rvars, @@ -51,22 +54,23 @@ jlong __jlongpointer cv, // col vars jint num_cvars, jlong __jlongpointer ndv, // nondet vars jint num_ndvars, -jlong __jlongpointer y, // 'yes' states -jlong __jlongpointer m, // 'maybe' states -jboolean min // min or max probabilities (true = min, false = max) +jlong __jlongpointer y, // 'yes' states +jlong __jlongpointer m, // 'maybe' states +jboolean min // min or max probabilities (true = min, false = max) ) { // cast function parameters - DdNode *trans = jlong_to_DdNode(t); // trans matrix - ODDNode *odd = jlong_to_ODDNode(od); // reachable states + DdNode *trans = jlong_to_DdNode(t); // trans matrix + DdNode *trans_actions = jlong_to_DdNode(ta); // trans action labels + ODDNode *odd = jlong_to_ODDNode(od); // reachable states DdNode **rvars = jlong_to_DdNode_array(rv); // row vars DdNode **cvars = jlong_to_DdNode_array(cv); // col vars DdNode **ndvars = jlong_to_DdNode_array(ndv); // nondet vars - DdNode *yes = jlong_to_DdNode(y); // 'yes' states - DdNode *maybe = jlong_to_DdNode(m); // 'maybe' states + DdNode *yes = jlong_to_DdNode(y); // 'yes' states + DdNode *maybe = jlong_to_DdNode(m); // 'maybe' states // mtbdds - DdNode *a = NULL; + DdNode *a = NULL, *tmp = NULL; // model stats int n, nc; long nnz; @@ -81,6 +85,10 @@ jboolean min // min or max probabilities (true = min, false = max) bool adv = false, adv_loop = false; FILE *fp_adv = NULL; int adv_l, adv_h; + int *actions; + jstring *action_names_jstrings; + const char** action_names; + int num_actions; // misc int i, j, k, l1, h1, l2, h2, iters; double d1, d2, kb, kbt; @@ -112,6 +120,25 @@ jboolean min // min or max probabilities (true = min, false = max) PS_PrintToMainLog(env, "[n=%d, nc=%d, nnz=%d, k=%d] ", n, nc, nnz, ndsm->k); PS_PrintMemoryToMainLog(env, "[", kb, "]\n"); + // if needed, and if info is available, build a vector of action indices for the mdp + if (adv && trans_actions != NULL) { + PS_PrintToMainLog(env, "Building action information... "); + // first need to filter out unwanted rows + Cudd_Ref(trans_actions); + Cudd_Ref(maybe); + tmp = DD_Apply(ddman, APPLY_TIMES, trans_actions, maybe); + // then convert to a vector of integer indices + actions = build_nd_action_vector(ddman, tmp, ndsm, rvars, num_rvars, ndvars, num_ndvars, odd); + Cudd_RecursiveDeref(ddman, tmp); + kb = n*4.0/1024.0; + kbt += kb; + PS_PrintMemoryToMainLog(env, "[", kb, "]\n"); + } else { + actions = NULL; + } + // also extract list of action name + get_string_array_from_java(env, synchs, action_names_jstrings, action_names, num_actions); + // get vector for yes PS_PrintToMainLog(env, "Creating vector for yes... "); yes_vec = mtbdd_to_double_vector(ddman, yes, rvars, num_rvars, odd); @@ -193,7 +220,11 @@ jboolean min // min or max probabilities (true = min, false = max) soln2[i] = (h1 > l1) ? d1 : yes_vec[i]; // store adversary info (if required) if (adv_loop) if (h1 > l1) - for (k = adv_l; k < adv_h; k++) fprintf(fp_adv, "%d %d %g\n", i, cols[k], non_zeros[k]); + for (k = adv_l; k < adv_h; k++) { + fprintf(fp_adv, "%d %d %g", i, cols[k], non_zeros[k]); + if (actions != NULL) fprintf(fp_adv, " %s", actions[l1]>1?action_names[actions[l1]-2]:""); + fprintf(fp_adv, "\n"); + } } // check convergence @@ -260,8 +291,10 @@ jboolean min // min or max probabilities (true = min, false = max) if (ndsm) delete ndsm; if (yes_vec) delete[] yes_vec; if (soln2) delete[] soln2; + release_string_array_from_java(env, action_names_jstrings, action_names, num_actions); return ptr_to_jlong(soln); } //------------------------------------------------------------------------------ + diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java index 45688f73..32b51c88 100644 --- a/prism/src/sparse/PrismSparse.java +++ b/prism/src/sparse/PrismSparse.java @@ -27,6 +27,7 @@ package sparse; import java.io.FileNotFoundException; +import java.util.List; import prism.*; import jdd.*; @@ -255,10 +256,10 @@ public class PrismSparse } // pctl until (nondeterministic/mdp) - private static native long PS_NondetUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax); - public static DoubleVector NondetUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax) throws PrismException + private static native long PS_NondetUntil(long trans, long trans_actions, List synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax); + public static DoubleVector NondetUntil(JDDNode trans, JDDNode transActions, List synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax) throws PrismException { - long ptr = PS_NondetUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax); + long ptr = PS_NondetUntil(trans.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); } diff --git a/prism/src/sparse/sparse.cc b/prism/src/sparse/sparse.cc index d147563d..dbfcb6f4 100644 --- a/prism/src/sparse/sparse.cc +++ b/prism/src/sparse/sparse.cc @@ -43,6 +43,7 @@ static void traverse_mtbdd_vect_rec(DdManager *ddman, DdNode *dd, DdNode **vars, // global variables (used by local functions) static int count; static int *starts, *starts2; +static int *actions; static RMSparseMatrix *rmsm; static CMSparseMatrix *cmsm; static RCSparseMatrix *rcsm; @@ -668,7 +669,7 @@ NDSparseMatrix *build_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode **r // try/catch for memory allocation/deallocation } catch(std::bad_alloc e) { - if (ndsm) delete cmscsm; + if (ndsm) delete ndsm; if (matrices) delete[] matrices; if (matrices_bdds) { for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]); @@ -808,7 +809,7 @@ NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode // try/catch for memory allocation/deallocation } catch(std::bad_alloc e) { - if (ndsm) delete cmscsm; + if (ndsm) delete ndsm; if (matrices) delete[] matrices; if (matrices_bdds) { for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]); @@ -833,6 +834,88 @@ NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode //------------------------------------------------------------------------------ +// build nondeterministic (mdp) action vector to accompany a sparse matrix +// (i.e. a vector containing for every state and nondet choice, an index +// into the list of all action labels) +// throws std::bad_alloc on out-of-memory + +int *build_nd_action_vector(DdManager *ddman, DdNode *trans_actions, NDSparseMatrix *mdp_ndsm, DdNode **rvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd) +{ + int i, n, nm, nc; + DdNode *tmp = NULL, **matrices = NULL, **matrices_bdds = NULL; + + // try/catch for memory allocation/deallocation + try { + + // get stats from mdp sparse storage (num states/choices) + n = mdp_ndsm->n; + nc = mdp_ndsm->nc; + // break the mtbdd storing the action info into several (nm) mtbdds + // (this number nm should match the figure for the corresponding mdp; + // but we can't do this sanity check because that statistic is not retained.) + Cudd_Ref(trans_actions); + tmp = DD_Not(ddman, DD_Equals(ddman, trans_actions, 0)); + tmp = DD_ThereExists(ddman, tmp, rvars, num_vars); + nm = (int)DD_GetNumMinterms(ddman, tmp, num_ndvars); + Cudd_RecursiveDeref(ddman, tmp); + matrices = new DdNode*[nm]; + count = 0; + split_mdp_rec(ddman, trans_actions, ndvars, num_ndvars, 0, matrices); + // and for each one create a bdd storing which rows are non-empty + matrices_bdds = new DdNode*[nm]; + for (i = 0; i < nm; i++) { + Cudd_Ref(matrices[i]); + matrices_bdds[i] = DD_Not(ddman, DD_Equals(ddman, matrices[i], 0)); + } + + // create arrays + actions = NULL; actions = new int[nc]; + starts = NULL; starts = new int[n+1]; + + // build the (temporary) array 'starts' (like was done when building the sparse matrix for the mdp). + // in fact, this information is retrievable from the sparse matrix, but it may have + // been converted to counts, rather than offsets, so its easier to rebuild it. + // first traverse mtbdds to compute how many choices are in each row + for (i = 0; i < n+1; i++) starts[i] = 0; + for (i = 0; i < nm; i++) { + traverse_mtbdd_vect_rec(ddman, matrices_bdds[i], rvars, num_vars, 0, odd, 0, 1); + } + // and use this to compute the starts information + for (i = 1 ; i < n+1; i++) { + starts[i] += starts[i-1]; + } + + // now traverse the mtbdd to get the actual entries (action indices) + for (i = 0; i < nm; i++) { + traverse_mtbdd_vect_rec(ddman, matrices[i], rvars, num_vars, 0, odd, 0, 3); + } + + // try/catch for memory allocation/deallocation + } catch(std::bad_alloc e) { + if (actions) delete[] actions; + if (matrices) delete[] matrices; + if (matrices_bdds) { + for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]); + delete[] matrices_bdds; + } + if (starts) delete[] starts; + throw e; + } + + // clear up memory + for (i = 0; i < nm; i++) { + Cudd_RecursiveDeref(ddman, matrices_bdds[i]); + // nb: don't deref matrices array because that was just pointers, not new copies + } + delete[] starts; + delete[] matrices; + delete[] matrices_bdds; + + return actions; +} + +//------------------------------------------------------------------------------ + void split_mdp_rec(DdManager *ddman, DdNode *dd, DdNode **ndvars, int num_ndvars, int level, DdNode **matrices) { DdNode *e, *t; @@ -1062,7 +1145,14 @@ void traverse_mtbdd_vect_rec(DdManager *ddman, DdNode *dd, DdNode **vars, int nu case 2: starts[i]++; break; + + // mdp action vector - single pass + case 3: + actions[starts[i]] = (int)Cudd_V(dd); + starts[i]++; + break; } + return; }