diff --git a/prism/include/PrismSparse.h b/prism/include/PrismSparse.h index ae30530f..1083df4f 100644 --- a/prism/include/PrismSparse.h +++ b/prism/include/PrismSparse.h @@ -226,10 +226,10 @@ JNIEXPORT jint JNICALL Java_sparse_PrismSparse_PS_1ExportMatrix /* * Class: sparse_PrismSparse * Method: PS_ExportMDP - * Signature: (JLjava/lang/String;JIJIJIJILjava/lang/String;)I + * Signature: (JJLjava/util/List;Ljava/lang/String;JIJIJIJILjava/lang/String;)I */ JNIEXPORT jint JNICALL Java_sparse_PrismSparse_PS_1ExportMDP - (JNIEnv *, jclass, jlong, jstring, jlong, jint, jlong, jint, jlong, jint, jlong, jint, jstring); + (JNIEnv *, jclass, jlong, jlong, jobject, jstring, jlong, jint, jlong, jint, jlong, jint, jlong, jint, jstring); /* * Class: sparse_PrismSparse diff --git a/prism/src/prism/NondetModel.java b/prism/src/prism/NondetModel.java index e8e682c9..0d4258dc 100644 --- a/prism/src/prism/NondetModel.java +++ b/prism/src/prism/NondetModel.java @@ -339,7 +339,7 @@ public class NondetModel extends ProbModel if (!explicit) { // can only do explicit (sparse matrix based) export for mdps } else { - PrismSparse.ExportMDP(trans, getTransSymbol(), allDDRowVars, allDDColVars, allDDNondetVars, odd, + PrismSparse.ExportMDP(trans, transActions, getSynchs(), getTransSymbol(), allDDRowVars, allDDColVars, allDDNondetVars, odd, exportType, (file != null) ? file.getPath() : null); } } diff --git a/prism/src/sparse/PS_ExportMDP.cc b/prism/src/sparse/PS_ExportMDP.cc index 34450c71..c61014b3 100644 --- a/prism/src/sparse/PS_ExportMDP.cc +++ b/prism/src/sparse/PS_ExportMDP.cc @@ -31,6 +31,7 @@ #include #include #include "sparse.h" +#include "prism.h" #include "PrismSparseGlob.h" #include "jnipointer.h" #include @@ -42,6 +43,8 @@ JNIEXPORT jint JNICALL Java_sparse_PrismSparse_PS_1ExportMDP JNIEnv *env, jclass cls, jlong __jlongpointer m, // mdp +jlong __jlongpointer ta, // trans action labels +jobject synchs, jstring na, // mdp name jlong __jlongpointer rv, // row vars jint num_rvars, @@ -54,14 +57,20 @@ jint et, // export type jstring fn // filename ) { - DdNode *mdp = jlong_to_DdNode(m); // mdp - DdNode **rvars = jlong_to_DdNode_array(rv); // row vars - DdNode **cvars = jlong_to_DdNode_array(cv); // col vars + DdNode *mdp = jlong_to_DdNode(m); // mdp + DdNode *trans_actions = jlong_to_DdNode(ta); // trans action labels + DdNode **rvars = jlong_to_DdNode_array(rv); // row vars + DdNode **cvars = jlong_to_DdNode_array(cv); // col vars DdNode **ndvars = jlong_to_DdNode_array(ndv); // nondet vars ODDNode *odd = jlong_to_ODDNode(od); // sparse matrix NDSparseMatrix *ndsm = NULL; + // action info + int *actions = NULL; + jstring *action_names_jstrings = NULL; + const char** action_names = NULL; + int num_actions; // model stats int i, j, k, n, nc, l1, h1, l2, h2; long nnz; @@ -80,6 +89,13 @@ jstring fn // filename nnz = ndsm->nnz; nc = ndsm->nc; + // if needed, and if info is available, build a vector of action indices for the mdp + // also extract list of action names + if (true && trans_actions != NULL) { + actions = build_nd_action_vector(ddman, mdp, trans_actions, ndsm, rvars, cvars, num_rvars, ndvars, num_ndvars, odd); + get_string_array_from_java(env, synchs, action_names_jstrings, action_names, num_actions); + } + // print file header switch (export_type) { case EXPORT_PLAIN: export_string("%d %d %d\n", n, nc, nnz); break; @@ -107,17 +123,24 @@ jstring fn // filename else { l2 = h2; h2 += choice_counts[j]; } if (export_type == EXPORT_ROWS) export_string("%d", i); else if (export_type == EXPORT_DOT || export_type == EXPORT_DOT_STATES) { - export_string("%d -> %d.%d [ arrowhead=none,label=\"%d\" ];\n", i, i, j-l1, j-l1); + export_string("%d -> %d.%d [ arrowhead=none,label=\"%d", i, i, j-l1, j-l1); + if (actions != NULL) export_string(":%s", (actions[j]>0?action_names[actions[j]-1]:"")); + export_string("\" ];\n"); export_string("%d.%d [ shape=circle,width=0.1,height=0.1,label=\"\" ];\n", i, j-l1); } for (k = l2; k < h2; k++) { switch (export_type) { - case EXPORT_PLAIN: export_string("%d %d %d %.12g\n", i, j-l1, cols[k], non_zeros[k]); break; + case EXPORT_PLAIN: + export_string("%d %d %d %.12g", i, j-l1, cols[k], non_zeros[k]); + if (actions != NULL) export_string(" %s", (actions[j]>0?action_names[actions[j]-1]:"")); + export_string("\n"); + break; case EXPORT_MATLAB: export_string("%s%d(%d,%d)=%.12g;\n", export_name, j-l1+1, i+1, cols[k]+1, non_zeros[k]); break; case EXPORT_DOT: case EXPORT_DOT_STATES: export_string("%d.%d -> %d [ label=\"%.12g\" ];\n", i, j-l1, cols[k], non_zeros[k]); break; case EXPORT_ROWS: export_string(" %.12g:%d", non_zeros[k], cols[k]); break; } } + if (export_type == EXPORT_ROWS && actions != NULL) export_string(" %s", (actions[j]>0?action_names[actions[j]-1]:"")); if (export_type == EXPORT_ROWS) export_string("\n"); } } @@ -139,6 +162,7 @@ jstring fn // filename // free memory if (ndsm) delete ndsm; + //if (action_names != NULL) release_string_array_from_java(env, action_names_jstrings, action_names, num_actions); return 0; } diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java index a13f8785..ce2ba6d1 100644 --- a/prism/src/sparse/PrismSparse.java +++ b/prism/src/sparse/PrismSparse.java @@ -341,10 +341,10 @@ public class PrismSparse } // export mdp - private static native int PS_ExportMDP(long mdp, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename); - public static void ExportMDP(JDDNode mdp, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException + private static native int PS_ExportMDP(long mdp, long trans_actions, List synchs, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename); + public static void ExportMDP(JDDNode mdp, JDDNode transActions, List synchs, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException { - int res = PS_ExportMDP(mdp.ptr(), name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename); + int res = PS_ExportMDP(mdp.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename); if (res == -1) { throw new FileNotFoundException(); }