Browse Source

Construction of (symbolic) action label info (currently enabled), functions to convert to sparse storage, and use of this in the adversary generation for MDP until (still switched off for now).

git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@1559 bbc10eb1-c90d-0410-af57-cb519fbb1720
master
Dave Parker 16 years ago
parent
commit
b09727fda4
  1. 4
      prism/include/PrismSparse.h
  2. 4
      prism/include/prism.h
  3. 1
      prism/include/sparse.h
  4. 6
      prism/src/prism/Model.java
  5. 34
      prism/src/prism/Modules2MTBDD.java
  6. 18
      prism/src/prism/NondetModel.java
  7. 6
      prism/src/prism/NondetModelChecker.java
  8. 58
      prism/src/prism/ProbModel.java
  9. 2
      prism/src/prism/StateModelChecker.java
  10. 18
      prism/src/prism/StochModel.java
  11. 46
      prism/src/prism/prism.cc
  12. 53
      prism/src/sparse/PS_NondetUntil.cc
  13. 7
      prism/src/sparse/PrismSparse.java
  14. 94
      prism/src/sparse/sparse.cc

4
prism/include/PrismSparse.h

@ -162,10 +162,10 @@ JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1NondetBoundedUntil
/*
* Class: sparse_PrismSparse
* Method: PS_NondetUntil
* Signature: (JJJIJIJIJJZ)J
* Signature: (JJLjava/util/List;JJIJIJIJJZ)J
*/
JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1NondetUntil
(JNIEnv *, jclass, jlong, jlong, jlong, jint, jlong, jint, jlong, jint, jlong, jlong, jboolean);
(JNIEnv *, jclass, jlong, jlong, jobject, jlong, jlong, jint, jlong, jint, jlong, jint, jlong, jlong, jboolean);
/*
* Class: sparse_PrismSparse

4
prism/include/prism.h

@ -24,6 +24,8 @@
//
//==============================================================================
#include <jni.h>
// Flags for building Windows DLLs
#ifdef __MINGW32__
#define EXPORT __declspec(dllexport)
@ -41,6 +43,8 @@ typedef struct FoxGlynnWeights
} FoxGlynnWeights;
// Function prototypes
EXPORT void get_string_array_from_java(JNIEnv *env, jobject strings_list, jstring *&strings_jstrings, const char **&strings, jint &size);
EXPORT void release_string_array_from_java(JNIEnv *env, jstring *strings_jstrings, const char **strings, jint size);
EXPORT FoxGlynnWeights fox_glynn(double q_tmax, double underflow, double overflow, double accuracy);
//------------------------------------------------------------------------------

1
prism/include/sparse.h

@ -172,6 +172,7 @@ CMSRSparseMatrix *build_cmsr_sparse_matrix(DdManager *ddman, DdNode *matrix, DdN
CMSCSparseMatrix *build_cmsc_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd);
NDSparseMatrix *build_nd_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd);
NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode *submdp, DdNode **rvars, DdNode **cvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd);
int *build_nd_action_vector(DdManager *ddman, DdNode *trans_actions, NDSparseMatrix *mdp_ndsm, DdNode **rvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd);
RMSparseMatrix *build_rm_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd, bool transpose);
CMSparseMatrix *build_cm_sparse_matrix(DdManager *ddman, DdNode *matrix, DdNode **rvars, DdNode **cvars, int num_vars, ODDNode *odd, bool transpose);

6
prism/src/prism/Model.java

@ -27,7 +27,7 @@
package prism;
import java.io.*;
import java.util.Vector;
import java.util.*;
import jdd.*;
import odd.*;
@ -50,6 +50,7 @@ public interface Model
int getVarHigh(int i);
int getVarRange(int i);
Values getConstantValues();
List<String> getSynchs();
String globalToLocal(long x);
int globalToLocal(long x, int l);
State convertBddToState(JDDNode dd);
@ -78,6 +79,7 @@ public interface Model
JDDNode getTransRewards();
JDDNode getTransRewards(int i);
JDDNode getTransRewards(String s);
JDDNode getTransActions();
JDDVars[] getVarDDRowVars();
JDDVars[] getVarDDColVars();
JDDVars getVarDDRowVars(int i);
@ -95,12 +97,14 @@ public interface Model
ODDNode getODD();
void setSynchs(Vector<String> synchs);
void resetTrans(JDDNode trans);
void resetTransRewards(int i, JDDNode transRewards);
void doReachability();
void doReachability(boolean extraReachInfo);
void skipReachability();
void setReach(JDDNode reach);
void setTransActions(JDDNode transActions);
void filterReachableStates();
void findDeadlocks();
void fixDeadlocks();

34
prism/src/prism/Modules2MTBDD.java

@ -77,6 +77,7 @@ public class Modules2MTBDD
private JDDNode start; // dd for start state
private JDDNode stateRewards[]; // dds for state rewards
private JDDNode transRewards[]; // dds for transition rewards
private JDDNode transActions; // dd for transition action labels
private JDDNode transInd; // dds for independent bits of trans
private JDDNode transSynch[]; // dds for synch action parts of trans
private JDDVars allDDRowVars; // all dd vars (rows)
@ -116,8 +117,10 @@ public class Modules2MTBDD
private int numModulesAfterSymm; // number of modules in the PRISM file after the symmetric ones
private int numSymmModules; // number of symmetric components
// hidden option - do we also store each part of the transition matrix separately?
private boolean storeTransParts = true;
// hidden option - do we also store each part of the transition matrix separately? (now defunct)
private boolean storeTransParts = false;
// hidden option - do we also store action info for the transition matrix? (supersedes the above)
private boolean storeTransActions = true;
// data structure used to store mtbdds and related info
// for some component of the whole model
@ -164,8 +167,9 @@ public class Modules2MTBDD
doSymmetry = !(s == null || s == "");
}
// main method - translate
@SuppressWarnings("unchecked") // for clone of vector in translate()
// main method - translate
public Model translate() throws PrismException
{
Model model = null;
@ -261,14 +265,21 @@ public class Modules2MTBDD
numVars, varList, varDDRowVars, varDDColVars, constantValues);
}
// We also store a copy of the list of action label names
model.setSynchs((Vector<String>)synchs.clone());
// For MDPs, we also store the DDs used to construct the part
// of the transition matrix that corresponds to each action
if (modelType == ModelType.MDP && storeTransParts) {
((NondetModel)model).setSynchs((Vector<String>)synchs.clone());
((NondetModel)model).setTransInd(transInd);
((NondetModel)model).setTransSynch(transSynch);
}
// if required, we also store info about action labels
if (storeTransActions) {
model.setTransActions(transActions);
}
// do reachability (or not)
if (prism.getDoReach()) {
mainLog.print("\nComputing reachable states...\n");
@ -838,6 +849,21 @@ public class Modules2MTBDD
}
}
// If required, we also build an MTBDD to store the action labels for each transition
if (storeTransActions) {
if (modelType == ModelType.MDP) {
transActions = JDD.Constant(0);
JDD.Ref(sysDDs.ind.trans);
tmp = JDD.ThereExists(JDD.GreaterThan(sysDDs.ind.trans, 0), allDDColVars);
transActions = JDD.Apply(JDD.PLUS, transActions, JDD.Apply(JDD.TIMES, tmp, JDD.Constant(1)));
for (i = 0; i < numSynchs; i++) {
JDD.Ref(sysDDs.synchs[i].trans);
tmp = JDD.ThereExists(JDD.GreaterThan(sysDDs.synchs[i].trans, 0), allDDColVars);
transActions = JDD.Apply(JDD.PLUS, transActions, JDD.Apply(JDD.TIMES, tmp, JDD.Constant(2+i)));
}
}
}
// deref bits of ComponentDD objects - we don't need them any more
JDD.Deref(sysDDs.ind.guards);
JDD.Deref(sysDDs.ind.trans);

18
prism/src/prism/NondetModel.java

@ -41,8 +41,6 @@ import sparse.*;
public class NondetModel extends ProbModel
{
// Extra info
protected Vector<String> synchs; // synchronising action labels
protected double numSynchs; // number of synchronising actions
protected double numChoices; // number of choices
// Extra dd stuff
@ -97,11 +95,6 @@ public class NondetModel extends ProbModel
return allDDNondetVars;
}
public Vector<String> getSynchs()
{
return synchs;
}
public JDDNode getTransInd()
{
return transInd;
@ -135,12 +128,6 @@ public class NondetModel extends ProbModel
// set methods for things not set up in constructor
public void setSynchs(Vector<String> synchs)
{
this.synchs = synchs;
this.numSynchs = synchs.size();
}
public void setTransInd(JDDNode transInd)
{
this.transInd = transInd;
@ -336,6 +323,11 @@ public class NondetModel extends ProbModel
}
}
}
if (transActions != null && !transActions.equals(JDD.ZERO)) {
log.print("Action label indices: ");
log.print(JDD.GetNumNodes(transActions) + " nodes (");
log.print(JDD.GetNumTerminals(transActions) + " terminal)\n");
}
}
// export transition matrix to a file

6
prism/src/prism/NondetModelChecker.java

@ -645,7 +645,7 @@ public class NondetModelChecker extends NonProbModelChecker
else {
// for fairness, we compute max here
try {
probs = computeUntilProbs(trans, trans01, newb1, newb2, min && !fairness);
probs = computeUntilProbs(trans, transActions, trans01, newb1, newb2, min && !fairness);
} catch (PrismException e) {
JDD.Deref(newb1);
JDD.Deref(newb2);
@ -934,7 +934,7 @@ public class NondetModelChecker extends NonProbModelChecker
// note: this function doesn't need to know anything about fairness
// it is just told whether to compute min or max probabilities
protected StateProbs computeUntilProbs(JDDNode tr, JDDNode tr01, JDDNode b1, JDDNode b2, boolean min)
protected StateProbs computeUntilProbs(JDDNode tr, JDDNode tra, JDDNode tr01, JDDNode b1, JDDNode b2, boolean min)
throws PrismException
{
JDDNode yes, no, maybe;
@ -1016,7 +1016,7 @@ public class NondetModelChecker extends NonProbModelChecker
probs = new StateProbsMTBDD(probsMTBDD, model);
break;
case Prism.SPARSE:
probsDV = PrismSparse.NondetUntil(tr, odd, allDDRowVars, allDDColVars, allDDNondetVars, yes, maybe,
probsDV = PrismSparse.NondetUntil(tr, tra, model.getSynchs(), odd, allDDRowVars, allDDColVars, allDDNondetVars, yes, maybe,
min);
probs = new StateProbsDV(probsDV, model);
break;

58
prism/src/prism/ProbModel.java

@ -27,8 +27,7 @@
package prism;
import java.io.*;
import java.util.BitSet;
import java.util.Vector;
import java.util.*;
import jdd.*;
import odd.*;
@ -51,6 +50,9 @@ public class ProbModel implements Model
protected VarList varList; // list of module variables
protected long[] gtol; // numbers for use by globalToLocal
protected Values constantValues; // values of constants
// actions
protected int numSynchs; // number of synchronising actions
protected Vector<String> synchs; // synchronising action labels
// rewards
protected int numRewardStructs; // number of reward structs
protected String[] rewardStructNames; // reward struct names
@ -71,6 +73,8 @@ public class ProbModel implements Model
protected JDDNode fixdl; // fixed deadlock states dd
protected JDDNode stateRewards[]; // state rewards dds
protected JDDNode transRewards[]; // transition rewards dds
protected JDDNode transActions; // dd for transition action labels
// dd vars
protected JDDVars[] varDDRowVars; // dd vars for each module variable (rows)
protected JDDVars[] varDDColVars; // dd vars for each module variable (cols)
@ -155,6 +159,14 @@ public class ProbModel implements Model
return constantValues;
}
/**
* Get vector of action label names.
*/
public List<String> getSynchs()
{
return synchs;
}
// rewards
public int getNumRewardStructs()
{
@ -277,6 +289,11 @@ public class ProbModel implements Model
return null;
}
public JDDNode getTransActions()
{
return transActions;
}
// dd vars
public JDDVars[] getVarDDRowVars()
{
@ -391,7 +408,10 @@ public class ProbModel implements Model
varDDRowVars = vrv;
varDDColVars = vcv;
constantValues = cv;
// action label info (optional) is initially null
transActions = null;
// compute numbers for globalToLocal converter
gtol = new long[numVars];
for (i = 0; i < numVars; i++) {
@ -409,6 +429,15 @@ public class ProbModel implements Model
numStartStates = JDD.GetNumMinterms(start, allDDRowVars.n());
}
/**
* Set vector of action label names.
*/
public void setSynchs(Vector<String> synchs)
{
this.synchs = synchs;
this.numSynchs = synchs.size();
}
/**
* Reset transition matrix DD
*/
@ -474,7 +503,15 @@ public class ProbModel implements Model
// build odd
odd = ODDUtils.BuildODD(reach, allDDRowVars);
}
/**
* Set the DD used to store transitoin action label indices.
*/
public void setTransActions(JDDNode transActions)
{
this.transActions = transActions;
}
// remove non-reachable states from various dds
// (and calculate num transitions)
@ -508,6 +545,13 @@ public class ProbModel implements Model
transRewards[i] = JDD.Apply(JDD.TIMES, tmp, transRewards[i]);
}
// Action label indices matrix
// (just filter rows here; subclasses, e.g. CTMCs, may do more subsequently)
if (transActions != null) {
JDD.Ref(reach);
transActions = JDD.Apply(JDD.TIMES, reach, transActions);
}
// filter start states, work out number of initial states
JDD.Ref(reach);
start = JDD.Apply(JDD.TIMES, reach, start);
@ -619,6 +663,11 @@ public class ProbModel implements Model
}
}
}
if (transActions != null && !transActions.equals(JDD.ZERO)) {
log.print("Action label indices: ");
log.print(JDD.GetNumNodes(transActions) + " nodes (");
log.print(JDD.GetNumTerminals(transActions) + " terminal)\n");
}
}
// export transition matrix to a file
@ -795,5 +844,6 @@ public class ProbModel implements Model
JDD.Deref(stateRewards[i]);
JDD.Deref(transRewards[i]);
}
if (transActions != null) JDD.Deref(transActions);
}
}

