From ad294aa98135ae4404fc253f3960aff5afabf4a4 Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Sat, 11 Aug 2012 07:43:21 +0000 Subject: [PATCH] Added dot product method to symbolic StateValue classes. git-svn-id: https://www.prismmodelchecker.org/svn/prism/prism/trunk@5543 bbc10eb1-c90d-0410-af57-cb519fbb1720 --- prism/include/DoubleVector.h | 8 ++++++++ prism/src/dv/DoubleVector.cc | 23 +++++++++++++++++++++++ prism/src/dv/DoubleVector.java | 7 +++++++ prism/src/prism/StateValues.java | 1 + prism/src/prism/StateValuesDV.java | 7 +++++++ prism/src/prism/StateValuesMTBDD.java | 14 ++++++++++++++ 6 files changed, 60 insertions(+) diff --git a/prism/include/DoubleVector.h b/prism/include/DoubleVector.h index bb88d54d..ffde33db 100644 --- a/prism/include/DoubleVector.h +++ b/prism/include/DoubleVector.h @@ -79,6 +79,14 @@ JNIEXPORT void JNICALL Java_dv_DoubleVector_DV_1Add JNIEXPORT void JNICALL Java_dv_DoubleVector_DV_1TimesConstant (JNIEnv *, jobject, jlong, jint, jdouble); +/* + * Class: dv_DoubleVector + * Method: DV_DotProduct + * Signature: (JIJ)D + */ +JNIEXPORT jdouble JNICALL Java_dv_DoubleVector_DV_1DotProduct + (JNIEnv *, jobject, jlong, jint, jlong); + /* * Class: dv_DoubleVector * Method: DV_Filter diff --git a/prism/src/dv/DoubleVector.cc b/prism/src/dv/DoubleVector.cc index 88f98dd9..8514af8c 100644 --- a/prism/src/dv/DoubleVector.cc +++ b/prism/src/dv/DoubleVector.cc @@ -212,6 +212,29 @@ jdouble d //------------------------------------------------------------------------------ +JNIEXPORT jdouble JNICALL Java_dv_DoubleVector_DV_1DotProduct +( +JNIEnv *env, +jobject obj, +jlong __jlongpointer v, +jint n, +jlong __jlongpointer v2 +) +{ + double *vector = jlong_to_double(v); + double *vector2 = jlong_to_double(v2); + int i; + double d = 0.0; + + for (i = 0; i < n; i++) { + d += vector[i] * vector2[i]; + } + + return d; +} + +//------------------------------------------------------------------------------ + JNIEXPORT void JNICALL Java_dv_DoubleVector_DV_1Filter ( JNIEnv *env, diff --git a/prism/src/dv/DoubleVector.java b/prism/src/dv/DoubleVector.java index f0d216e2..4dbe8bf5 100644 --- a/prism/src/dv/DoubleVector.java +++ b/prism/src/dv/DoubleVector.java @@ -146,6 +146,13 @@ public class DoubleVector DV_TimesConstant(v, n, d); } + // compute dot (inner) product of this and another vector + private native double DV_DotProduct(long v, int n, long v2); + public double dotProduct(DoubleVector dv) + { + return DV_DotProduct(v, n, dv.v); + } + // filter vector using a bdd (set elements not in filter to 0) private native void DV_Filter(long v, long filter, long vars, int num_vars, long odd); public void filter(JDDNode filter, JDDVars vars, ODDNode odd) diff --git a/prism/src/prism/StateValues.java b/prism/src/prism/StateValues.java index 19889fa3..5fdca00f 100644 --- a/prism/src/prism/StateValues.java +++ b/prism/src/prism/StateValues.java @@ -42,6 +42,7 @@ public interface StateValues void subtractFromOne(); void add(StateValues sp); void timesConstant(double d); + double dotProduct(StateValues sp); void filter(JDDNode filter); public void maxMTBDD(JDDNode vec2); void clear(); diff --git a/prism/src/prism/StateValuesDV.java b/prism/src/prism/StateValuesDV.java index f2ab0b87..5375883d 100644 --- a/prism/src/prism/StateValuesDV.java +++ b/prism/src/prism/StateValuesDV.java @@ -193,6 +193,13 @@ public class StateValuesDV implements StateValues values.timesConstant(d); } + // compute dot (inner) product of this and another vector + + public double dotProduct(StateValues sv) + { + return values.dotProduct(((StateValuesDV) sv).values); + } + // filter vector using a bdd (set elements not in filter to 0) public void filter(JDDNode filter) diff --git a/prism/src/prism/StateValuesMTBDD.java b/prism/src/prism/StateValuesMTBDD.java index 63791fd5..75486ab2 100644 --- a/prism/src/prism/StateValuesMTBDD.java +++ b/prism/src/prism/StateValuesMTBDD.java @@ -206,6 +206,20 @@ public class StateValuesMTBDD implements StateValues values = JDD.Apply(JDD.TIMES, values, JDD.Constant(d)); } + // compute dot (inner) product of this and another vector + + public double dotProduct(StateValues sp) + { + StateValuesMTBDD spm = (StateValuesMTBDD) sp; + JDD.Ref(values); + JDD.Ref(spm.values); + JDDNode tmp = JDD.Apply(JDD.TIMES, values, spm.values); + tmp = JDD.SumAbstract(tmp, vars); + double d = JDD.FindMax(tmp); + JDD.Deref(tmp); + return d; + } + // filter vector using a bdd (set elements not in filter to 0) public void filter(JDDNode filter)