diff --git a/prism/include/DoubleVector.h b/prism/include/DoubleVector.h
index 4ff98740..3d7ea5a1 100644
--- a/prism/include/DoubleVector.h
+++ b/prism/include/DoubleVector.h
@@ -98,10 +98,10 @@ JNIEXPORT jdouble JNICALL Java_dv_DoubleVector_DV_1DotProduct
/*
* Class: dv_DoubleVector
* Method: DV_Filter
- * Signature: (JJJIJ)V
+ * Signature: (JJDJIJ)V
*/
JNIEXPORT void JNICALL Java_dv_DoubleVector_DV_1Filter
- (JNIEnv *, jobject, jlong, jlong, jlong, jint, jlong);
+ (JNIEnv *, jobject, jlong, jlong, jdouble, jlong, jint, jlong);
/*
* Class: dv_DoubleVector
diff --git a/prism/include/dv.h b/prism/include/dv.h
index b92dc6fd..2b39ea1a 100644
--- a/prism/include/dv.h
+++ b/prism/include/dv.h
@@ -72,7 +72,7 @@ EXPORT DdNode *double_vector_to_mtbdd(DdManager *ddman, double *vec, DdNode **va
EXPORT DdNode *double_vector_to_bdd(DdManager *ddman, double *vec, int rel_op, double value, DdNode **vars, int num_vars, ODDNode *odd);
EXPORT DdNode *double_vector_to_bdd(DdManager *ddman, double *vec, int rel_op, double value1, double value2, DdNode **vars, int num_vars, ODDNode *odd);
-EXPORT void filter_double_vector(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, ODDNode *odd);
+EXPORT void filter_double_vector(DdManager *ddman, double *vec, DdNode *filter, double d, DdNode **vars, int num_vars, ODDNode *odd);
EXPORT void max_double_vector_mtbdd(DdManager *ddman, double *vec, DdNode *vec2, DdNode **vars, int num_vars, ODDNode *odd);
EXPORT double get_first_from_bdd(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, ODDNode *odd);
EXPORT double min_double_vector_over_bdd(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, ODDNode *odd);
diff --git a/prism/src/dv/DoubleVector.cc b/prism/src/dv/DoubleVector.cc
index bdac6bbb..a861b337 100644
--- a/prism/src/dv/DoubleVector.cc
+++ b/prism/src/dv/DoubleVector.cc
@@ -258,6 +258,7 @@ JNIEnv *env,
jobject obj,
jlong __jlongpointer vector,
jlong __jlongpointer filter,
+jdouble d,
jlong __jlongpointer vars,
jint num_vars,
jlong __jlongpointer odd
@@ -267,6 +268,7 @@ jlong __jlongpointer odd
ddman,
jlong_to_double(vector),
jlong_to_DdNode(filter),
+ d,
jlong_to_DdNode_array(vars), num_vars,
jlong_to_ODDNode(odd)
);
diff --git a/prism/src/dv/DoubleVector.java b/prism/src/dv/DoubleVector.java
index 85f05120..abec1a8a 100644
--- a/prism/src/dv/DoubleVector.java
+++ b/prism/src/dv/DoubleVector.java
@@ -199,11 +199,24 @@ public class DoubleVector
return DV_DotProduct(v, n, dv.v);
}
- // filter vector using a bdd (set elements not in filter to 0)
- private native void DV_Filter(long v, long filter, long vars, int num_vars, long odd);
+
+ private native void DV_Filter(long v, long filter, double d, long vars, int num_vars, long odd);
+ /**
+ * Filter vector using a bdd (set elements not in filter to d)
+ *
[ REFS: result, DEREFS: none ]
+ */
+ public void filter(JDDNode filter, double d, JDDVars vars, ODDNode odd)
+ {
+ DV_Filter(v, filter.ptr(), d, vars.array(), vars.n(), odd.ptr());
+ }
+
+ /**
+ * Filter vector using a bdd (set elements not in filter to 0)
+ *
[ REFS: none, DEREFS: none ]
+ */
public void filter(JDDNode filter, JDDVars vars, ODDNode odd)
{
- DV_Filter(v, filter.ptr(), vars.array(), vars.n(), odd.ptr());
+ DV_Filter(v, filter.ptr(), 0.0, vars.array(), vars.n(), odd.ptr());
}
// apply max operator, i.e. v[i] = max(v[i], v2[i]), where v2 is an mtbdd
diff --git a/prism/src/dv/dv.cc b/prism/src/dv/dv.cc
index f28e999a..d9e79af4 100644
--- a/prism/src/dv/dv.cc
+++ b/prism/src/dv/dv.cc
@@ -34,7 +34,7 @@
static void mtbdd_to_double_vector_rec(DdManager *ddman, DdNode *dd, DdNode **vars, int num_vars, int level, ODDNode *odd, long o, double *res);
static DdNode *double_vector_to_mtbdd_rec(DdManager *ddman, double *vec, DdNode **vars, int num_vars, int level, ODDNode *odd, long o);
static DdNode *double_vector_to_bdd_rec(DdManager *ddman, double *vec, int rel_op, double value1, double value2, DdNode **vars, int num_vars, int level, ODDNode *odd, long o);
-static void filter_double_vector_rec(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, int level, ODDNode *odd, long o);
+static void filter_double_vector_rec(DdManager *ddman, double *vec, DdNode *filter, double d, DdNode **vars, int num_vars, int level, ODDNode *odd, long o);
static void max_double_vector_mtbdd_rec(DdManager *ddman, double *vec, DdNode *vec2, DdNode **vars, int num_vars, int level, ODDNode *odd, long o);
static double get_first_from_bdd_rec(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, int level, ODDNode *odd, long o);
static double min_double_vector_over_bdd_rec(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, int level, ODDNode *odd, long o);
@@ -208,33 +208,33 @@ DdNode *double_vector_to_bdd_rec(DdManager *ddman, double *vec, int rel_op, doub
//------------------------------------------------------------------------------
-// filter vector using a bdd (set elements not in filter to 0)
+// filter vector using a bdd (set elements not in filter to constant d)
-EXPORT void filter_double_vector(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, ODDNode *odd)
+EXPORT void filter_double_vector(DdManager *ddman, double *vec, DdNode *filter, double d, DdNode **vars, int num_vars, ODDNode *odd)
{
- filter_double_vector_rec(ddman, vec, filter, vars, num_vars, 0, odd, 0);
+ filter_double_vector_rec(ddman, vec, filter, d, vars, num_vars, 0, odd, 0);
}
-void filter_double_vector_rec(DdManager *ddman, double *vec, DdNode *filter, DdNode **vars, int num_vars, int level, ODDNode *odd, long o)
+void filter_double_vector_rec(DdManager *ddman, double *vec, DdNode *filter, double d, DdNode **vars, int num_vars, int level, ODDNode *odd, long o)
{
DdNode *dd;
if (level == num_vars) {
if (Cudd_V(filter) == 0) {
- vec[o] = 0;
+ vec[o] = d;
}
}
else {
if (odd->eoff > 0) {
dd = (filter->index > vars[level]->index) ? filter : Cudd_E(filter);
- filter_double_vector_rec(ddman, vec, dd, vars, num_vars, level+1, odd->e, o);
+ filter_double_vector_rec(ddman, vec, dd, d, vars, num_vars, level+1, odd->e, o);
}
if (odd->toff > 0) {
dd = (filter->index > vars[level]->index) ? filter : Cudd_T(filter);
- filter_double_vector_rec(ddman, vec, dd, vars, num_vars, level+1, odd->t, o+odd->eoff);
+ filter_double_vector_rec(ddman, vec, dd, d, vars, num_vars, level+1, odd->t, o+odd->eoff);
}
}
-}
+}
//------------------------------------------------------------------------------
diff --git a/prism/src/prism/StateValues.java b/prism/src/prism/StateValues.java
index 299ce131..ac4b876f 100644
--- a/prism/src/prism/StateValues.java
+++ b/prism/src/prism/StateValues.java
@@ -77,6 +77,12 @@ public interface StateValues extends StateVector
*/
void filter(JDDNode filter);
+ /**
+ * Filter this vector using a BDD (set elements not in filter to constant {@code d}).
+ *
[ DEREFS: none ]
+ */
+ void filter(JDDNode filter, double d);
+
/**
* Apply max operator, i.e. vec[i] = max(vec[i], vec2[i]), where vec2 is an MTBDD
*
[ DEREFS: none ]
diff --git a/prism/src/prism/StateValuesDV.java b/prism/src/prism/StateValuesDV.java
index e625ef3e..d318a2f4 100644
--- a/prism/src/prism/StateValuesDV.java
+++ b/prism/src/prism/StateValuesDV.java
@@ -233,6 +233,12 @@ public class StateValuesDV implements StateValues
values.filter(filter, vars, odd);
}
+ @Override
+ public void filter(JDDNode filter, double d)
+ {
+ values.filter(filter, d, vars, odd);
+ }
+
@Override
public void maxMTBDD(JDDNode vec2)
{
diff --git a/prism/src/prism/StateValuesMTBDD.java b/prism/src/prism/StateValuesMTBDD.java
index 3bfd2211..0a49ac16 100644
--- a/prism/src/prism/StateValuesMTBDD.java
+++ b/prism/src/prism/StateValuesMTBDD.java
@@ -246,7 +246,15 @@ public class StateValuesMTBDD implements StateValues
JDD.Ref(filter);
values = JDD.Apply(JDD.TIMES, values, filter);
}
-
+
+ @Override
+ public void filter(JDDNode filter, double d)
+ {
+ // If filter, then keep value, else constant d,
+ // but only for the reachable states
+ values = JDD.Times(reach.copy(), JDD.ITE(filter.copy(), values, JDD.Constant(d)));
+ }
+
@Override
public void maxMTBDD(JDDNode vec2)
{
diff --git a/prism/src/prism/StateValuesVoid.java b/prism/src/prism/StateValuesVoid.java
index 19644324..c184f68a 100644
--- a/prism/src/prism/StateValuesVoid.java
+++ b/prism/src/prism/StateValuesVoid.java
@@ -125,6 +125,12 @@ public class StateValuesVoid implements StateValues
throw new UnsupportedOperationException();
}
+ @Override
+ public void filter(JDDNode filter, double d)
+ {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public void maxMTBDD(JDDNode vec2)
{