2
prism/src/prism/StateModelChecker.java

@ -53,6 +53,7 @@ public class StateModelChecker implements ModelChecker
protected VarList varList;
protected JDDNode trans;
protected JDDNode trans01;
protected JDDNode transActions;
protected JDDNode start;
protected JDDNode reach;
protected ODDNode odd;
@ -89,6 +90,7 @@ public class StateModelChecker implements ModelChecker
varList = model.getVarList();
trans = model.getTrans();
trans01 = model.getTrans01();
transActions = model.getTransActions();
start = model.getStart();
reach = model.getReach();
odd = model.getODD();

18
prism/src/prism/StochModel.java

@ -61,4 +61,22 @@ public class StochModel extends ProbModel
{
super(tr, s, sr, trr, rsn, arv, acv, ddvn, nm, mn, mrv, mcv, nv, vl, vrv, vcv, cv);
}
/**
* Remove non-reachable states from various DDs.
* Most of the work is done in the superclass (ProbModel);
* just a few extra things to do here.
*/
public void filterReachableStates()
{
super.filterReachableStates();
// If required, also filter columns of action label indices matrix
// (for the superclass - DTMCs - we only store info per state).
if (transActions != null) {
JDD.Ref(reach);
JDDNode tmp = JDD.PermuteVariables(reach, allDDRowVars, allDDColVars);
transActions = JDD.Apply(JDD.TIMES, tmp, transActions);
}
}
}

