From 8a5b2f001dc7e1a960b3cad76430ed1c3e67143e Mon Sep 17 00:00:00 2001 From: Dave Parker Date: Wed, 3 Mar 2021 21:53:21 +0000 Subject: [PATCH] POMDP refactoring: new Belief constructors. --- prism/src/explicit/Belief.java | 35 +++++++++++++++++++++++++++-- prism/src/explicit/POMDPSimple.java | 6 ++--- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/prism/src/explicit/Belief.java b/prism/src/explicit/Belief.java index 0af35219..797c81c5 100644 --- a/prism/src/explicit/Belief.java +++ b/prism/src/explicit/Belief.java @@ -51,8 +51,8 @@ public class Belief implements Comparable /** * Constructor - * @param Observable part (index of observation) - * @param Distribution over unobservable part (probability for each unobservation) + * @param so Observable part (index of observation) + * @param bu Distribution over unobservable part (probability for each unobservation) */ public Belief(int so, double[] bu) { @@ -60,6 +60,37 @@ public class Belief implements Comparable this.bu = bu; } + /** + * Constructor + * @param dist Distribution over states of a model (probability for each state) + * @param model The (partially observable) model + * If {@code dist} is a not a valid distribution, the resulting belief will be invalid too. + */ + protected Belief(double[] dist, PartiallyObservableModel model) + { + so = -1; + bu = new double[model.getNumUnobservations()]; + for (int s = 0; s < dist.length; s++) { + if (dist[s] != 0) { + so = model.getObservation(s); + bu[model.getUnobservation(s)] += dist[s]; + } + } + } + + /** + * Construct a point distribution over a single model stete + * @param s A model state + * @param model The (partially observable) model + */ + public static Belief pointDistribution(int s, PartiallyObservableModel model) + { + int so = model.getObservation(s); + double[] bu = new double[model.getNumUnobservations()]; + bu[model.getUnobservation(s)] = 1.0; + return new Belief(so, bu); + } + /** * Convert to a probability distribution over all model states * (represented as an array of probabilities). diff --git a/prism/src/explicit/POMDPSimple.java b/prism/src/explicit/POMDPSimple.java index b19dbc48..1ab3769a 100644 --- a/prism/src/explicit/POMDPSimple.java +++ b/prism/src/explicit/POMDPSimple.java @@ -378,7 +378,7 @@ public class POMDPSimple extends MDPSimple implements POMDP initialBeliefInDist[i] = 1; } PrismUtils.normalise(initialBeliefInDist); - return beliefInDistToBelief(initialBeliefInDist); + return new Belief(initialBeliefInDist, this); } @Override @@ -397,7 +397,7 @@ public class POMDPSimple extends MDPSimple implements POMDP { double[] beliefInDist = belief.toDistributionOverStates(this); double[] nextBeliefInDist = getBeliefInDistAfterChoice(beliefInDist, i); - return beliefInDistToBelief(nextBeliefInDist); + return new Belief(nextBeliefInDist, this); } @Override @@ -423,7 +423,7 @@ public class POMDPSimple extends MDPSimple implements POMDP { double[] beliefInDist = belief.toDistributionOverStates(this); double[] nextBeliefInDist = getBeliefInDistAfterChoiceAndObservation(beliefInDist, i, o); - Belief nextBelief = beliefInDistToBelief(nextBeliefInDist); + Belief nextBelief = new Belief(nextBeliefInDist, this); assert(nextBelief.so == o); return nextBelief; }