diff --git a/prism/src/hybrid/PH_StochTransient.cc b/prism/src/hybrid/PH_StochTransient.cc index dc79eb25..007ff1e5 100644 --- a/prism/src/hybrid/PH_StochTransient.cc +++ b/prism/src/hybrid/PH_StochTransient.cc @@ -62,7 +62,7 @@ JNIEnv *env, jclass cls, jlong __jlongpointer tr, // trans matrix jlong __jlongpointer od, // odd -jlong __jlongpointer in, // initial distribution +jlong __jlongpointer in, // initial distribution (note: this will be deleted afterwards) jlong __jlongpointer rv, // row vars jint num_rvars, jlong __jlongpointer cv, // col vars @@ -73,7 +73,7 @@ jdouble time // time bound // cast function parameters DdNode *trans = jlong_to_DdNode(tr); // trans matrix ODDNode *odd = jlong_to_ODDNode(od); // odd - DdNode *init = jlong_to_DdNode(in); // initial distribution + double *init = jlong_to_double(in); // initial distribution DdNode **rvars = jlong_to_DdNode_array(rv); // row vars DdNode **cvars = jlong_to_DdNode_array(cv); // col vars @@ -171,7 +171,9 @@ jdouble time // time bound // create solution/iteration vectors PH_PrintToMainLog(env, "Allocating iteration vectors... "); - soln = mtbdd_to_double_vector(ddman, init, rvars, num_rvars, odd); + // for soln, we just use init (since we are free to modify/delete this vector) + // we also report the memory usage of this vector here, even though it has already been created + soln = init; soln2 = new double[n]; sum = new double[n]; kb = n*8.0/1024.0; @@ -306,6 +308,7 @@ jdouble time // time bound if (hddm) delete hddm; if (diags) delete[] diags; if (diags_dist) delete diags_dist; + // nb: we *do* free soln (which was originally init) if (soln) delete[] soln; if (soln2) delete[] soln2; diff --git a/prism/src/hybrid/PrismHybrid.java b/prism/src/hybrid/PrismHybrid.java index 5717f5e3..b3152b8b 100644 --- a/prism/src/hybrid/PrismHybrid.java +++ b/prism/src/hybrid/PrismHybrid.java @@ -346,9 +346,9 @@ public class PrismHybrid // transient (stochastic/ctmc) private static native long PH_StochTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, double time); - public static DoubleVector StochTransient(JDDNode trans, ODDNode odd, JDDNode init, JDDVars rows, JDDVars cols, double time) throws PrismException + public static DoubleVector StochTransient(JDDNode trans, ODDNode odd, DoubleVector init, JDDVars rows, JDDVars cols, double time) throws PrismException { - long ptr = PH_StochTransient(trans.ptr(), odd.ptr(), init.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); + long ptr = PH_StochTransient(trans.ptr(), odd.ptr(), init.getPtr(), rows.array(), rows.n(), cols.array(), cols.n(), time); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); } diff --git a/prism/src/mtbdd/PM_StochTransient.cc b/prism/src/mtbdd/PM_StochTransient.cc index 5ddb6e85..e36a0ba2 100644 --- a/prism/src/mtbdd/PM_StochTransient.cc +++ b/prism/src/mtbdd/PM_StochTransient.cc @@ -43,7 +43,7 @@ JNIEnv *env, jclass cls, jlong __jlongpointer tr, // rate matrix jlong __jlongpointer od, // odd -jlong __jlongpointer in, // initial distribution +jlong __jlongpointer in, // initial distribution (note: this will be derefed afterwards) jlong __jlongpointer rv, // row vars jint num_rvars, jlong __jlongpointer cv, // col vars @@ -307,6 +307,8 @@ jdouble time // time } Cudd_RecursiveDeref(ddman, diags); Cudd_RecursiveDeref(ddman, sol); + // nb: we deref init, even though it is passed in as a param + Cudd_RecursiveDeref(ddman, init); return ptr_to_jlong(sum); } diff --git a/prism/src/prism/Prism.java b/prism/src/prism/Prism.java index 83dd0547..dc63ed23 100644 --- a/prism/src/prism/Prism.java +++ b/prism/src/prism/Prism.java @@ -1455,15 +1455,36 @@ public class Prism implements PrismSettingsListener mainLog.println("\nTime for steady-state probability computation: " + l/1000.0 + " seconds."); } - // do transient - + /** + * Compute transient probabilities (for DTMC or CTMC). + * Output probability distribution to log. + */ public void doTransient(Model model, double time) throws PrismException + { + doTransient(model, time, EXPORT_PLAIN, null); + } + + /** + * Compute transient probabilities (for DTMC or CTMC). + * Output probability distribution to a file (or, if file is null, to log). + * The exportType should be EXPORT_PLAIN or EXPORT_MATLAB. + */ + public void doTransient(Model model, double time, int exportType, File file) throws PrismException { long l = 0; // timer StateProbs probs = null; + PrismLog tmpLog; if (time < 0) throw new PrismException("Cannot compute transient probabilities for negative time value"); + if (file != null && getEngine() == MTBDD) + throw new PrismException("Transient probability export only supported for sparse/hybrid engines"); + + // no specific states format for MRMC + if (exportType == EXPORT_MRMC) exportType = EXPORT_PLAIN; + // rows format does not apply to states output + if (exportType == EXPORT_ROWS) exportType = EXPORT_PLAIN; + // create new model checker object mc = new StochModelChecker(this, model, null); @@ -1485,13 +1506,33 @@ public class Prism implements PrismSettingsListener l = System.currentTimeMillis() - l; - // print out probabilities - mainLog.print("\nProbabilities: \n"); - probs.print(mainLog); - probs.clear(); + // print message + mainLog.print("\nPrinting transient probabilities "); + switch (exportType) { + case EXPORT_PLAIN: mainLog.print("in plain text format "); break; + case EXPORT_MATLAB: mainLog.print("in Matlab format "); break; + } + if (file != null) mainLog.println("to file \"" + file + "\"..."); else mainLog.println("below:"); + + // create new file log or use main log + if (file != null) { + tmpLog = new PrismFileLog(file.getPath()); + if (!tmpLog.ready()) { + throw new PrismException("Could not open file \"" + file + "\" for output"); + } + } else { + tmpLog = mainLog; + } + + // print out or export probabilities + probs.print(tmpLog, file == null, exportType == EXPORT_MATLAB, file == null); // print out model checking time mainLog.println("\nTime for transient probability computation: " + l/1000.0 + " seconds."); + + // tidy up + probs.clear(); + if (file != null) tmpLog.close(); } // clear up and close down diff --git a/prism/src/prism/PrismCL.java b/prism/src/prism/PrismCL.java index 5996daac..0e79d980 100644 --- a/prism/src/prism/PrismCL.java +++ b/prism/src/prism/PrismCL.java @@ -92,6 +92,7 @@ public class PrismCL private String exportTransDotStatesFilename = null; private String exportBSCCsFilename = null; private String exportResultsFilename = null; + private String exportTransientFilename = null; private String exportPrismFilename = null; private String simpathFilename = null; @@ -291,8 +292,8 @@ public class PrismCL try { doTransient(); } + // in case of error, report it and proceed catch (PrismException e) { - // in case of error, report it and proceed error(e.getMessage()); } } @@ -773,6 +774,13 @@ public class PrismCL { double d; int i; + File exportTransientFile = null; + + // choose destination for output (file or log) + if (exportTransientFilename == null || exportTransientFilename.equals("stdout")) + exportTransientFile = null; + else + exportTransientFile =new File(exportTransientFilename); // compute transient probabilities if (model.getModelType() == ModelType.CTMC) { @@ -782,7 +790,7 @@ public class PrismCL catch (NumberFormatException e) { throw new PrismException("Invalid value \""+transientTime+"\" for transient probability computation"); } - prism.doTransient(model, d); + prism.doTransient(model, d, exportType, exportTransientFile); } else if (model.getModelType() == ModelType.DTMC) { try { @@ -791,7 +799,7 @@ public class PrismCL catch (NumberFormatException e) { throw new PrismException("Invalid value \""+transientTime+"\" for transient probability computation"); } - prism.doTransient(model, i); + prism.doTransient(model, i, exportType, exportTransientFile); } else { mainLog.println("\nWarning: Transient probabilities only computed for DTMCs/CTMCs."); @@ -1012,6 +1020,15 @@ public class PrismCL errorAndExit("No file specified for -"+sw+" switch"); } } + // export transient probs (as opposed to displaying on screen) + else if (sw.equals("exporttransient") || sw.equals("exporttr")) { + if (i < args.length-1) { + exportTransientFilename = args[++i]; + } + else { + errorAndExit("No file specified for -"+sw+" switch"); + } + } // switch export mode to "matlab" else if (sw.equals("exportmatlab")) { exportType = Prism.EXPORT_MATLAB; @@ -1684,6 +1701,7 @@ public class PrismCL mainLog.println("-exporttransdotstates ... Export the transition matrix graph to a dot file, with state info"); mainLog.println("-exportdot .............. Export the transition matrix MTBDD to a dot file"); mainLog.println("-exportbsccs ............ Compute and export all BSCCs of the model"); + mainLog.println("-exporttransient ......... Export transient probabilities to a file"); mainLog.println("-exportprism ............ Export final PRISM model to a file"); mainLog.println(); mainLog.println("-mtbdd (or -m) ................. Use the MTBDD engine"); diff --git a/prism/src/prism/StochModelChecker.java b/prism/src/prism/StochModelChecker.java index 9781e0d3..b4e94f82 100644 --- a/prism/src/prism/StochModelChecker.java +++ b/prism/src/prism/StochModelChecker.java @@ -312,39 +312,52 @@ public class StochModelChecker extends ProbModelChecker // do transient computation // ----------------------------------------------------------------------------------- - // transient computation (from initial states) - + /** + * Compute transient probability distribution (forwards). + * Start from initial state (or uniform distribution over multiple initial states). + */ public StateProbs doTransient(double time) throws PrismException + { + return doTransient(time, null); + } + + /** + * Compute transient probability distribution (forwards). + * Optionally, use the passed in vector initDist as the initial probability distribution (time 0). + * If null, start from initial state (or uniform distribution over multiple initial states). + * For reasons of efficiency, when a vector is passed in, it will be trampled over and + * then deleted afterwards, so if you wanted it, take a copy. + */ + public StateProbs doTransient(double time, StateProbs initDist) throws PrismException { // mtbdd stuff JDDNode start, init; // other stuff - StateProbs probs = null; - - // get initial states of model - start = model.getStart(); - - // and hence compute initial probability distribution (equiprobable over - // all start states) - JDD.Ref(start); - init = JDD.Apply(JDD.DIVIDE, start, JDD.Constant(JDD.GetNumMinterms(start, allDDRowVars.n()))); - - // compute transient probabilities - try { - // special case: time = 0 - if (time == 0.0) { - JDD.Ref(init); - probs = new StateProbsMTBDD(init, model); - } else { - probs = computeTransientProbs(trans, init, time); + StateProbs initDistNew = null, probs = null; + + // build initial distribution (if not specified) + if (initDist == null) { + // first construct as MTBDD + // get initial states of model + start = model.getStart(); + // compute initial probability distribution (equiprobable over all start states) + JDD.Ref(start); + init = JDD.Apply(JDD.DIVIDE, start, JDD.Constant(JDD.GetNumMinterms(start, allDDRowVars.n()))); + // if using MTBDD engine, distribution needs to be an MTBDD + if (engine == Prism.MTBDD) { + initDistNew = new StateProbsMTBDD(init, model); } - } catch (PrismException e) { - JDD.Deref(init); - throw e; + // for sparse/hybrid engines, distribution needs to be a double vector + else { + initDistNew = new StateProbsDV(init, model); + JDD.Deref(init); + } + } else { + initDistNew = initDist; } - - // derefs - JDD.Deref(init); + + // compute transient probabilities + probs = computeTransientProbs(trans, initDistNew, time); return probs; } @@ -546,26 +559,39 @@ public class StochModelChecker extends ProbModelChecker return rewards; } - // compute transient probabilities - - protected StateProbs computeTransientProbs(JDDNode tr, JDDNode init, double time) throws PrismException + /** + * Compute transient probability distribution (forwards). + * Use the passed in vector initDist as the initial probability distribution (time 0). + * The type of this should match the current engine + * (i.e. StateProbsMTBDD for MTBDD, StateProbsDV for sparse/hybrid). + * For reasons of efficiency, this vector will be trampled over and + * then deleted afterwards, so if you wanted it, take a copy. + */ + protected StateProbs computeTransientProbs(JDDNode tr, StateProbs initDist, double time) throws PrismException { JDDNode probsMTBDD; DoubleVector probsDV; StateProbs probs = null; + // special case: time = 0 + if (time == 0.0) { + // we are allowed to keep the init vector, so no need to clone + return initDist; + } + + // general case try { switch (engine) { case Prism.MTBDD: - probsMTBDD = PrismMTBDD.StochTransient(tr, odd, init, allDDRowVars, allDDColVars, time); + probsMTBDD = PrismMTBDD.StochTransient(tr, odd, ((StateProbsMTBDD) initDist).getJDDNode(), allDDRowVars, allDDColVars, time); probs = new StateProbsMTBDD(probsMTBDD, model); break; case Prism.SPARSE: - probsDV = PrismSparse.StochTransient(tr, odd, init, allDDRowVars, allDDColVars, time); + probsDV = PrismSparse.StochTransient(tr, odd, ((StateProbsDV) initDist).getDoubleVector(), allDDRowVars, allDDColVars, time); probs = new StateProbsDV(probsDV, model); break; case Prism.HYBRID: - probsDV = PrismHybrid.StochTransient(tr, odd, init, allDDRowVars, allDDColVars, time); + probsDV = PrismHybrid.StochTransient(tr, odd, ((StateProbsDV) initDist).getDoubleVector(), allDDRowVars, allDDColVars, time); probs = new StateProbsDV(probsDV, model); break; default: diff --git a/prism/src/sparse/PS_StochTransient.cc b/prism/src/sparse/PS_StochTransient.cc index 3665c0fd..0bfc420a 100644 --- a/prism/src/sparse/PS_StochTransient.cc +++ b/prism/src/sparse/PS_StochTransient.cc @@ -46,7 +46,7 @@ JNIEnv *env, jclass cls, jlong __jlongpointer tr, // trans matrix jlong __jlongpointer od, // odd -jlong __jlongpointer in, // initial distribution +jlong __jlongpointer in, // initial distribution (note: this will be deleted afterwards) jlong __jlongpointer rv, // row vars jint num_rvars, jlong __jlongpointer cv, // col vars @@ -57,7 +57,7 @@ jdouble time // time bound // cast function parameters DdNode *trans = jlong_to_DdNode(tr); // trans matrix ODDNode *odd = jlong_to_ODDNode(od); // odd - DdNode *init = jlong_to_DdNode(in); // initial distribution + double *init = jlong_to_double(in); // initial distribution DdNode **rvars = jlong_to_DdNode_array(rv); // row vars DdNode **cvars = jlong_to_DdNode_array(cv); // col vars @@ -158,7 +158,9 @@ jdouble time // time bound // create solution/iteration vectors PS_PrintToMainLog(env, "Allocating iteration vectors... "); - soln = mtbdd_to_double_vector(ddman, init, rvars, num_rvars, odd); + // for soln, we just use init (since we are free to modify/delete this vector) + // we also report the memory usage of this vector here, even though it has already been created + soln = init; soln2 = new double[n]; sum = new double[n]; kb = n*8.0/1024.0; @@ -331,6 +333,7 @@ jdouble time // time bound if (cmsm) delete cmsm; if (diags) delete[] diags; if (diags_dist) delete diags_dist; + // nb: we *do* free soln (which was originally init) if (soln) delete[] soln; if (soln2) delete[] soln2; diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java index 32b51c88..b14b0bf1 100644 --- a/prism/src/sparse/PrismSparse.java +++ b/prism/src/sparse/PrismSparse.java @@ -316,9 +316,9 @@ public class PrismSparse // transient (stochastic/ctmc) private static native long PS_StochTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, double time); - public static DoubleVector StochTransient(JDDNode trans, ODDNode odd, JDDNode init, JDDVars rows, JDDVars cols, double time) throws PrismException + public static DoubleVector StochTransient(JDDNode trans, ODDNode odd, DoubleVector init, JDDVars rows, JDDVars cols, double time) throws PrismException { - long ptr = PS_StochTransient(trans.ptr(), odd.ptr(), init.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); + long ptr = PS_StochTransient(trans.ptr(), odd.ptr(), init.getPtr(), rows.array(), rows.n(), cols.array(), cols.n(), time); if (ptr == 0) throw new PrismException(getErrorMessage()); return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); }