46
prism/src/prism/prism.cc

@ -29,6 +29,52 @@
#include "prism.h"
#include <stdio.h>
#include <math.h>
#include <new>
//------------------------------------------------------------------------------
// convert a list of strings (from java/jni) to an array of c strings.
// actually stores arrays of both jstring objects and c strings, and also size
// (because need these to free memory afterwards).
void get_string_array_from_java(JNIEnv *env, jobject strings_list, jstring *&strings_jstrings, const char **&strings, jint &size)
{
int i, j;
jclass vn_cls;
jmethodID vn_mid;
// get size of vector of strings
vn_cls = env->GetObjectClass(strings_list);
vn_mid = env->GetMethodID(vn_cls, "size", "()I");
if (vn_mid == 0) {
return;
}
size = env->CallIntMethod(strings_list,vn_mid);
// put strings from vector into array
strings_jstrings = new jstring[size];
strings = new const char*[size];
vn_mid = env->GetMethodID(vn_cls, "get", "(I)Ljava/lang/Object;");
if (vn_mid == 0) {
return;
}
for (i = 0; i < size; i++) {
strings_jstrings[i] = (jstring)env->CallObjectMethod(strings_list, vn_mid, i);
strings[i] = env->GetStringUTFChars(strings_jstrings[i], 0);
}
}
//------------------------------------------------------------------------------
// release the memory from a list of strings created by get_string_array_from_java
void release_string_array_from_java(JNIEnv *env, jstring *strings_jstrings, const char **strings, jint size)
{
// release memory
for (int i = 0; i < size; i++) {
env->ReleaseStringUTFChars(strings_jstrings[i], strings[i]);
}
delete[] strings_jstrings;
delete[] strings;
}
//------------------------------------------------------------------------------

