diff --git a/prism/src/mtbdd/PM_NondetReachReward.cc b/prism/src/mtbdd/PM_NondetReachReward.cc index 15f72882..1af46266 100644 --- a/prism/src/mtbdd/PM_NondetReachReward.cc +++ b/prism/src/mtbdd/PM_NondetReachReward.cc @@ -111,8 +111,11 @@ jboolean min // min or max probabilities (true = min, false = max) Cudd_Ref(mask); new_mask = DD_ITE(ddman, mask, DD_PlusInfinity(ddman), DD_Constant(ddman, 0)); - // initial solution is zero - sol = DD_Constant(ddman, 0); + // initial solution is infinity in 'inf' states, zero elsewhere + // note: ok to do this because cudd matrix-multiply (and other ops) + // treat 0 * inf as 0, unlike in IEEE 754 rules + Cudd_Ref(inf); + sol = DD_ITE(ddman, inf, DD_PlusInfinity(ddman), DD_Constant(ddman, 0)); // print memory usage i = DD_GetNumNodes(ddman, a); @@ -158,6 +161,10 @@ jboolean min // min or max probabilities (true = min, false = max) tmp = DD_MaxAbstract(ddman, tmp, ndvars, num_ndvars); } + // put infinities (for 'inf' states) back into into solution vector + Cudd_Ref(inf); + tmp = DD_ITE(ddman, inf, DD_PlusInfinity(ddman), tmp); + // check convergence switch (term_crit) { case TERM_CRIT_ABSOLUTE: @@ -179,10 +186,6 @@ jboolean min // min or max probabilities (true = min, false = max) // PM_PrintToMainLog(env, "%.2f %.2f sec\n", ((double)(util_cpu_time() - start3)/1000), ((double)(util_cpu_time() - start2)/1000)/iters); } - // set reward for infinity states to infinity - Cudd_Ref(inf); - sol = DD_ITE(ddman, inf, DD_PlusInfinity(ddman), sol); - // stop clocks stop = util_cpu_time(); time_for_iters = (double)(stop - start2)/1000; diff --git a/prism/src/sparse/PS_NondetReachReward.cc b/prism/src/sparse/PS_NondetReachReward.cc index 7e4b3e86..c071960e 100644 --- a/prism/src/sparse/PS_NondetReachReward.cc +++ b/prism/src/sparse/PS_NondetReachReward.cc @@ -147,6 +147,13 @@ jboolean min // min or max probabilities (true = min, false = max) kbt += kb; PS_PrintMemoryToMainLog(env, "[", kb, "]\n"); + // get vector for yes + PS_PrintToMainLog(env, "Creating vector for inf... "); + inf_vec = mtbdd_to_double_vector(ddman, inf, rvars, num_rvars, odd); + kb = n*8.0/1024.0; + kbt += kb; + PS_PrintMemoryToMainLog(env, "[", kb, "]\n"); + // create solution/iteration vectors PS_PrintToMainLog(env, "Allocating iteration vectors... "); soln = new double[n]; @@ -158,9 +165,9 @@ jboolean min // min or max probabilities (true = min, false = max) // print total memory usage PS_PrintMemoryToMainLog(env, "TOTAL: [", kbt, "]\n"); - // initial solution is zero + // initial solution is infinity in 'inf' states, zero elsewhere for (i = 0; i < n; i++) { - soln[i] = 0; + soln[i] = (inf_vec[i] > 0) ? HUGE_VAL : 0.0; } // get setup time @@ -241,8 +248,8 @@ jboolean min // min or max probabilities (true = min, false = max) first = false; } // set vector element - // (if there were no choices from this state, reward is zero) - soln2[i] = (h1 > l1) ? d1 : 0; + // (if there were no choices from this state, reward is zero/infinity) + soln2[i] = (h1 > l1) ? d1 : inf_vec[i] > 0 ? HUGE_VAL : 0; // store adversary info (if required) if (adv_loop) if (h1 > l1) for (k = adv_l; k < adv_h; k++) fprintf(fp_adv, "%d %d %g\n", i, cols[k], non_zeros[k]); @@ -295,14 +302,6 @@ jboolean min // min or max probabilities (true = min, false = max) // if the iterative method didn't terminate, this is an error if (!done) { delete soln; soln = NULL; PS_SetErrorMessage("Iterative method did not converge within %d iterations.\nConsider using a different numerical method or increasing the maximum number of iterations", iters); } - // set reward for infinity states to infinity - if (soln != NULL) { - // first, generate vector for inf - inf_vec = mtbdd_to_double_vector(ddman, inf, rvars, num_rvars, odd); - // go thru setting elements of soln to infinity - for (i = 0; i < n; i++) if (inf_vec[i] > 0) soln[i] = HUGE_VAL; - } - // close file to store adversary (if required) if (adv) { fclose(fp_adv); @@ -321,6 +320,7 @@ jboolean min // min or max probabilities (true = min, false = max) if (trans_rewards) Cudd_RecursiveDeref(ddman, trans_rewards); if (ndsm) delete ndsm; if (ndsm_r) delete ndsm_r; + if (inf_vec) delete[] inf_vec; if (sr_vec) delete[] sr_vec; if (soln2) delete[] soln2;