From 32086274a2d1de0803d2e1e3f01f769894cfe0e1 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Thu, 3 Apr 2008 11:39:06 +0000 Subject: [PATCH] Added transient probabilities computation for DTMCs. git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@720 bbc10eb1-c90d-0410-af57-cb519fbb1720 --- prism/include/PrismHybrid.h | 8 + prism/include/PrismMTBDD.h | 8 + prism/include/PrismSparse.h | 8 + prism/src/hybrid/PH_ProbTransient.cc | 294 ++++++++++++++++++ prism/src/hybrid/PrismHybrid.java | 9 + prism/src/mtbdd/PM_ProbTransient.cc | 136 ++++++++ prism/src/mtbdd/PM_StochTransient.cc | 2 +- prism/src/mtbdd/PrismMTBDD.java | 9 + prism/src/prism/Prism.java | 37 ++- prism/src/prism/PrismCL.java | 32 +- prism/src/prism/ProbModelChecker.java | 73 +++++ prism/src/sparse/PS_ProbTransient.cc | 229 ++++++++++++++ prism/src/sparse/PrismSparse.java | 9 + .../userinterface/model/GUIMultiModel.java | 2 +- .../computation/ComputeTransientThread.java | 2 +- 15 files changed, 841 insertions(+), 17 deletions(-) create mode 100644 prism/src/hybrid/PH_ProbTransient.cc create mode 100644 prism/src/mtbdd/PM_ProbTransient.cc create mode 100644 prism/src/sparse/PS_ProbTransient.cc diff --git a/prism/include/PrismHybrid.h b/prism/include/PrismHybrid.h index d99ed503..dafe6d3b 100644 --- a/prism/include/PrismHybrid.h +++ b/prism/include/PrismHybrid.h @@ -183,6 +183,14 @@ JNIEXPORT jlong JNICALL Java_hybrid_PrismHybrid_PH_1ProbInstReward JNIEXPORT jlong JNICALL Java_hybrid_PrismHybrid_PH_1ProbReachReward (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jlong, jint, jlong, jint, jlong, jlong, jlong); +/* + * Class: hybrid_PrismHybrid + * Method: PH_ProbTransient + * Signature: (JJJJIJII)J + */ +JNIEXPORT jlong JNICALL Java_hybrid_PrismHybrid_PH_1ProbTransient + (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jint, jlong, jint, jint); + /* * Class: hybrid_PrismHybrid * Method: PH_NondetBoundedUntil diff --git a/prism/include/PrismMTBDD.h b/prism/include/PrismMTBDD.h index 3e87a1db..5d5e1791 100644 --- a/prism/include/PrismMTBDD.h +++ b/prism/include/PrismMTBDD.h @@ -191,6 +191,14 @@ JNIEXPORT jlong JNICALL Java_mtbdd_PrismMTBDD_PM_1ProbInstReward JNIEXPORT jlong JNICALL Java_mtbdd_PrismMTBDD_PM_1ProbReachReward (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jlong, jint, jlong, jint, jlong, jlong, jlong); +/* + * Class: mtbdd_PrismMTBDD + * Method: PM_ProbTransient + * Signature: (JJJJIJII)J + */ +JNIEXPORT jlong JNICALL Java_mtbdd_PrismMTBDD_PM_1ProbTransient + (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jint, jlong, jint, jint); + /* * Class: mtbdd_PrismMTBDD * Method: PM_NondetBoundedUntil diff --git a/prism/include/PrismSparse.h b/prism/include/PrismSparse.h index e2bfef9f..947a7e99 100644 --- a/prism/include/PrismSparse.h +++ b/prism/include/PrismSparse.h @@ -143,6 +143,14 @@ JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1ProbInstReward JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1ProbReachReward (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jlong, jint, jlong, jint, jlong, jlong, jlong); +/* + * Class: sparse_PrismSparse + * Method: PS_ProbTransient + * Signature: (JJJJIJII)J + */ +JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1ProbTransient + (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jint, jlong, jint, jint); + /* * Class: sparse_PrismSparse * Method: PS_NondetBoundedUntil diff --git a/prism/src/hybrid/PH_ProbTransient.cc b/prism/src/hybrid/PH_ProbTransient.cc new file mode 100644 index 00000000..44dcb270 --- /dev/null +++ b/prism/src/hybrid/PH_ProbTransient.cc @@ -0,0 +1,294 @@ +//============================================================================== +// +// Copyright (c) 2002- +// Authors: +// * Dave Parker (University of Oxford, formerly University of Birmingham) +// +//------------------------------------------------------------------------------ +// +// This file is part of PRISM. +// +// PRISM is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// PRISM is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with PRISM; if not, write to the Free Software Foundation, +// Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +//============================================================================== + +// includes +#include "PrismHybrid.h" +#include +#include +#include +#include +#include +#include +#include +#include "sparse.h" +#include "hybrid.h" +#include "PrismHybridGlob.h" +#include "jnipointer.h" + +// local prototypes +static void mult_rec(HDDNode *hdd, int level, int row_offset, int col_offset); +static void mult_cm(CMSparseMatrix *cmsm, int row_offset, int col_offset); +static void mult_cmsc(CMSCSparseMatrix *cmscsm, int row_offset, int col_offset); + +// globals (used by local functions) +static HDDNode *zero; +static int num_levels; +static bool compact_sm; +static double *sm_dist; +static int sm_dist_shift; +static int sm_dist_mask; +static double *soln, *soln2; + +//------------------------------------------------------------------------------ + +JNIEXPORT jlong __pointer JNICALL Java_hybrid_PrismHybrid_PH_1ProbTransient +( +JNIEnv *env, +jclass cls, +jlong __pointer tr, // trans matrix +jlong __pointer od, // odd +jlong __pointer in, // initial distribution +jlong __pointer rv, // row vars +jint num_rvars, +jlong __pointer cv, // col vars +jint num_cvars, +jint time // time +) +{ + // cast function parameters + DdNode *trans = jlong_to_DdNode(tr); // trans matrix + ODDNode *odd = jlong_to_ODDNode(od); // odd + DdNode *init = jlong_to_DdNode(in); // initial distribution + DdNode **rvars = jlong_to_DdNode_array(rv); // row vars + DdNode **cvars = jlong_to_DdNode_array(cv); // col vars + + // model stats + int n; + // matrix mtbdd + HDDMatrix *hddm; + HDDNode *hdd; + // vectors + double *tmpsoln, *sum; + // timing stuff + long start1, start2, start3, stop; + double time_taken, time_for_setup, time_for_iters; + // misc + bool done; + int i, iters; + double kb, kbt; + + // start clocks + start1 = start2 = util_cpu_time(); + + // get number of states from odd + n = odd->eoff + odd->toff; + + // build hdd for matrix + PH_PrintToMainLog(env, "\nBuilding hybrid MTBDD matrix... "); + hddm = build_hdd_matrix(trans, rvars, cvars, num_rvars, odd, false); + hdd = hddm->top; + zero = hddm->zero; + num_levels = hddm->num_levels; + kb = hddm->mem_nodes; + kbt = kb; + PH_PrintToMainLog(env, "[levels=%d, nodes=%d] [%.1f KB]\n", hddm->num_levels, hddm->num_nodes, kb); + + // add sparse matrices + PH_PrintToMainLog(env, "Adding explicit sparse matrices... "); + add_sparse_matrices(hddm, compact, false); + compact_sm = hddm->compact_sm; + if (compact_sm) { + sm_dist = hddm->dist; + sm_dist_shift = hddm->dist_shift; + sm_dist_mask = hddm->dist_mask; + } + kb = hddm->mem_sm; + kbt += kb; + PH_PrintToMainLog(env, "[levels=%d, num=%d%s] [%.1f KB]\n", hddm->l_sm, hddm->num_sm, compact_sm?", compact":"", kb); + + // create solution/iteration vectors + PH_PrintToMainLog(env, "Allocating iteration vectors... "); + soln = mtbdd_to_double_vector(ddman, init, rvars, num_rvars, odd); + soln2 = new double[n]; + sum = new double[n]; + kb = n*8.0/1024.0; + kbt += 3*kb; + PH_PrintToMainLog(env, "[3 x %.1f KB]\n", kb); + + // print total memory usage + PH_PrintToMainLog(env, "TOTAL: [%.1f KB]\n", kbt); + + // get setup time + stop = util_cpu_time(); + time_for_setup = (double)(stop - start2)/1000; + start2 = stop; + + // start transient analysis + iters = 0; + done = false; + PH_PrintToMainLog(env, "\nStarting iterations...\n"); + + // note that we ignore max_iters as we know how any iterations _should_ be performed + for (iters = 0; iters < time && !done; iters++) { + +// PH_PrintToMainLog(env, "Iteration %d: ", iters); +// start3 = util_cpu_time(); + + // initialise vector + for (i = 0; i < n; i++) soln2[i] = 0.0; + + // do matrix vector multiply bit + mult_rec(hdd, 0, 0, 0); + + // check for steady state convergence + if (do_ss_detect) switch (term_crit) { + case TERM_CRIT_ABSOLUTE: + done = true; + for (i = 0; i < n; i++) { + if (fabs(soln2[i] - soln[i]) > term_crit_param) { + done = false; + break; + } + } + break; + case TERM_CRIT_RELATIVE: + done = true; + for (i = 0; i < n; i++) { + if (fabs((soln2[i] - soln[i])/soln2[i]) > term_crit_param) { + done = false; + break; + } + } + break; + } + + // prepare for next iteration + tmpsoln = soln; + soln = soln2; + soln2 = tmpsoln; + +// PH_PrintToMainLog(env, "%.2f %.2f sec\n", ((double)(util_cpu_time() - start3)/1000), ((double)(util_cpu_time() - start2)/1000)/iters); + } + + // stop clocks + stop = util_cpu_time(); + time_for_iters = (double)(stop - start2)/1000; + time_taken = (double)(stop - start1)/1000; + + // print iters/timing info + if (done) PH_PrintToMainLog(env, "\nSteady state detected at iteration %d\n", iters); + PH_PrintToMainLog(env, "\nIterative method: %d iterations in %.2f seconds (average %.6f, setup %.2f)\n", iters, time_taken, time_for_iters/iters, time_for_setup); + + // free memory + free_hdd_matrix(hddm); + delete soln2; + + return ptr_to_jlong(soln); +} + +//------------------------------------------------------------------------------ + +static void mult_rec(HDDNode *hdd, int level, int row_offset, int col_offset) +{ + HDDNode *e, *t; + + // if it's the zero node + if (hdd == zero) { + return; + } + // or if we've reached a submatrix + // (check for non-null ptr but, equivalently, we could just check if level==l_sm) + else if (hdd->sm.ptr) { + if (!compact_sm) { + mult_cm((CMSparseMatrix *)hdd->sm.ptr, row_offset, col_offset); + } else { + mult_cmsc((CMSCSparseMatrix *)hdd->sm.ptr, row_offset, col_offset); + } + return; + } + // or if we've reached the bottom + else if (level == num_levels) { + //printf("(%d,%d)=%f\n", col_offset, row_offset, hdd->type.val); + soln2[col_offset] += soln[row_offset] * (hdd->type.val); + return; + } + // otherwise recurse + e = hdd->type.kids.e; + if (e != zero) { + mult_rec(e->type.kids.e, level+1, row_offset, col_offset); + mult_rec(e->type.kids.t, level+1, row_offset, col_offset+e->off.val); + } + t = hdd->type.kids.t; + if (t != zero) { + mult_rec(t->type.kids.e, level+1, row_offset+hdd->off.val, col_offset); + mult_rec(t->type.kids.t, level+1, row_offset+hdd->off.val, col_offset+t->off.val); + } +} + +//----------------------------------------------------------------------------------- + +static void mult_cm(CMSparseMatrix *cmsm, int row_offset, int col_offset) +{ + int i2, j2, l2, h2; + int sm_n = cmsm->n; + int sm_nnz = cmsm->nnz; + double *sm_non_zeros = cmsm->non_zeros; + unsigned char *sm_col_counts = cmsm->col_counts; + int *sm_col_starts = (int *)cmsm->col_counts; + bool sm_use_counts = cmsm->use_counts; + unsigned int *sm_rows = cmsm->rows; + + // loop through columns of submatrix + l2 = sm_nnz; h2 = 0; + for (i2 = 0; i2 < sm_n; i2++) { + + // loop through entries in this column + if (!sm_use_counts) { l2 = sm_col_starts[i2]; h2 = sm_col_starts[i2+1]; } + else { l2 = h2; h2 += sm_col_counts[i2]; } + for (j2 = l2; j2 < h2; j2++) { + soln2[col_offset + i2] += soln[row_offset + sm_rows[j2]] * (sm_non_zeros[j2]); + //printf("(%d,%d)=%f\n", col_offset + sm_rows[j2], row_offset + i2, sm_non_zeros[j2]); + } + } +} + +//----------------------------------------------------------------------------------- + +static void mult_cmsc(CMSCSparseMatrix *cmscsm, int row_offset, int col_offset) +{ + int i2, j2, l2, h2; + int sm_n = cmscsm->n; + int sm_nnz = cmscsm->nnz; + unsigned char *sm_col_counts = cmscsm->col_counts; + int *sm_col_starts = (int *)cmscsm->col_counts; + bool sm_use_counts = cmscsm->use_counts; + unsigned int *sm_rows = cmscsm->rows; + + // loop through columns of submatrix + l2 = sm_nnz; h2 = 0; + for (i2 = 0; i2 < sm_n; i2++) { + + // loop through entries in this column + if (!sm_use_counts) { l2 = sm_col_starts[i2]; h2 = sm_col_starts[i2+1]; } + else { l2 = h2; h2 += sm_col_counts[i2]; } + for (j2 = l2; j2 < h2; j2++) { + soln2[col_offset + i2] += soln[row_offset + (int)(sm_rows[j2] >> sm_dist_shift)] * (sm_dist[(int)(sm_rows[j2] & sm_dist_mask)]); + //printf("(%d,%d)=%f\n", col_offset + (int)(sm_rows[j2] >> sm_dist_shift), row_offset + i2, sm_dist[(int)(sm_rows[j2] & sm_dist_mask)]); + } + } +} + +//------------------------------------------------------------------------------ diff --git a/prism/src/hybrid/PrismHybrid.java b/prism/src/hybrid/PrismHybrid.java index 6071254b..5717f5e3 100644 --- a/prism/src/hybrid/PrismHybrid.java +++ b/prism/src/hybrid/PrismHybrid.java @@ -272,6 +272,15 @@ public class PrismHybrid return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); } + // transient (probabilistic/dtmc) + private static native long PH_ProbTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, int time); + public static DoubleVector ProbTransient(JDDNode trans, ODDNode odd, JDDNode init, JDDVars rows, JDDVars cols, int time) throws PrismException + { + long ptr = PH_ProbTransient(trans.ptr(), odd.ptr(), init.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); + if (ptr == 0) throw new PrismException(getErrorMessage()); + return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); + } + //---------------------------------------------------------------------------------------------- // nondeterministic/mdp stuff //---------------------------------------------------------------------------------------------- diff --git a/prism/src/mtbdd/PM_ProbTransient.cc b/prism/src/mtbdd/PM_ProbTransient.cc new file mode 100644 index 00000000..61ac1c35 --- /dev/null +++ b/prism/src/mtbdd/PM_ProbTransient.cc @@ -0,0 +1,136 @@ +//============================================================================== +// +// Copyright (c) 2002- +// Authors: +// * Dave Parker (University of Oxford, formerly University of Birmingham) +// +//------------------------------------------------------------------------------ +// +// This file is part of PRISM. +// +// PRISM is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// PRISM is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with PRISM; if not, write to the Free Software Foundation, +// Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +//============================================================================== + +// includes +#include "PrismMTBDD.h" +#include +#include +#include +#include +#include +#include +#include "PrismMTBDDGlob.h" +#include "jnipointer.h" + +//------------------------------------------------------------------------------ + +JNIEXPORT jlong __pointer JNICALL Java_mtbdd_PrismMTBDD_PM_1ProbTransient +( +JNIEnv *env, +jclass cls, +jlong __pointer tr, // rate matrix +jlong __pointer od, // odd +jlong __pointer in, // initial distribution +jlong __pointer rv, // row vars +jint num_rvars, +jlong __pointer cv, // col vars +jint num_cvars, +jint time // time +) +{ + // cast function parameters + DdNode *trans = jlong_to_DdNode(tr); // trans matrix + ODDNode *odd = jlong_to_ODDNode(od); // odd + DdNode *init = jlong_to_DdNode(in); // initial distribution + DdNode **rvars = jlong_to_DdNode_array(rv); // row vars + DdNode **cvars = jlong_to_DdNode_array(cv); // col vars + + // mtbdds + DdNode *sol, *tmp; + // timing stuff + long start1, start2, start3, stop; + double time_taken, time_for_setup, time_for_iters; + // misc + int iters; + bool done; + + // start clocks + start1 = start2 = util_cpu_time(); + + // set up vectors + Cudd_Ref(init); + sol = init; + sol = DD_PermuteVariables(ddman, sol, rvars, cvars, num_rvars); + + // get setup time + stop = util_cpu_time(); + time_for_setup = (double)(stop - start2)/1000; + start2 = stop; + + // start iterations + iters = 0; + done = false; + PM_PrintToMainLog(env, "\nStarting iterations...\n"); + + // note that we ignore max_iters as we know how any iterations _should_ be performed + for (iters = 0; iters < time && !done; iters++) { + +// PM_PrintToMainLog(env, "Iteration %d: ", iters); +// start3 = util_cpu_time(); + + //matrix-vector multiply + Cudd_Ref(sol); + tmp = DD_PermuteVariables(ddman, sol, cvars, rvars, num_rvars); + Cudd_Ref(trans); + tmp = DD_MatrixMultiply(ddman, tmp, trans, rvars, num_rvars, MM_BOULDER); + + // check for steady state convergence + if (do_ss_detect) switch (term_crit) { + case TERM_CRIT_ABSOLUTE: + if (DD_EqualSupNorm(ddman, tmp, sol, term_crit_param)) { + done = true; + } + break; + case TERM_CRIT_RELATIVE: + if (DD_EqualSupNormRel(ddman, tmp, sol, term_crit_param)) { + done = true; + } + break; + } + + // prepare for next iteration + Cudd_RecursiveDeref(ddman, sol); + sol = tmp; + +// PM_PrintToMainLog(env, "%.2f %.2f sec\n", ((double)(util_cpu_time() - start3)/1000), ((double)(util_cpu_time() - start2)/1000)/iters); + } + + // convert to row vector + sol = DD_PermuteVariables(ddman, sol, cvars, rvars, num_rvars); + + // stop clocks + stop = util_cpu_time(); + time_for_iters = (double)(stop - start2)/1000; + time_taken = (double)(stop - start1)/1000; + + // print iterations/timing info + if (done) PM_PrintToMainLog(env, "\nSteady state detected at iteration %d\n", iters); + PM_PrintToMainLog(env, "\nIterative method: %d iterations in %.2f seconds (average %.6f, setup %.2f)\n", iters, time_taken, time_for_iters/iters, time_for_setup); + + return ptr_to_jlong(sol); +} + +//------------------------------------------------------------------------------ diff --git a/prism/src/mtbdd/PM_StochTransient.cc b/prism/src/mtbdd/PM_StochTransient.cc index e74b9683..7ff24416 100644 --- a/prism/src/mtbdd/PM_StochTransient.cc +++ b/prism/src/mtbdd/PM_StochTransient.cc @@ -48,7 +48,7 @@ jlong __pointer rv, // row vars jint num_rvars, jlong __pointer cv, // col vars jint num_cvars, -jdouble time // time bound +jint time // time ) { // cast function parameters diff --git a/prism/src/mtbdd/PrismMTBDD.java b/prism/src/mtbdd/PrismMTBDD.java index 00035b7c..55e1dfeb 100644 --- a/prism/src/mtbdd/PrismMTBDD.java +++ b/prism/src/mtbdd/PrismMTBDD.java @@ -290,6 +290,15 @@ public class PrismMTBDD return new JDDNode(ptr); } + // transient (probabilistic/dtmc) + private static native long PM_ProbTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, int time); + public static JDDNode ProbTransient(JDDNode trans, ODDNode odd, JDDNode init, JDDVars rows, JDDVars cols, int time) throws PrismException + { + long ptr = PM_ProbTransient(trans.ptr(), odd.ptr(), init.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); + if (ptr == 0) throw new PrismException(getErrorMessage()); + return new JDDNode(ptr); + } + //------------------------------------------------------------------------------ // nondeterministic/mdp stuff //------------------------------------------------------------------------------ diff --git a/prism/src/prism/Prism.java b/prism/src/prism/Prism.java index 56e63195..2c20f656 100644 --- a/prism/src/prism/Prism.java +++ b/prism/src/prism/Prism.java @@ -1280,12 +1280,20 @@ public class Prism implements PrismSettingsListener StateProbs probs = null; mainLog.println("\nComputing steady-state probabilities..."); - - // create new model checker object - mc = new StochModelChecker(this, model, null); - // do steady state calculation l = System.currentTimeMillis(); - probs = ((StochModelChecker)mc).doSteadyState(); + + if (model instanceof ProbModel) { + mc = new ProbModelChecker(this, model, null); + probs = ((ProbModelChecker)mc).doSteadyState(); + } + else if (model instanceof StochModel) { + mc = new StochModelChecker(this, model, null); + probs = ((StochModelChecker)mc).doSteadyState(); + } + else { + throw new PrismException("Steady-state probabilities only computed for DTMCs/CTMCs"); + } + l = System.currentTimeMillis() - l; // print out probabilities @@ -1304,14 +1312,27 @@ public class Prism implements PrismSettingsListener long l = 0; // timer StateProbs probs = null; - mainLog.println("\nComputing transient probabilities (time = " + time + ")..."); + if (time < 0) throw new PrismException("Cannot compute transient probabilities for negative time value"); // create new model checker object mc = new StochModelChecker(this, model, null); - // do steady state calculation l = System.currentTimeMillis(); - probs = ((StochModelChecker)mc).doTransient(time); + + if (model instanceof ProbModel) { + mainLog.println("\nComputing transient probabilities (time = " + (int)time + ")..."); + mc = new ProbModelChecker(this, model, null); + probs = ((ProbModelChecker)mc).doTransient((int)time); + } + else if (model instanceof StochModel) { + mainLog.println("\nComputing transient probabilities (time = " + time + ")..."); + mc = new StochModelChecker(this, model, null); + probs = ((StochModelChecker)mc).doTransient(time); + } + else { + throw new PrismException("Transient probabilities only computed for DTMCs/CTMCs"); + } + l = System.currentTimeMillis() - l; // print out probabilities diff --git a/prism/src/prism/PrismCL.java b/prism/src/prism/PrismCL.java index ebd9f4cc..c0d9fe96 100644 --- a/prism/src/prism/PrismCL.java +++ b/prism/src/prism/PrismCL.java @@ -116,7 +116,7 @@ public class PrismCL private ResultsCollection results[] = null; // time for transient computation - private double transientTime; + private String transientTime; // simulation info private double simApprox; @@ -691,12 +691,30 @@ public class PrismCL private void doTransient() throws PrismException { + double d; + int i; + // compute transient probabilities - if (model instanceof StochModel) { - prism.doTransient(model, transientTime); + if (model instanceof StochModel || model instanceof ProbModel) { + try { + d = Double.parseDouble(transientTime); + } + catch (NumberFormatException e) { + throw new PrismException("Invalid value \""+transientTime+"\" for transient probability computation"); + } + prism.doTransient(model, d); + } + else if (model instanceof ProbModel) { + try { + i = Integer.parseInt(transientTime); + } + catch (NumberFormatException e) { + throw new PrismException("Invalid value \""+transientTime+"\" for transient probability computation"); + } + prism.doTransient(model, i); } else { - mainLog.println("\nWarning: Transient probabilities only computed for CTMC models."); + mainLog.println("\nWarning: Transient probabilities only computed for DTMCs/CTMCs."); } } @@ -754,8 +772,10 @@ public class PrismCL if (i < args.length-1) { try { dotransient = true; - transientTime = Double.parseDouble(args[++i]); - if (transientTime < 0) throw new NumberFormatException(""); + transientTime = args[++i]; + // Make sure transient time parses as a +ve double + d = Double.parseDouble(transientTime); + if (d < 0) throw new NumberFormatException(""); } catch (NumberFormatException e) { errorAndExit("Invalid value for -"+sw+" switch"); diff --git a/prism/src/prism/ProbModelChecker.java b/prism/src/prism/ProbModelChecker.java index 38935423..154a1e6f 100644 --- a/prism/src/prism/ProbModelChecker.java +++ b/prism/src/prism/ProbModelChecker.java @@ -1096,6 +1096,47 @@ public class ProbModelChecker extends StateModelChecker return solnProbs; } + // ----------------------------------------------------------------------------------- + // do transient computation + // ----------------------------------------------------------------------------------- + + // transient computation (from initial states) + + public StateProbs doTransient(int time) throws PrismException + { + // mtbdd stuff + JDDNode start, init; + // other stuff + StateProbs probs = null; + + // get initial states of model + start = model.getStart(); + + // and hence compute initial probability distribution (equiprobable over + // all start states) + JDD.Ref(start); + init = JDD.Apply(JDD.DIVIDE, start, JDD.Constant(JDD.GetNumMinterms(start, allDDRowVars.n()))); + + // compute transient probabilities + try { + // special case: time = 0 + if (time == 0) { + JDD.Ref(init); + probs = new StateProbsMTBDD(init, model); + } else { + probs = computeTransientProbs(trans, init, time); + } + } catch (PrismException e) { + JDD.Deref(init); + throw e; + } + + // derefs + JDD.Deref(init); + + return probs; + } + // ----------------------------------------------------------------------------------- // probability computation methods // ----------------------------------------------------------------------------------- @@ -1614,6 +1655,38 @@ public class ProbModelChecker extends StateModelChecker return probs; } + + // compute transient probabilities + + protected StateProbs computeTransientProbs(JDDNode tr, JDDNode init, int time) throws PrismException + { + JDDNode probsMTBDD; + DoubleVector probsDV; + StateProbs probs = null; + + try { + switch (engine) { + case Prism.MTBDD: + probsMTBDD = PrismMTBDD.ProbTransient(tr, odd, init, allDDRowVars, allDDColVars, time); + probs = new StateProbsMTBDD(probsMTBDD, model); + break; + case Prism.SPARSE: + probsDV = PrismSparse.ProbTransient(tr, odd, init, allDDRowVars, allDDColVars, time); + probs = new StateProbsDV(probsDV, model); + break; + case Prism.HYBRID: + probsDV = PrismHybrid.ProbTransient(tr, odd, init, allDDRowVars, allDDColVars, time); + probs = new StateProbsDV(probsDV, model); + break; + default: + throw new PrismException("Engine does not support this numerical method"); + } + } catch (PrismException e) { + throw e; + } + + return probs; + } } // ------------------------------------------------------------------------------ diff --git a/prism/src/sparse/PS_ProbTransient.cc b/prism/src/sparse/PS_ProbTransient.cc new file mode 100644 index 00000000..4b737dcc --- /dev/null +++ b/prism/src/sparse/PS_ProbTransient.cc @@ -0,0 +1,229 @@ +//============================================================================== +// +// Copyright (c) 2002- +// Authors: +// * Dave Parker (University of Oxford, formerly University of Birmingham) +// +//------------------------------------------------------------------------------ +// +// This file is part of PRISM. +// +// PRISM is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// PRISM is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with PRISM; if not, write to the Free Software Foundation, +// Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +//============================================================================== + +// includes +#include "PrismSparse.h" +#include +#include +#include +#include +#include +#include +#include +#include "sparse.h" +#include "PrismSparseGlob.h" +#include "jnipointer.h" + +//------------------------------------------------------------------------------ + +JNIEXPORT jlong __pointer JNICALL Java_sparse_PrismSparse_PS_1ProbTransient +( +JNIEnv *env, +jclass cls, +jlong __pointer tr, // trans matrix +jlong __pointer od, // odd +jlong __pointer in, // initial distribution +jlong __pointer rv, // row vars +jint num_rvars, +jlong __pointer cv, // col vars +jint num_cvars, +jint time // time +) +{ + // cast function parameters + DdNode *trans = jlong_to_DdNode(tr); // trans matrix + ODDNode *odd = jlong_to_ODDNode(od); // odd + DdNode *init = jlong_to_DdNode(in); // initial distribution + DdNode **rvars = jlong_to_DdNode_array(rv); // row vars + DdNode **cvars = jlong_to_DdNode_array(cv); // col vars + + // model stats + int n; + long nnz; + // flags + bool compact_tr; + // sparse matrix + CMSparseMatrix *cmsm; + CMSCSparseMatrix *cmscsm; + // vectors + double *soln, *soln2, *tmpsoln; + // timing stuff + long start1, start2, start3, stop; + double time_taken, time_for_setup, time_for_iters; + // misc + bool done; + int i, j, l, h, iters; + double d, kb, kbt; + + // start clocks + start1 = start2 = util_cpu_time(); + + // get number of states + n = odd->eoff + odd->toff; + + // build sparse matrix + PS_PrintToMainLog(env, "\nBuilding sparse matrix... "); + // if requested, try and build a "compact" version + compact_tr = true; + cmscsm = NULL; + if (compact) cmscsm = build_cmsc_sparse_matrix(ddman, trans, rvars, cvars, num_rvars, odd); + if (cmscsm != NULL) { + nnz = cmscsm->nnz; + kb = cmscsm->mem; + } + // if not or if it wasn't possible, built a normal one + else { + compact_tr = false; + cmsm = build_cm_sparse_matrix(ddman, trans, rvars, cvars, num_rvars, odd); + nnz = cmsm->nnz; + kb = cmsm->mem; + } + // print some info + PS_PrintToMainLog(env, "[n=%d, nnz=%d%s] ", n, nnz, compact_tr?", compact":""); + kbt = kb; + PS_PrintToMainLog(env, "[%.1f KB]\n", kb); + + // create solution/iteration vectors + PS_PrintToMainLog(env, "Allocating iteration vectors... "); + soln = mtbdd_to_double_vector(ddman, init, rvars, num_rvars, odd); + soln2 = new double[n]; + kb = n*8.0/1024.0; + kbt += 2*kb; + PS_PrintToMainLog(env, "[2 x %.1f KB]\n", kb); + + // print total memory usage + PS_PrintToMainLog(env, "TOTAL: [%.1f KB]\n", kbt); + + // get setup time + stop = util_cpu_time(); + time_for_setup = (double)(stop - start2)/1000; + start2 = stop; + + // start iterations + iters = 0; + done = false; + PS_PrintToMainLog(env, "\nStarting iterations...\n"); + + // note that we ignore max_iters as we know how any iterations _should_ be performed + for (iters = 0; iters < time && !done; iters++) { + +// PS_PrintToMainLog(env, "Iteration %d: ", iters); +// start3 = util_cpu_time(); + + // store local copies of stuff + double *non_zeros; + unsigned char *col_counts; + int *col_starts; + bool use_counts; + unsigned int *rows; + double *dist; + int dist_shift; + int dist_mask; + if (!compact_tr) { + non_zeros = cmsm->non_zeros; + col_counts = cmsm->col_counts; + col_starts = (int *)cmsm->col_counts; + use_counts = cmsm->use_counts; + rows = cmsm->rows; + } else { + col_counts = cmscsm->col_counts; + col_starts = (int *)cmscsm->col_counts; + use_counts = cmscsm->use_counts; + rows = cmscsm->rows; + dist = cmscsm->dist; + dist_shift = cmscsm->dist_shift; + dist_mask = cmscsm->dist_mask; + } + + // do matrix vector multiply bit + h = 0; + for (i = 0; i < n; i++) { + d = 0.0; + if (!use_counts) { l = col_starts[i]; h = col_starts[i+1]; } + else { l = h; h += col_counts[i]; } + // "column major" version + if (!compact_tr) { + for (j = l; j < h; j++) { + d += non_zeros[j] * soln[rows[j]]; + } + // "compact msc" version + } else { + for (j = l; j < h; j++) { + d += dist[(int)(rows[j] & dist_mask)] * soln[(int)(rows[j] >> dist_shift)]; + } + } + // set vector element + soln2[i] = d; + } + + // check for steady state convergence + // (note: doing outside loop means may not need to check all elements) + if (do_ss_detect) switch (term_crit) { + case TERM_CRIT_ABSOLUTE: + done = true; + for (i = 0; i < n; i++) { + if (fabs(soln2[i] - soln[i]) > term_crit_param) { + done = false; + break; + } + } + break; + case TERM_CRIT_RELATIVE: + done = true; + for (i = 0; i < n; i++) { + if (fabs((soln2[i] - soln[i])/soln2[i]) > term_crit_param) { + done = false; + break; + } + } + break; + } + + // prepare for next iteration + tmpsoln = soln; + soln = soln2; + soln2 = tmpsoln; + +// PS_PrintToMainLog(env, "%.2f %.2f sec\n", ((double)(util_cpu_time() - start3)/1000), ((double)(util_cpu_time() - start2)/1000)/iters); + } + + // stop clocks + stop = util_cpu_time(); + time_for_iters = (double)(stop - start2)/1000; + time_taken = (double)(stop - start1)/1000; + + // print iters/timing info + if (done) PS_PrintToMainLog(env, "\nSteady state detected at iteration %d\n", iters); + PS_PrintToMainLog(env, "\nIterative method: %d iterations in %.2f seconds (average %.6f, setup %.2f)\n", iters, time_taken, time_for_iters/iters, time_for_setup); + + // free memory + if (compact_tr) free_cmsc_sparse_matrix(cmscsm); else free_cm_sparse_matrix(cmsm); + delete soln2; + + return ptr_to_jlong(soln); +} + +//------------------------------------------------------------------------------ diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java index d271fb4c..b31c4561 100644 --- a/prism/src/sparse/PrismSparse.java +++ b/prism/src/sparse/PrismSparse.java @@ -232,6 +232,15 @@ public class PrismSparse return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); } + // transient (probabilistic/dtmc) + private static native long PS_ProbTransient(long trans, long odd, long init, long rv, int nrv, long cv, int ncv, int time); + public static DoubleVector ProbTransient(JDDNode trans, ODDNode odd, JDDNode init, JDDVars rows, JDDVars cols, int time) throws PrismException + { + long ptr = PS_ProbTransient(trans.ptr(), odd.ptr(), init.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); + if (ptr == 0) throw new PrismException(getErrorMessage()); + return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); + } + //---------------------------------------------------------------------------------------------- // nondeterministic/mdp stuff //---------------------------------------------------------------------------------------------- diff --git a/prism/src/userinterface/model/GUIMultiModel.java b/prism/src/userinterface/model/GUIMultiModel.java index 6fe5b816..7c707c85 100644 --- a/prism/src/userinterface/model/GUIMultiModel.java +++ b/prism/src/userinterface/model/GUIMultiModel.java @@ -137,7 +137,7 @@ public class GUIMultiModel extends GUIPlugin implements PrismSettingsListener viewTransRewards.setEnabled(!computing); viewPrismCode.setEnabled(!computing && handler.getParseState() == GUIMultiModelTree.TREE_SYNCHRONIZED_GOOD); computeSS.setEnabled(!computing && (handler.getParsedModelType() == ModulesFile.STOCHASTIC || handler.getParsedModelType() == ModulesFile.PROBABILISTIC)); - computeTr.setEnabled(!computing && (handler.getParsedModelType() == ModulesFile.STOCHASTIC)); + computeTr.setEnabled(!computing && (handler.getParsedModelType() == ModulesFile.STOCHASTIC || handler.getParsedModelType() == ModulesFile.PROBABILISTIC)); exportStatesPlain.setEnabled(!computing); exportStatesMatlab.setEnabled(!computing); exportTransPlain.setEnabled(!computing); diff --git a/prism/src/userinterface/model/computation/ComputeTransientThread.java b/prism/src/userinterface/model/computation/ComputeTransientThread.java index f6afc56a..bc7e792c 100644 --- a/prism/src/userinterface/model/computation/ComputeTransientThread.java +++ b/prism/src/userinterface/model/computation/ComputeTransientThread.java @@ -66,7 +66,7 @@ public class ComputeTransientThread extends GUIComputationThread //Do Computation try { - if(!(computeThis instanceof StochModel)) throw new PrismException("Can only compute transient probabilities for CTMCs"); + if(!(computeThis instanceof StochModel || computeThis instanceof ProbModel)) throw new PrismException("Can only compute transient probabilities for DTMCs/CTMCs"); plug.getPrism().doTransient(computeThis, transientTime); } catch(PrismException e)