From b09727fda4b54553f92923c88d2009c1aebcb488 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Tue, 17 Nov 2009 16:56:30 +0000 Subject: [PATCH] Construction of (symbolic) action label info (currently enabled), functions to convert to sparse storage, and use of this in the adversary generation for MDP until (still switched off for now). git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@1559 bbc10eb1-c90d-0410-af57-cb519fbb1720 --- prism/include/PrismSparse.h | 4 +- prism/include/prism.h | 4 ++ prism/include/sparse.h | 1 + prism/src/prism/Model.java | 6 +- prism/src/prism/Modules2MTBDD.java | 34 +++++++-- prism/src/prism/NondetModel.java | 18 ++--- prism/src/prism/NondetModelChecker.java | 6 +- prism/src/prism/ProbModel.java | 58 +++++++++++++-- prism/src/prism/StateModelChecker.java | 2 + prism/src/prism/StochModel.java | 18 +++++ prism/src/prism/prism.cc | 46 ++++++++++++ prism/src/sparse/PS_NondetUntil.cc | 53 +++++++++++--- prism/src/sparse/PrismSparse.java | 7 +- prism/src/sparse/sparse.cc | 94 ++++++++++++++++++++++++- 14 files changed, 309 insertions(+), 42 deletions(-) diff --git a/prism/include/PrismSparse.h b/prism/include/PrismSparse.h index 947a7e99..ae30530f 100644 --- a/prism/include/PrismSparse.h +++ b/prism/include/PrismSparse.h @@ -162,10 +162,10 @@ JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1NondetBoundedUntil /* * Class: sparse_PrismSparse * Method: PS_NondetUntil - * Signature: (JJJIJIJIJJZ)J + * Signature: (JJLjava/util/List;JJIJIJIJJZ)J */ JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1NondetUntil - (JNIEnv *, jclass, jlong, jlong, jlong, jint, jlong, jint, jlong, jint, jlong, jlong, jboolean); + (JNIEnv *, jclass, jlong, jlong, jobject, jlong, jlong, jint, jlong, jint, jlong, jint, jlong, jlong, jboolean); /* * Class: sparse_PrismSparse diff --git a/prism/include/prism.h b/prism/include/prism.h index 6ab9ad26..962d5cda 100644 --- a/prism/include/prism.h +++ b/prism/include/prism.h @@ -24,6 +24,8 @@ // //============================================================================== +#include + // Flags for building Windows DLLs #ifdef __MINGW32__ #define EXPORT __declspec(dllexport) @@ -41,6 +43,8 @@ typedef struct FoxGlynnWeights } FoxGlynnWeights; // Function prototypes +EXPORT void get_string_array_from_java(JNIEnv *env, jobject strings_list, jstring *&strings_jstrings, const char **&strings, jint &size); +EXPORT void release_string_array_from_java(JNIEnv *env, jstring *strings_jstrings, const char **strings, jint size); EXPORT FoxGlynnWeights fox_glynn(double q_tmax, double underflow, double overflow, double accuracy); //------------------------------------------------------------------------------ diff --git a/prism/include/sparse.h b/prism/include/sparse.h index 7f8bcbe0..8eec3241 100644 --- a/prism/include/sparse.h +++ b/prism/include/sparse.h @@ -172,6 +172,7 @@ CMSRSparseMatrix *build_cmsr_sparse_matrix(DdManager *ddman, DdNode *matrix, DdN CMSCSparseMatrix *build_cmsc_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd); NDSparseMatrix *build_nd_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd); NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode *submdp, DdNode **rvars, DdNode **cvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd); +int *build_nd_action_vector(DdManager *ddman, DdNode *trans_actions, NDSparseMatrix *mdp_ndsm, DdNode **rvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd); RMSparseMatrix *build_rm_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd, bool transpose); CMSparseMatrix *build_cm_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd, bool transpose); diff --git a/prism/src/prism/Model.java b/prism/src/prism/Model.java index e87a4a2d..86fe0482 100644 --- a/prism/src/prism/Model.java +++ b/prism/src/prism/Model.java @@ -27,7 +27,7 @@ package prism; import java.io.*; -import java.util.Vector; +import java.util.*; import jdd.*; import odd.*; @@ -50,6 +50,7 @@ public interface Model int getVarHigh(int i); int getVarRange(int i); Values getConstantValues(); + List getSynchs(); String globalToLocal(long x); int globalToLocal(long x, int l); State convertBddToState(JDDNode dd); @@ -78,6 +79,7 @@ public interface Model JDDNode getTransRewards(); JDDNode getTransRewards(int i); JDDNode getTransRewards(String s); + JDDNode getTransActions(); JDDVars[] getVarDDRowVars(); JDDVars[] getVarDDColVars(); JDDVars getVarDDRowVars(int i); @@ -95,12 +97,14 @@ public interface Model ODDNode getODD(); + void setSynchs(Vector synchs); void resetTrans(JDDNode trans); void resetTransRewards(int i, JDDNode transRewards); void doReachability(); void doReachability(boolean extraReachInfo); void skipReachability(); void setReach(JDDNode reach); + void setTransActions(JDDNode transActions); void filterReachableStates(); void findDeadlocks(); void fixDeadlocks(); diff --git a/prism/src/prism/Modules2MTBDD.java b/prism/src/prism/Modules2MTBDD.java index 236061cc..5b3bdb1b 100644 --- a/prism/src/prism/Modules2MTBDD.java +++ b/prism/src/prism/Modules2MTBDD.java @@ -77,6 +77,7 @@ public class Modules2MTBDD private JDDNode start; // dd for start state private JDDNode stateRewards[]; // dds for state rewards private JDDNode transRewards[]; // dds for transition rewards + private JDDNode transActions; // dd for transition action labels private JDDNode transInd; // dds for independent bits of trans private JDDNode transSynch[]; // dds for synch action parts of trans private JDDVars allDDRowVars; // all dd vars (rows) @@ -116,8 +117,10 @@ public class Modules2MTBDD private int numModulesAfterSymm; // number of modules in the PRISM file after the symmetric ones private int numSymmModules; // number of symmetric components - // hidden option - do we also store each part of the transition matrix separately? - private boolean storeTransParts = true; + // hidden option - do we also store each part of the transition matrix separately? (now defunct) + private boolean storeTransParts = false; + // hidden option - do we also store action info for the transition matrix? (supersedes the above) + private boolean storeTransActions = true; // data structure used to store mtbdds and related info // for some component of the whole model @@ -164,8 +167,9 @@ public class Modules2MTBDD doSymmetry = !(s == null || s == ""); } - // main method - translate + @SuppressWarnings("unchecked") // for clone of vector in translate() + // main method - translate public Model translate() throws PrismException { Model model = null; @@ -261,14 +265,21 @@ public class Modules2MTBDD numVars, varList, varDDRowVars, varDDColVars, constantValues); } + // We also store a copy of the list of action label names + model.setSynchs((Vector)synchs.clone()); + // For MDPs, we also store the DDs used to construct the part // of the transition matrix that corresponds to each action if (modelType == ModelType.MDP && storeTransParts) { - ((NondetModel)model).setSynchs((Vector)synchs.clone()); ((NondetModel)model).setTransInd(transInd); ((NondetModel)model).setTransSynch(transSynch); } + // if required, we also store info about action labels + if (storeTransActions) { + model.setTransActions(transActions); + } + // do reachability (or not) if (prism.getDoReach()) { mainLog.print("\nComputing reachable states...\n"); @@ -838,6 +849,21 @@ public class Modules2MTBDD } } + // If required, we also build an MTBDD to store the action labels for each transition + if (storeTransActions) { + if (modelType == ModelType.MDP) { + transActions = JDD.Constant(0); + JDD.Ref(sysDDs.ind.trans); + tmp = JDD.ThereExists(JDD.GreaterThan(sysDDs.ind.trans, 0), allDDColVars); + transActions = JDD.Apply(JDD.PLUS, transActions, JDD.Apply(JDD.TIMES, tmp, JDD.Constant(1))); + for (i = 0; i < numSynchs; i++) { + JDD.Ref(sysDDs.synchs[i].trans); + tmp = JDD.ThereExists(JDD.GreaterThan(sysDDs.synchs[i].trans, 0), allDDColVars); + transActions = JDD.Apply(JDD.PLUS, transActions, JDD.Apply(JDD.TIMES, tmp, JDD.Constant(2+i))); + } + } + } + // deref bits of ComponentDD objects - we don't need them any more JDD.Deref(sysDDs.ind.guards); JDD.Deref(sysDDs.ind.trans); diff --git a/prism/src/prism/NondetModel.java b/prism/src/prism/NondetModel.java index 99d4206d..7c75e31b 100644 --- a/prism/src/prism/NondetModel.java +++ b/prism/src/prism/NondetModel.java @@ -41,8 +41,6 @@ import sparse.*; public class NondetModel extends ProbModel { // Extra info - protected Vector synchs; // synchronising action labels - protected double numSynchs; // number of synchronising actions protected double numChoices; // number of choices // Extra dd stuff @@ -97,11 +95,6 @@ public class NondetModel extends ProbModel return allDDNondetVars; } - public Vector getSynchs() - { - return synchs; - } - public JDDNode getTransInd() { return transInd; @@ -135,12 +128,6 @@ public class NondetModel extends ProbModel // set methods for things not set up in constructor - public void setSynchs(Vector synchs) - { - this.synchs = synchs; - this.numSynchs = synchs.size(); - } - public void setTransInd(JDDNode transInd) { this.transInd = transInd; @@ -336,6 +323,11 @@ public class NondetModel extends ProbModel } } } + if (transActions != null && !transActions.equals(JDD.ZERO)) { + log.print("Action label indices: "); + log.print(JDD.GetNumNodes(transActions) + " nodes ("); + log.print(JDD.GetNumTerminals(transActions) + " terminal)\n"); + } } // export transition matrix to a file diff --git a/prism/src/prism/NondetModelChecker.java b/prism/src/prism/NondetModelChecker.java index 97b97718..65066667 100644 --- a/prism/src/prism/NondetModelChecker.java +++ b/prism/src/prism/NondetModelChecker.java @@ -645,7 +645,7 @@ public class NondetModelChecker extends NonProbModelChecker else { // for fairness, we compute max here try { - probs = computeUntilProbs(trans, trans01, newb1, newb2, min && !fairness); + probs = computeUntilProbs(trans, transActions, trans01, newb1, newb2, min && !fairness); } catch (PrismException e) { JDD.Deref(newb1); JDD.Deref(newb2); @@ -934,7 +934,7 @@ public class NondetModelChecker extends NonProbModelChecker // note: this function doesn't need to know anything about fairness // it is just told whether to compute min or max probabilities - protected StateProbs computeUntilProbs(JDDNode tr, JDDNode tr01, JDDNode b1, JDDNode b2, boolean min) + protected StateProbs computeUntilProbs(JDDNode tr, JDDNode tra, JDDNode tr01, JDDNode b1, JDDNode b2, boolean min) throws PrismException { JDDNode yes, no, maybe; @@ -1016,7 +1016,7 @@ public class NondetModelChecker extends NonProbModelChecker probs = new StateProbsMTBDD(probsMTBDD, model); break; case Prism.SPARSE: - probsDV = PrismSparse.NondetUntil(tr, odd, allDDRowVars, allDDColVars, allDDNondetVars, yes, maybe, + probsDV = PrismSparse.NondetUntil(tr, tra, model.getSynchs(), odd, allDDRowVars, allDDColVars, allDDNondetVars, yes, maybe, min); probs = new StateProbsDV(probsDV, model); break; diff --git a/prism/src/prism/ProbModel.java b/prism/src/prism/ProbModel.java index 7496cac2..38ec3894 100644 --- a/prism/src/prism/ProbModel.java +++ b/prism/src/prism/ProbModel.java @@ -27,8 +27,7 @@ package prism; import java.io.*; -import java.util.BitSet; -import java.util.Vector; +import java.util.*; import jdd.*; import odd.*; @@ -51,6 +50,9 @@ public class ProbModel implements Model protected VarList varList; // list of module variables protected long[] gtol; // numbers for use by globalToLocal protected Values constantValues; // values of constants + // actions + protected int numSynchs; // number of synchronising actions + protected Vector synchs; // synchronising action labels // rewards protected int numRewardStructs; // number of reward structs protected String[] rewardStructNames; // reward struct names @@ -71,6 +73,8 @@ public class ProbModel implements Model protected JDDNode fixdl; // fixed deadlock states dd protected JDDNode stateRewards[]; // state rewards dds protected JDDNode transRewards[]; // transition rewards dds + protected JDDNode transActions; // dd for transition action labels + // dd vars protected JDDVars[] varDDRowVars; // dd vars for each module variable (rows) protected JDDVars[] varDDColVars; // dd vars for each module variable (cols) @@ -155,6 +159,14 @@ public class ProbModel implements Model return constantValues; } + /** + * Get vector of action label names. + */ + public List getSynchs() + { + return synchs; + } + // rewards public int getNumRewardStructs() { @@ -277,6 +289,11 @@ public class ProbModel implements Model return null; } + public JDDNode getTransActions() + { + return transActions; + } + // dd vars public JDDVars[] getVarDDRowVars() { @@ -391,7 +408,10 @@ public class ProbModel implements Model varDDRowVars = vrv; varDDColVars = vcv; constantValues = cv; - + + // action label info (optional) is initially null + transActions = null; + // compute numbers for globalToLocal converter gtol = new long[numVars]; for (i = 0; i < numVars; i++) { @@ -409,6 +429,15 @@ public class ProbModel implements Model numStartStates = JDD.GetNumMinterms(start, allDDRowVars.n()); } + /** + * Set vector of action label names. + */ + public void setSynchs(Vector synchs) + { + this.synchs = synchs; + this.numSynchs = synchs.size(); + } + /** * Reset transition matrix DD */ @@ -474,7 +503,15 @@ public class ProbModel implements Model // build odd odd = ODDUtils.BuildODD(reach, allDDRowVars); } - + + /** + * Set the DD used to store transitoin action label indices. + */ + public void setTransActions(JDDNode transActions) + { + this.transActions = transActions; + } + // remove non-reachable states from various dds // (and calculate num transitions) @@ -508,6 +545,13 @@ public class ProbModel implements Model transRewards[i] = JDD.Apply(JDD.TIMES, tmp, transRewards[i]); } + // Action label indices matrix + // (just filter rows here; subclasses, e.g. CTMCs, may do more subsequently) + if (transActions != null) { + JDD.Ref(reach); + transActions = JDD.Apply(JDD.TIMES, reach, transActions); + } + // filter start states, work out number of initial states JDD.Ref(reach); start = JDD.Apply(JDD.TIMES, reach, start); @@ -619,6 +663,11 @@ public class ProbModel implements Model } } } + if (transActions != null && !transActions.equals(JDD.ZERO)) { + log.print("Action label indices: "); + log.print(JDD.GetNumNodes(transActions) + " nodes ("); + log.print(JDD.GetNumTerminals(transActions) + " terminal)\n"); + } } // export transition matrix to a file @@ -795,5 +844,6 @@ public class ProbModel implements Model JDD.Deref(stateRewards[i]); JDD.Deref(transRewards[i]); } + if (transActions != null) JDD.Deref(transActions); } } diff --git a/prism/src/prism/StateModelChecker.java b/prism/src/prism/StateModelChecker.java index af4471e4..f6e93d3e 100644 --- a/prism/src/prism/StateModelChecker.java +++ b/prism/src/prism/StateModelChecker.java @@ -53,6 +53,7 @@ public class StateModelChecker implements ModelChecker protected VarList varList; protected JDDNode trans; protected JDDNode trans01; + protected JDDNode transActions; protected JDDNode start; protected JDDNode reach; protected ODDNode odd; @@ -89,6 +90,7 @@ public class StateModelChecker implements ModelChecker varList = model.getVarList(); trans = model.getTrans(); trans01 = model.getTrans01(); + transActions = model.getTransActions(); start = model.getStart(); reach = model.getReach(); odd = model.getODD(); diff --git a/prism/src/prism/StochModel.java b/prism/src/prism/StochModel.java index a15fad48..1f19dfcf 100644 --- a/prism/src/prism/StochModel.java +++ b/prism/src/prism/StochModel.java @@ -61,4 +61,22 @@ public class StochModel extends ProbModel { super(tr, s, sr, trr, rsn, arv, acv, ddvn, nm, mn, mrv, mcv, nv, vl, vrv, vcv, cv); } + + /** + * Remove non-reachable states from various DDs. + * Most of the work is done in the superclass (ProbModel); + * just a few extra things to do here. + */ + public void filterReachableStates() + { + super.filterReachableStates(); + + // If required, also filter columns of action label indices matrix + // (for the superclass - DTMCs - we only store info per state). + if (transActions != null) { + JDD.Ref(reach); + JDDNode tmp = JDD.PermuteVariables(reach, allDDRowVars, allDDColVars); + transActions = JDD.Apply(JDD.TIMES, tmp, transActions); + } + } } diff --git a/prism/src/prism/prism.cc b/prism/src/prism/prism.cc index 53239846..633b1d6c 100644 --- a/prism/src/prism/prism.cc +++ b/prism/src/prism/prism.cc @@ -29,6 +29,52 @@ #include "prism.h" #include #include +#include + +//------------------------------------------------------------------------------ + +// convert a list of strings (from java/jni) to an array of c strings. +// actually stores arrays of both jstring objects and c strings, and also size +// (because need these to free memory afterwards). + +void get_string_array_from_java(JNIEnv *env, jobject strings_list, jstring *&strings_jstrings, const char **&strings, jint &size) +{ + int i, j; + jclass vn_cls; + jmethodID vn_mid; + // get size of vector of strings + vn_cls = env->GetObjectClass(strings_list); + vn_mid = env->GetMethodID(vn_cls, "size", "()I"); + if (vn_mid == 0) { + return; + } + size = env->CallIntMethod(strings_list,vn_mid); + // put strings from vector into array + strings_jstrings = new jstring[size]; + strings = new const char*[size]; + vn_mid = env->GetMethodID(vn_cls, "get", "(I)Ljava/lang/Object;"); + if (vn_mid == 0) { + return; + } + for (i = 0; i < size; i++) { + strings_jstrings[i] = (jstring)env->CallObjectMethod(strings_list, vn_mid, i); + strings[i] = env->GetStringUTFChars(strings_jstrings[i], 0); + } +} + +//------------------------------------------------------------------------------ + +// release the memory from a list of strings created by get_string_array_from_java + +void release_string_array_from_java(JNIEnv *env, jstring *strings_jstrings, const char **strings, jint size) +{ + // release memory + for (int i = 0; i < size; i++) { + env->ReleaseStringUTFChars(strings_jstrings[i], strings[i]); + } + delete[] strings_jstrings; + delete[] strings; +} //------------------------------------------------------------------------------ diff --git a/prism/src/sparse/PS_NondetUntil.cc b/prism/src/sparse/PS_NondetUntil.cc index 479a2f5c..f116cbfa 100644 --- a/prism/src/sparse/PS_NondetUntil.cc +++ b/prism/src/sparse/PS_NondetUntil.cc @@ -33,6 +33,7 @@ #include #include #include "sparse.h" +#include "prism.h" #include "PrismSparseGlob.h" #include "jnipointer.h" #include @@ -43,7 +44,9 @@ JNIEXPORT jlong __jlongpointer JNICALL Java_sparse_PrismSparse_PS_1NondetUntil ( JNIEnv *env, jclass cls, -jlong __jlongpointer t, // trans matrix +jlong __jlongpointer t, // trans matrix +jlong __jlongpointer ta, // trans action labels +jobject synchs, jlong __jlongpointer od, // odd jlong __jlongpointer rv, // row vars jint num_rvars, @@ -51,22 +54,23 @@ jlong __jlongpointer cv, // col vars jint num_cvars, jlong __jlongpointer ndv, // nondet vars jint num_ndvars, -jlong __jlongpointer y, // 'yes' states -jlong __jlongpointer m, // 'maybe' states -jboolean min // min or max probabilities (true = min, false = max) +jlong __jlongpointer y, // 'yes' states +jlong __jlongpointer m, // 'maybe' states +jboolean min // min or max probabilities (true = min, false = max) ) { // cast function parameters - DdNode *trans = jlong_to_DdNode(t); // trans matrix - ODDNode *odd = jlong_to_ODDNode(od); // reachable states + DdNode *trans = jlong_to_DdNode(t); // trans matrix + DdNode *trans_actions = jlong_to_DdNode(ta); // trans action labels + ODDNode *odd = jlong_to_ODDNode(od); // reachable states DdNode **rvars = jlong_to_DdNode_array(rv); // row vars DdNode **cvars = jlong_to_DdNode_array(cv); // col vars DdNode **ndvars = jlong_to_DdNode_array(ndv); // nondet vars - DdNode *yes = jlong_to_DdNode(y); // 'yes' states - DdNode *maybe = jlong_to_DdNode(m); // 'maybe' states + DdNode *yes = jlong_to_DdNode(y); // 'yes' states + DdNode *maybe = jlong_to_DdNode(m); // 'maybe' states // mtbdds - DdNode *a = NULL; + DdNode *a = NULL, *tmp = NULL; // model stats int n, nc; long nnz; @@ -81,6 +85,10 @@ jboolean min // min or max probabilities (true = min, false = max) bool adv = false, adv_loop = false; FILE *fp_adv = NULL; int adv_l, adv_h; + int *actions; + jstring *action_names_jstrings; + const char** action_names; + int num_actions; // misc int i, j, k, l1, h1, l2, h2, iters; double d1, d2, kb, kbt; @@ -112,6 +120,25 @@ jboolean min // min or max probabilities (true = min, false = max) PS_PrintToMainLog(env, "[n=%d, nc=%d, nnz=%d, k=%d] ", n, nc, nnz, ndsm->k); PS_PrintMemoryToMainLog(env, "[", kb, "]\n"); + // if needed, and if info is available, build a vector of action indices for the mdp + if (adv && trans_actions != NULL) { + PS_PrintToMainLog(env, "Building action information... "); + // first need to filter out unwanted rows + Cudd_Ref(trans_actions); + Cudd_Ref(maybe); + tmp = DD_Apply(ddman, APPLY_TIMES, trans_actions, maybe); + // then convert to a vector of integer indices + actions = build_nd_action_vector(ddman, tmp, ndsm, rvars, num_rvars, ndvars, num_ndvars, odd); + Cudd_RecursiveDeref(ddman, tmp); + kb = n*4.0/1024.0; + kbt += kb; + PS_PrintMemoryToMainLog(env, "[", kb, "]\n"); + } else { + actions = NULL; + } + // also extract list of action name + get_string_array_from_java(env, synchs, action_names_jstrings, action_names, num_actions); + // get vector for yes PS_PrintToMainLog(env, "Creating vector for yes... "); yes_vec = mtbdd_to_double_vector(ddman, yes, rvars, num_rvars, odd); @@ -193,7 +220,11 @@ jboolean min // min or max probabilities (true = min, false = max) soln2[i] = (h1 > l1) ? d1 : yes_vec[i]; // store adversary info (if required) if (adv_loop) if (h1 > l1) - for (k = adv_l; k < adv_h; k++) fprintf(fp_adv, "%d %d %g\n", i, cols[k], non_zeros[k]); + for (k = adv_l; k < adv_h; k++) { + fprintf(fp_adv, "%d %d %g", i, cols[k], non_zeros[k]); + if (actions != NULL) fprintf(fp_adv, " %s", actions[l1]>1?action_names[actions[l1]-2]:""); + fprintf(fp_adv, "\n"); + } } // check convergence @@ -260,8 +291,10 @@ jboolean min // min or max probabilities (true = min, false = max) if (ndsm) delete ndsm; if (yes_vec) delete[] yes_vec; if (soln2) delete[] soln2; + release_string_array_from_java(env, action_names_jstrings, action_names, num_actions); return ptr_to_jlong(soln); } //------------------------------------------------------------------------------ + diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java index 45688f73..32b51c88 100644 --- a/prism/src/sparse/PrismSparse.java +++ b/prism/src/sparse/PrismSparse.java @@ -27,6 +27,7 @@ package sparse; import java.io.FileNotFoundException; +import java.util.List; import prism.*; import jdd.*; @@ -255,10 +256,10 @@ public class PrismSparse } // pctl until (nondeterministic/mdp) - private static native long PS_NondetUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax); - public static DoubleVector NondetUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax) throws PrismException + private static native long PS_NondetUntil(long trans, long trans_actions, List synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax); + public static DoubleVector NondetUntil(JDDNode trans, JDDNode transActions, List synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax) throws PrismException { - long ptr = PS_NondetUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax); + long ptr = PS_NondetUntil(trans.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); } diff --git a/prism/src/sparse/sparse.cc b/prism/src/sparse/sparse.cc index d147563d..dbfcb6f4 100644 --- a/prism/src/sparse/sparse.cc +++ b/prism/src/sparse/sparse.cc @@ -43,6 +43,7 @@ static void traverse_mtbdd_vect_rec(DdManager *ddman, DdNode *dd, DdNode **vars, // global variables (used by local functions) static int count; static int *starts, *starts2; +static int *actions; static RMSparseMatrix *rmsm; static CMSparseMatrix *cmsm; static RCSparseMatrix *rcsm; @@ -668,7 +669,7 @@ NDSparseMatrix *build_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode **r // try/catch for memory allocation/deallocation } catch(std::bad_alloc e) { - if (ndsm) delete cmscsm; + if (ndsm) delete ndsm; if (matrices) delete[] matrices; if (matrices_bdds) { for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]); @@ -808,7 +809,7 @@ NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode // try/catch for memory allocation/deallocation } catch(std::bad_alloc e) { - if (ndsm) delete cmscsm; + if (ndsm) delete ndsm; if (matrices) delete[] matrices; if (matrices_bdds) { for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]); @@ -833,6 +834,88 @@ NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode //------------------------------------------------------------------------------ +// build nondeterministic (mdp) action vector to accompany a sparse matrix +// (i.e. a vector containing for every state and nondet choice, an index +// into the list of all action labels) +// throws std::bad_alloc on out-of-memory + +int *build_nd_action_vector(DdManager *ddman, DdNode *trans_actions, NDSparseMatrix *mdp_ndsm, DdNode **rvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd) +{ + int i, n, nm, nc; + DdNode *tmp = NULL, **matrices = NULL, **matrices_bdds = NULL; + + // try/catch for memory allocation/deallocation + try { + + // get stats from mdp sparse storage (num states/choices) + n = mdp_ndsm->n; + nc = mdp_ndsm->nc; + // break the mtbdd storing the action info into several (nm) mtbdds + // (this number nm should match the figure for the corresponding mdp; + // but we can't do this sanity check because that statistic is not retained.) + Cudd_Ref(trans_actions); + tmp = DD_Not(ddman, DD_Equals(ddman, trans_actions, 0)); + tmp = DD_ThereExists(ddman, tmp, rvars, num_vars); + nm = (int)DD_GetNumMinterms(ddman, tmp, num_ndvars); + Cudd_RecursiveDeref(ddman, tmp); + matrices = new DdNode*[nm]; + count = 0; + split_mdp_rec(ddman, trans_actions, ndvars, num_ndvars, 0, matrices); + // and for each one create a bdd storing which rows are non-empty + matrices_bdds = new DdNode*[nm]; + for (i = 0; i < nm; i++) { + Cudd_Ref(matrices[i]); + matrices_bdds[i] = DD_Not(ddman, DD_Equals(ddman, matrices[i], 0)); + } + + // create arrays + actions = NULL; actions = new int[nc]; + starts = NULL; starts = new int[n+1]; + + // build the (temporary) array 'starts' (like was done when building the sparse matrix for the mdp). + // in fact, this information is retrievable from the sparse matrix, but it may have + // been converted to counts, rather than offsets, so its easier to rebuild it. + // first traverse mtbdds to compute how many choices are in each row + for (i = 0; i < n+1; i++) starts[i] = 0; + for (i = 0; i < nm; i++) { + traverse_mtbdd_vect_rec(ddman, matrices_bdds[i], rvars, num_vars, 0, odd, 0, 1); + } + // and use this to compute the starts information + for (i = 1 ; i < n+1; i++) { + starts[i] += starts[i-1]; + } + + // now traverse the mtbdd to get the actual entries (action indices) + for (i = 0; i < nm; i++) { + traverse_mtbdd_vect_rec(ddman, matrices[i], rvars, num_vars, 0, odd, 0, 3); + } + + // try/catch for memory allocation/deallocation + } catch(std::bad_alloc e) { + if (actions) delete[] actions; + if (matrices) delete[] matrices; + if (matrices_bdds) { + for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]); + delete[] matrices_bdds; + } + if (starts) delete[] starts; + throw e; + } + + // clear up memory + for (i = 0; i < nm; i++) { + Cudd_RecursiveDeref(ddman, matrices_bdds[i]); + // nb: don't deref matrices array because that was just pointers, not new copies + } + delete[] starts; + delete[] matrices; + delete[] matrices_bdds; + + return actions; +} + +//------------------------------------------------------------------------------ + void split_mdp_rec(DdManager *ddman, DdNode *dd, DdNode **ndvars, int num_ndvars, int level, DdNode **matrices) { DdNode *e, *t; @@ -1062,7 +1145,14 @@ void traverse_mtbdd_vect_rec(DdManager *ddman, DdNode *dd, DdNode **vars, int nu case 2: starts[i]++; break; + + // mdp action vector - single pass + case 3: + actions[starts[i]] = (int)Cudd_V(dd); + starts[i]++; + break; } + return; }