53
prism/src/sparse/PS_NondetUntil.cc

@ -33,6 +33,7 @@
#include <odd.h>
#include <dv.h>
#include "sparse.h"
#include "prism.h"
#include "PrismSparseGlob.h"
#include "jnipointer.h"
#include <new>
@ -43,7 +44,9 @@ JNIEXPORT jlong __jlongpointer JNICALL Java_sparse_PrismSparse_PS_1NondetUntil
(
JNIEnv *env,
jclass cls,
jlong __jlongpointer t, // trans matrix
jlong __jlongpointer t, // trans matrix
jlong __jlongpointer ta, // trans action labels
jobject synchs,
jlong __jlongpointer od, // odd
jlong __jlongpointer rv, // row vars
jint num_rvars,
@ -51,22 +54,23 @@ jlong __jlongpointer cv, // col vars
jint num_cvars,
jlong __jlongpointer ndv, // nondet vars
jint num_ndvars,
jlong __jlongpointer y, // 'yes' states
jlong __jlongpointer m, // 'maybe' states
jboolean min // min or max probabilities (true = min, false = max)
jlong __jlongpointer y, // 'yes' states
jlong __jlongpointer m, // 'maybe' states
jboolean min // min or max probabilities (true = min, false = max)
)
{
// cast function parameters
DdNode *trans = jlong_to_DdNode(t); // trans matrix
ODDNode *odd = jlong_to_ODDNode(od); // reachable states
DdNode *trans = jlong_to_DdNode(t); // trans matrix
DdNode *trans_actions = jlong_to_DdNode(ta); // trans action labels
ODDNode *odd = jlong_to_ODDNode(od); // reachable states
DdNode **rvars = jlong_to_DdNode_array(rv); // row vars
DdNode **cvars = jlong_to_DdNode_array(cv); // col vars
DdNode **ndvars = jlong_to_DdNode_array(ndv); // nondet vars
DdNode *yes = jlong_to_DdNode(y); // 'yes' states
DdNode *maybe = jlong_to_DdNode(m); // 'maybe' states
DdNode *yes = jlong_to_DdNode(y); // 'yes' states
DdNode *maybe = jlong_to_DdNode(m); // 'maybe' states
// mtbdds
DdNode *a = NULL;
DdNode *a = NULL, *tmp = NULL;
// model stats
int n, nc;
long nnz;
@ -81,6 +85,10 @@ jboolean min // min or max probabilities (true = min, false = max)
bool adv = false, adv_loop = false;
FILE *fp_adv = NULL;
int adv_l, adv_h;
int *actions;
jstring *action_names_jstrings;
const char** action_names;
int num_actions;
// misc
int i, j, k, l1, h1, l2, h2, iters;
double d1, d2, kb, kbt;
@ -112,6 +120,25 @@ jboolean min // min or max probabilities (true = min, false = max)
PS_PrintToMainLog(env, "[n=%d, nc=%d, nnz=%d, k=%d] ", n, nc, nnz, ndsm->k);
PS_PrintMemoryToMainLog(env, "[", kb, "]\n");
// if needed, and if info is available, build a vector of action indices for the mdp
if (adv && trans_actions != NULL) {
PS_PrintToMainLog(env, "Building action information... ");
// first need to filter out unwanted rows
Cudd_Ref(trans_actions);
Cudd_Ref(maybe);
tmp = DD_Apply(ddman, APPLY_TIMES, trans_actions, maybe);
// then convert to a vector of integer indices
actions = build_nd_action_vector(ddman, tmp, ndsm, rvars, num_rvars, ndvars, num_ndvars, odd);
Cudd_RecursiveDeref(ddman, tmp);
kb = n*4.0/1024.0;
kbt += kb;
PS_PrintMemoryToMainLog(env, "[", kb, "]\n");
} else {
actions = NULL;
}
// also extract list of action name
get_string_array_from_java(env, synchs, action_names_jstrings, action_names, num_actions);
// get vector for yes
PS_PrintToMainLog(env, "Creating vector for yes... ");
yes_vec = mtbdd_to_double_vector(ddman, yes, rvars, num_rvars, odd);
@ -193,7 +220,11 @@ jboolean min // min or max probabilities (true = min, false = max)
soln2[i] = (h1 > l1) ? d1 : yes_vec[i];
// store adversary info (if required)
if (adv_loop) if (h1 > l1)
for (k = adv_l; k < adv_h; k++) fprintf(fp_adv, "%d %d %g\n", i, cols[k], non_zeros[k]);
for (k = adv_l; k < adv_h; k++) {
fprintf(fp_adv, "%d %d %g", i, cols[k], non_zeros[k]);
if (actions != NULL) fprintf(fp_adv, " %s", actions[l1]>1?action_names[actions[l1]-2]:"");
fprintf(fp_adv, "\n");
}
}
// check convergence
@ -260,8 +291,10 @@ jboolean min // min or max probabilities (true = min, false = max)
if (ndsm) delete ndsm;
if (yes_vec) delete[] yes_vec;
if (soln2) delete[] soln2;
release_string_array_from_java(env, action_names_jstrings, action_names, num_actions);
return ptr_to_jlong(soln);
}
//------------------------------------------------------------------------------

