diff --git a/prism/include/PrismSparse.h b/prism/include/PrismSparse.h
index ae30530f..1083df4f 100644
--- a/prism/include/PrismSparse.h
+++ b/prism/include/PrismSparse.h
@@ -226,10 +226,10 @@ JNIEXPORT jint JNICALL Java_sparse_PrismSparse_PS_1ExportMatrix
/*
* Class: sparse_PrismSparse
* Method: PS_ExportMDP
- * Signature: (JLjava/lang/String;JIJIJIJILjava/lang/String;)I
+ * Signature: (JJLjava/util/List;Ljava/lang/String;JIJIJIJILjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_sparse_PrismSparse_PS_1ExportMDP
- (JNIEnv *, jclass, jlong, jstring, jlong, jint, jlong, jint, jlong, jint, jlong, jint, jstring);
+ (JNIEnv *, jclass, jlong, jlong, jobject, jstring, jlong, jint, jlong, jint, jlong, jint, jlong, jint, jstring);
/*
* Class: sparse_PrismSparse
diff --git a/prism/src/prism/NondetModel.java b/prism/src/prism/NondetModel.java
index e8e682c9..0d4258dc 100644
--- a/prism/src/prism/NondetModel.java
+++ b/prism/src/prism/NondetModel.java
@@ -339,7 +339,7 @@ public class NondetModel extends ProbModel
if (!explicit) {
// can only do explicit (sparse matrix based) export for mdps
} else {
- PrismSparse.ExportMDP(trans, getTransSymbol(), allDDRowVars, allDDColVars, allDDNondetVars, odd,
+ PrismSparse.ExportMDP(trans, transActions, getSynchs(), getTransSymbol(), allDDRowVars, allDDColVars, allDDNondetVars, odd,
exportType, (file != null) ? file.getPath() : null);
}
}
diff --git a/prism/src/sparse/PS_ExportMDP.cc b/prism/src/sparse/PS_ExportMDP.cc
index 34450c71..c61014b3 100644
--- a/prism/src/sparse/PS_ExportMDP.cc
+++ b/prism/src/sparse/PS_ExportMDP.cc
@@ -31,6 +31,7 @@
#include
#include
#include "sparse.h"
+#include "prism.h"
#include "PrismSparseGlob.h"
#include "jnipointer.h"
#include
@@ -42,6 +43,8 @@ JNIEXPORT jint JNICALL Java_sparse_PrismSparse_PS_1ExportMDP
JNIEnv *env,
jclass cls,
jlong __jlongpointer m, // mdp
+jlong __jlongpointer ta, // trans action labels
+jobject synchs,
jstring na, // mdp name
jlong __jlongpointer rv, // row vars
jint num_rvars,
@@ -54,14 +57,20 @@ jint et, // export type
jstring fn // filename
)
{
- DdNode *mdp = jlong_to_DdNode(m); // mdp
- DdNode **rvars = jlong_to_DdNode_array(rv); // row vars
- DdNode **cvars = jlong_to_DdNode_array(cv); // col vars
+ DdNode *mdp = jlong_to_DdNode(m); // mdp
+ DdNode *trans_actions = jlong_to_DdNode(ta); // trans action labels
+ DdNode **rvars = jlong_to_DdNode_array(rv); // row vars
+ DdNode **cvars = jlong_to_DdNode_array(cv); // col vars
DdNode **ndvars = jlong_to_DdNode_array(ndv); // nondet vars
ODDNode *odd = jlong_to_ODDNode(od);
// sparse matrix
NDSparseMatrix *ndsm = NULL;
+ // action info
+ int *actions = NULL;
+ jstring *action_names_jstrings = NULL;
+ const char** action_names = NULL;
+ int num_actions;
// model stats
int i, j, k, n, nc, l1, h1, l2, h2;
long nnz;
@@ -80,6 +89,13 @@ jstring fn // filename
nnz = ndsm->nnz;
nc = ndsm->nc;
+ // if needed, and if info is available, build a vector of action indices for the mdp
+ // also extract list of action names
+ if (true && trans_actions != NULL) {
+ actions = build_nd_action_vector(ddman, mdp, trans_actions, ndsm, rvars, cvars, num_rvars, ndvars, num_ndvars, odd);
+ get_string_array_from_java(env, synchs, action_names_jstrings, action_names, num_actions);
+ }
+
// print file header
switch (export_type) {
case EXPORT_PLAIN: export_string("%d %d %d\n", n, nc, nnz); break;
@@ -107,17 +123,24 @@ jstring fn // filename
else { l2 = h2; h2 += choice_counts[j]; }
if (export_type == EXPORT_ROWS) export_string("%d", i);
else if (export_type == EXPORT_DOT || export_type == EXPORT_DOT_STATES) {
- export_string("%d -> %d.%d [ arrowhead=none,label=\"%d\" ];\n", i, i, j-l1, j-l1);
+ export_string("%d -> %d.%d [ arrowhead=none,label=\"%d", i, i, j-l1, j-l1);
+ if (actions != NULL) export_string(":%s", (actions[j]>0?action_names[actions[j]-1]:""));
+ export_string("\" ];\n");
export_string("%d.%d [ shape=circle,width=0.1,height=0.1,label=\"\" ];\n", i, j-l1);
}
for (k = l2; k < h2; k++) {
switch (export_type) {
- case EXPORT_PLAIN: export_string("%d %d %d %.12g\n", i, j-l1, cols[k], non_zeros[k]); break;
+ case EXPORT_PLAIN:
+ export_string("%d %d %d %.12g", i, j-l1, cols[k], non_zeros[k]);
+ if (actions != NULL) export_string(" %s", (actions[j]>0?action_names[actions[j]-1]:""));
+ export_string("\n");
+ break;
case EXPORT_MATLAB: export_string("%s%d(%d,%d)=%.12g;\n", export_name, j-l1+1, i+1, cols[k]+1, non_zeros[k]); break;
case EXPORT_DOT: case EXPORT_DOT_STATES: export_string("%d.%d -> %d [ label=\"%.12g\" ];\n", i, j-l1, cols[k], non_zeros[k]); break;
case EXPORT_ROWS: export_string(" %.12g:%d", non_zeros[k], cols[k]); break;
}
}
+ if (export_type == EXPORT_ROWS && actions != NULL) export_string(" %s", (actions[j]>0?action_names[actions[j]-1]:""));
if (export_type == EXPORT_ROWS) export_string("\n");
}
}
@@ -139,6 +162,7 @@ jstring fn // filename
// free memory
if (ndsm) delete ndsm;
+ //if (action_names != NULL) release_string_array_from_java(env, action_names_jstrings, action_names, num_actions);
return 0;
}
diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java
index a13f8785..ce2ba6d1 100644
--- a/prism/src/sparse/PrismSparse.java
+++ b/prism/src/sparse/PrismSparse.java
@@ -341,10 +341,10 @@ public class PrismSparse
}
// export mdp
- private static native int PS_ExportMDP(long mdp, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename);
- public static void ExportMDP(JDDNode mdp, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException
+ private static native int PS_ExportMDP(long mdp, long trans_actions, List synchs, String name, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long odd, int exportType, String filename);
+ public static void ExportMDP(JDDNode mdp, JDDNode transActions, List synchs, String name, JDDVars rows, JDDVars cols, JDDVars nondet, ODDNode odd, int exportType, String filename) throws FileNotFoundException, PrismException
{
- int res = PS_ExportMDP(mdp.ptr(), name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename);
+ int res = PS_ExportMDP(mdp.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, name, rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), odd.ptr(), exportType, filename);
if (res == -1) {
throw new FileNotFoundException();
}