diff --git a/prism/include/PrismSparse.h b/prism/include/PrismSparse.h index 33de26dd..e2bfef9f 100644 --- a/prism/include/PrismSparse.h +++ b/prism/include/PrismSparse.h @@ -127,6 +127,14 @@ JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1ProbUntil JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1ProbCumulReward (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jlong, jint, jlong, jint, jint); +/* + * Class: sparse_PrismSparse + * Method: PS_ProbInstReward + * Signature: (JJJJIJII)J + */ +JNIEXPORT jlong JNICALL Java_sparse_PrismSparse_PS_1ProbInstReward + (JNIEnv *, jclass, jlong, jlong, jlong, jlong, jint, jlong, jint, jint); + /* * Class: sparse_PrismSparse * Method: PS_ProbReachReward diff --git a/prism/src/prism/ProbModelChecker.java b/prism/src/prism/ProbModelChecker.java index 9a51752a..b9fbc3d7 100644 --- a/prism/src/prism/ProbModelChecker.java +++ b/prism/src/prism/ProbModelChecker.java @@ -1446,6 +1446,10 @@ public class ProbModelChecker extends StateModelChecker rewardsMTBDD = PrismMTBDD.ProbInstReward(tr, sr, odd, allDDRowVars, allDDColVars, time); rewards = new StateProbsMTBDD(rewardsMTBDD, model); break; + case Prism.SPARSE: + rewardsDV = PrismSparse.ProbInstReward(tr, sr, odd, allDDRowVars, allDDColVars, time); + rewards = new StateProbsDV(rewardsDV, model); + break; default: throw new PrismException("Engine does not support this numerical method"); } diff --git a/prism/src/sparse/PS_ProbInstReward.cc b/prism/src/sparse/PS_ProbInstReward.cc new file mode 100644 index 00000000..3bfeadfb --- /dev/null +++ b/prism/src/sparse/PS_ProbInstReward.cc @@ -0,0 +1,205 @@ +//============================================================================== +// +// Copyright (c) 2002- +// Authors: +// * Dave Parker (University of Oxford, formerly University of Birmingham) +// +//------------------------------------------------------------------------------ +// +// This file is part of PRISM. +// +// PRISM is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// PRISM is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with PRISM; if not, write to the Free Software Foundation, +// Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +//============================================================================== + +// includes +#include "PrismSparse.h" +#include +#include +#include +#include +#include +#include +#include "sparse.h" +#include "PrismSparseGlob.h" +#include "jnipointer.h" + +//------------------------------------------------------------------------------ + +JNIEXPORT jlong __pointer JNICALL Java_sparse_PrismSparse_PS_1ProbInstReward +( +JNIEnv *env, +jclass cls, +jlong __pointer t, // trans matrix +jlong __pointer sr, // state rewards +jlong __pointer od, // odd +jlong __pointer rv, // row vars +jint num_rvars, +jlong __pointer cv, // col vars +jint num_cvars, +jint bound // time bound +) +{ + // cast function parameters + DdNode *trans = jlong_to_DdNode(t); // trans matrix + DdNode *state_rewards = jlong_to_DdNode(sr); // state rewards + ODDNode *odd = jlong_to_ODDNode(od); // reachable states + DdNode **rvars = jlong_to_DdNode_array(rv); // row vars + DdNode **cvars = jlong_to_DdNode_array(cv); // col vars + + // mtbdds + DdNode *tmp; + // model stats + int n; + long nnz; + // flags + bool compact_tr, compact_r; + // sparse matrix + RMSparseMatrix *rmsm; + CMSRSparseMatrix *cmsrsm; + // vectors + double *soln, *soln2, *tmpsoln; + // timing stuff + long start1, start2, start3, stop; + double time_taken, time_for_setup, time_for_iters; + // misc + int i, j, l, h, iters; + double d, kb, kbt; + bool first; + + // start clocks + start1 = start2 = util_cpu_time(); + + // get number of states + n = odd->eoff + odd->toff; + + // build sparse matrix + PS_PrintToMainLog(env, "\nBuilding sparse matrix... "); + // if requested, try and build a "compact" version + compact_tr = true; + cmsrsm = NULL; + if (compact) cmsrsm = build_cmsr_sparse_matrix(ddman, trans, rvars, cvars, num_rvars, odd); + if (cmsrsm != NULL) { + nnz = cmsrsm->nnz; + kb = cmsrsm->mem; + } + // if not or if it wasn't possible, built a normal one + else { + compact_tr = false; + rmsm = build_rm_sparse_matrix(ddman, trans, rvars, cvars, num_rvars, odd); + nnz = rmsm->nnz; + kb = rmsm->mem; + } + // print some info + PS_PrintToMainLog(env, "[n=%d, nnz=%d%s] ", n, nnz, compact_tr?", compact":""); + kbt = kb; + PS_PrintToMainLog(env, "[%.1f KB]\n", kb); + + // create solution/iteration vectors + // (solution is initialised to the state rewards) + PS_PrintToMainLog(env, "Allocating iteration vectors... "); + soln = mtbdd_to_double_vector(ddman, state_rewards, rvars, num_rvars, odd); + soln2 = new double[n]; + kb = n*8.0/1024.0; + kbt += 2*kb; + PS_PrintToMainLog(env, "[2 x %.1f KB]\n", kb); + + // print total memory usage + PS_PrintToMainLog(env, "TOTAL: [%.1f KB]\n", kbt); + + // get setup time + stop = util_cpu_time(); + time_for_setup = (double)(stop - start2)/1000; + start2 = stop; + + // start iterations + PS_PrintToMainLog(env, "\nStarting iterations...\n"); + + // note that we ignore max_iters as we know how any iterations _should_ be performed + for (iters = 0; iters < bound; iters++) { + +// PS_PrintToMainLog(env, "iter %d\n", iters); +// start3 = util_cpu_time(); + + // store local copies of stuff + double *non_zeros; + unsigned char *row_counts; + int *row_starts; + bool use_counts; + unsigned int *cols; + double *dist; + int dist_shift; + int dist_mask; + if (!compact_tr) { + non_zeros = rmsm->non_zeros; + row_counts = rmsm->row_counts; + row_starts = (int *)rmsm->row_counts; + use_counts = rmsm->use_counts; + cols = rmsm->cols; + } else { + row_counts = cmsrsm->row_counts; + row_starts = (int *)cmsrsm->row_counts; + use_counts = cmsrsm->use_counts; + cols = cmsrsm->cols; + dist = cmsrsm->dist; + dist_shift = cmsrsm->dist_shift; + dist_mask = cmsrsm->dist_mask; + } + + // matrix multiply + h = 0; + for (i = 0; i < n; i++) { + d = 0.0; + if (!use_counts) { l = row_starts[i]; h = row_starts[i+1]; } + else { l = h; h += row_counts[i]; } + // "row major" version + if (!compact_tr) { + for (j = l; j < h; j++) { + d += non_zeros[j] * soln[cols[j]]; + } + // "compact msr" version + } else { + for (j = l; j < h; j++) { + d += dist[(int)(cols[j] & dist_mask)] * soln[(int)(cols[j] >> dist_shift)]; + } + } + // set vector element + soln2[i] = d; + } + + // prepare for next iteration + tmpsoln = soln; + soln = soln2; + soln2 = tmpsoln; + +// PS_PrintToMainLog(env, "%.2f %.2f sec\n", ((double)(util_cpu_time() - start3)/1000), ((double)(util_cpu_time() - start2)/1000)/iters); + } + + // stop clocks + stop = util_cpu_time(); + time_for_iters = (double)(stop - start2)/1000; + time_taken = (double)(stop - start1)/1000; + + // print iterations/timing info + PS_PrintToMainLog(env, "\nIterative method: %d iterations in %.2f seconds (average %.6f, setup %.2f)\n", iters, time_taken, time_for_iters/iters, time_for_setup); + + // free memory + if (compact_tr) free_cmsr_sparse_matrix(cmsrsm); else free_rm_sparse_matrix(rmsm); + delete soln2; + + return ptr_to_jlong(soln); +} + +//------------------------------------------------------------------------------ diff --git a/prism/src/sparse/PrismSparse.java b/prism/src/sparse/PrismSparse.java index a603083b..d271fb4c 100644 --- a/prism/src/sparse/PrismSparse.java +++ b/prism/src/sparse/PrismSparse.java @@ -214,6 +214,15 @@ public class PrismSparse return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); } + // pctl inst reward (probabilistic/dtmc) + private static native long PS_ProbInstReward(long trans, long sr, long odd, long rv, int nrv, long cv, int ncv, int time); + public static DoubleVector ProbInstReward(JDDNode trans, JDDNode sr, ODDNode odd, JDDVars rows, JDDVars cols, int time) throws PrismException + { + long ptr = PS_ProbInstReward(trans.ptr(), sr.ptr(), odd.ptr(), rows.array(), rows.n(), cols.array(), cols.n(), time); + if (ptr == 0) throw new PrismException(getErrorMessage()); + return new DoubleVector(ptr, (int)(odd.getEOff() + odd.getTOff())); + } + // pctl reach reward (probabilistic/dtmc) private static native long PS_ProbReachReward(long trans, long sr, long trr, long odd, long rv, int nrv, long cv, int ncv, long goal, long inf, long maybe); public static DoubleVector ProbReachReward(JDDNode trans, JDDNode sr, JDDNode trr, ODDNode odd, JDDVars rows, JDDVars cols, JDDNode goal, JDDNode inf, JDDNode maybe) throws PrismException