7
prism/src/sparse/PrismSparse.java

@ -27,6 +27,7 @@
package sparse;
import java.io.FileNotFoundException;
import java.util.List;
import prism.*;
import jdd.*;
@ -255,10 +256,10 @@ public class PrismSparse
}
// pctl until (nondeterministic/mdp)
private static native long PS_NondetUntil(long trans, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax);
public static DoubleVector NondetUntil(JDDNode trans, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax) throws PrismException
private static native long PS_NondetUntil(long trans, long trans_actions, List<String> synchs, long odd, long rv, int nrv, long cv, int ncv, long ndv, int nndv, long yes, long maybe, boolean minmax);
public static DoubleVector NondetUntil(JDDNode trans, JDDNode transActions, List<String> synchs, ODDNode odd, JDDVars rows, JDDVars cols, JDDVars nondet, JDDNode yes, JDDNode maybe, boolean minmax) throws PrismException
{
long ptr = PS_NondetUntil(trans.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax);
long ptr = PS_NondetUntil(trans.ptr(), (transActions == null) ? 0 : transActions.ptr(), synchs, odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), nondet.array(), nondet.n(), yes.ptr(), maybe.ptr(), minmax);
if (ptr == 0) throw new PrismException(getErrorMessage());
return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff()));
}

94
prism/src/sparse/sparse.cc

@ -43,6 +43,7 @@ static void traverse_mtbdd_vect_rec(DdManager *ddman, DdNode *dd, DdNode **vars,
// global variables (used by local functions)
static int count;
static int *starts, *starts2;
static int *actions;
static RMSparseMatrix *rmsm;
static CMSparseMatrix *cmsm;
static RCSparseMatrix *rcsm;
@ -668,7 +669,7 @@ NDSparseMatrix *build_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode **r
// try/catch for memory allocation/deallocation
} catch(std::bad_alloc e) {
if (ndsm) delete cmscsm;
if (ndsm) delete ndsm;
if (matrices) delete[] matrices;
if (matrices_bdds) {
for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]);
@ -808,7 +809,7 @@ NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode
// try/catch for memory allocation/deallocation
} catch(std::bad_alloc e) {
if (ndsm) delete cmscsm;
if (ndsm) delete ndsm;
if (matrices) delete[] matrices;
if (matrices_bdds) {
for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]);
@ -833,6 +834,88 @@ NDSparseMatrix *build_sub_nd_sparse_matrix(DdManager *ddman, DdNode *mdp, DdNode
//------------------------------------------------------------------------------
// build nondeterministic (mdp) action vector to accompany a sparse matrix
// (i.e. a vector containing for every state and nondet choice, an index
// into the list of all action labels)
// throws std::bad_alloc on out-of-memory
int *build_nd_action_vector(DdManager *ddman, DdNode *trans_actions, NDSparseMatrix *mdp_ndsm, DdNode **rvars, int num_vars, DdNode **ndvars, int num_ndvars, ODDNode *odd)
{
int i, n, nm, nc;
DdNode *tmp = NULL, **matrices = NULL, **matrices_bdds = NULL;
// try/catch for memory allocation/deallocation
try {
// get stats from mdp sparse storage (num states/choices)
n = mdp_ndsm->n;
nc = mdp_ndsm->nc;
// break the mtbdd storing the action info into several (nm) mtbdds
// (this number nm should match the figure for the corresponding mdp;
// but we can't do this sanity check because that statistic is not retained.)
Cudd_Ref(trans_actions);
tmp = DD_Not(ddman, DD_Equals(ddman, trans_actions, 0));
tmp = DD_ThereExists(ddman, tmp, rvars, num_vars);
nm = (int)DD_GetNumMinterms(ddman, tmp, num_ndvars);
Cudd_RecursiveDeref(ddman, tmp);
matrices = new DdNode*[nm];
count = 0;
split_mdp_rec(ddman, trans_actions, ndvars, num_ndvars, 0, matrices);
// and for each one create a bdd storing which rows are non-empty
matrices_bdds = new DdNode*[nm];
for (i = 0; i < nm; i++) {
Cudd_Ref(matrices[i]);
matrices_bdds[i] = DD_Not(ddman, DD_Equals(ddman, matrices[i], 0));
}
// create arrays
actions = NULL; actions = new int[nc];
starts = NULL; starts = new int[n+1];
// build the (temporary) array 'starts' (like was done when building the sparse matrix for the mdp).
// in fact, this information is retrievable from the sparse matrix, but it may have
// been converted to counts, rather than offsets, so its easier to rebuild it.
// first traverse mtbdds to compute how many choices are in each row
for (i = 0; i < n+1; i++) starts[i] = 0;
for (i = 0; i < nm; i++) {
traverse_mtbdd_vect_rec(ddman, matrices_bdds[i], rvars, num_vars, 0, odd, 0, 1);
}
// and use this to compute the starts information
for (i = 1 ; i < n+1; i++) {
starts[i] += starts[i-1];
}
// now traverse the mtbdd to get the actual entries (action indices)
for (i = 0; i < nm; i++) {
traverse_mtbdd_vect_rec(ddman, matrices[i], rvars, num_vars, 0, odd, 0, 3);
}
// try/catch for memory allocation/deallocation
} catch(std::bad_alloc e) {
if (actions) delete[] actions;
if (matrices) delete[] matrices;
if (matrices_bdds) {
for (i = 0; i < nm; i++) Cudd_RecursiveDeref(ddman, matrices_bdds[i]);
delete[] matrices_bdds;
}
if (starts) delete[] starts;
throw e;
}
// clear up memory
for (i = 0; i < nm; i++) {
Cudd_RecursiveDeref(ddman, matrices_bdds[i]);
// nb: don't deref matrices array because that was just pointers, not new copies
}
delete[] starts;
delete[] matrices;
delete[] matrices_bdds;
return actions;
}
//------------------------------------------------------------------------------
void split_mdp_rec(DdManager *ddman, DdNode *dd, DdNode **ndvars, int num_ndvars, int level, DdNode **matrices)
{
DdNode *e, *t;
@ -1062,7 +1145,14 @@ void traverse_mtbdd_vect_rec(DdManager *ddman, DdNode *dd, DdNode **vars, int nu
case 2:
starts[i]++;
break;
// mdp action vector - single pass
case 3:
actions[starts[i]] = (int)Cudd_V(dd);
starts[i]++;
break;
}
return;
}

Loading…
Cancel
Save