You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
312 lines
8.3 KiB
312 lines
8.3 KiB
//==============================================================================
|
|
//
|
|
// Copyright (c) 2016-
|
|
// Authors:
|
|
// * Joachim Klein <klein@tcs.inf.tu-dresden.de> (TU Dresden)
|
|
//
|
|
//------------------------------------------------------------------------------
|
|
//
|
|
// This file is part of PRISM.
|
|
//
|
|
// PRISM is free software; you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation; either version 2 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// PRISM is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with PRISM; if not, write to the Free Software Foundation,
|
|
// Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
|
//
|
|
//==============================================================================
|
|
|
|
package explicit;
|
|
|
|
import java.util.BitSet;
|
|
import java.util.HashMap;
|
|
import java.util.HashSet;
|
|
import java.util.PriorityQueue;
|
|
import java.util.function.IntPredicate;
|
|
|
|
import common.StopWatch;
|
|
import common.IterableBitSet;
|
|
import explicit.IncomingChoiceRelation.Choice;
|
|
import explicit.rewards.MDPRewards;
|
|
import prism.PrismComponent;
|
|
|
|
/**
|
|
* An implementation of the upper bound computation for Rmin as detailed in the
|
|
* paper McMahan, Likhachev, Gordon "Bounded Real-Time Dynamic Programming:
|
|
* RTDP with monotone upper bounds and performance guarantees" (International
|
|
* Conference on Machine Learning, 2005).
|
|
* */
|
|
public class DijkstraSweepMPI {
|
|
|
|
private static class QueueEntry implements Comparable<QueueEntry> {
|
|
public int y;
|
|
public double p;
|
|
public double w;
|
|
|
|
public QueueEntry(int y, double p, double w)
|
|
{
|
|
this.y = y;
|
|
this.p = p;
|
|
this.w = w;
|
|
}
|
|
|
|
@Override
|
|
public int compareTo(QueueEntry o)
|
|
{
|
|
int r = Double.compare(p, o.p);
|
|
if (r == 0) {
|
|
return Double.compare(w, o.w);
|
|
} else {
|
|
return r;
|
|
}
|
|
}
|
|
}
|
|
|
|
private static class ChoiceValues {
|
|
public double p;
|
|
public double w;
|
|
|
|
public ChoiceValues(double p, double w)
|
|
{
|
|
this.p = p;
|
|
this.w = w;
|
|
}
|
|
}
|
|
|
|
private static boolean debug = false;
|
|
private MDP mdp;
|
|
private MDPRewards rewards;
|
|
private PriorityQueue<QueueEntry> queue;
|
|
private double[] pState;
|
|
private double[] wState;
|
|
private HashMap<Choice, ChoiceValues> choiceValues = new HashMap<Choice, ChoiceValues>();
|
|
private QueueEntry[] pri;
|
|
private int[] pi;
|
|
private BitSet unknown, target;
|
|
private BitSet fin = new BitSet();
|
|
private IncomingChoiceRelation incoming;
|
|
private double lambda;
|
|
|
|
private DijkstraSweepMPI(PrismComponent parent, MDP mdp, MDPRewards rewards, BitSet target, BitSet unknown)
|
|
{
|
|
this.mdp = mdp;
|
|
this.unknown = unknown;
|
|
this.target = target;
|
|
this.rewards = rewards;
|
|
|
|
incoming = IncomingChoiceRelation.forModel(parent, mdp);
|
|
|
|
queue = new PriorityQueue<QueueEntry>();
|
|
pState = new double[mdp.getNumStates()];
|
|
wState = new double[mdp.getNumStates()];
|
|
pri = new QueueEntry[mdp.getNumStates()];
|
|
pi = new int[mdp.getNumStates()];
|
|
|
|
for (int s : IterableBitSet.getSetBits(unknown)) {
|
|
for (int choice = 0, numChoices = mdp.getNumChoices(s); choice < numChoices; choice++) {
|
|
Choice c = new Choice(s, choice);
|
|
double rew = rewards.getStateReward(s);
|
|
rew += rewards.getTransitionReward(s, choice);
|
|
choiceValues.put(c, new ChoiceValues(0.0, rew));
|
|
}
|
|
}
|
|
|
|
for (int s : IterableBitSet.getSetBits(target)) {
|
|
pState[s] = 1.0;
|
|
}
|
|
|
|
HashSet<Choice> preTarget = new HashSet<Choice>();
|
|
for (int t : IterableBitSet.getSetBits(target)) {
|
|
for (Choice c : incoming.getIncomingChoices(t)) {
|
|
boolean newChoice = preTarget.add(c);
|
|
if (newChoice) {
|
|
if (!unknown.get(c.getState())) {
|
|
continue;
|
|
}
|
|
if (!validChoice(c)) {
|
|
continue;
|
|
}
|
|
|
|
update(c, target);
|
|
}
|
|
}
|
|
}
|
|
preTarget.clear();
|
|
|
|
sweep();
|
|
|
|
computeLambda();
|
|
}
|
|
|
|
private void sweep()
|
|
{
|
|
while (!queue.isEmpty()) {
|
|
int x = queue.poll().y;
|
|
if (fin.get(x)) {
|
|
// already handled
|
|
continue;
|
|
}
|
|
|
|
fin.set(x);
|
|
ChoiceValues v = choiceValues.get(new Choice(x, pi[x]));
|
|
wState[x] = v.w;
|
|
pState[x] = v.p;
|
|
|
|
for (Choice c : incoming.getIncomingChoices(x)) {
|
|
if (fin.get(c.getState())) {
|
|
// already handled, skip
|
|
continue;
|
|
}
|
|
|
|
if (!unknown.get(c.getState())) {
|
|
// uninteresting state
|
|
continue;
|
|
}
|
|
|
|
if (!validChoice(c)) {
|
|
// some successor go outside unknown U target (e.g., to some infinity or undefined state)
|
|
// skip
|
|
continue;
|
|
}
|
|
|
|
// a relevant choice, update
|
|
update(c, x);
|
|
}
|
|
}
|
|
}
|
|
|
|
private boolean validChoice(Choice choice)
|
|
{
|
|
IntPredicate outsideRelevant = (int t) -> {
|
|
if (unknown.get(t) || target.get(t)) return false;
|
|
return true;
|
|
};
|
|
|
|
return !mdp.someSuccessorsMatch(choice.getState(), choice.getChoice(), outsideRelevant);
|
|
}
|
|
|
|
private void update(Choice choice, int x)
|
|
{
|
|
double w_x = wState[x];
|
|
// compute P^a_yx * w(x)
|
|
double Pw = mdp.sumOverTransitions(choice.getState(), choice.getChoice(), (int s, int t, double p) -> {
|
|
if (t != x) return 0.0;
|
|
return p * w_x;
|
|
});
|
|
|
|
double p_x = pState[x];
|
|
// compute P^a_yx * p_g(x)
|
|
double Pp = mdp.sumOverTransitions(choice.getState(), choice.getChoice(), (int s, int t, double p) -> {
|
|
if (t != x) return 0.0;
|
|
return p * p_x;
|
|
});
|
|
|
|
ChoiceValues c = choiceValues.get(choice);
|
|
assert(c != null);
|
|
c.p += Pp;
|
|
c.w += Pw;
|
|
|
|
QueueEntry newPri = new QueueEntry(choice.getState(), 1 - c.p, c.w);
|
|
if (pri[choice.getState()] == null || newPri.compareTo(pri[choice.getState()]) < 0) {
|
|
pri[choice.getState()] = newPri;
|
|
pi[choice.getState()] = choice.getChoice();
|
|
queue.add(newPri);
|
|
}
|
|
}
|
|
|
|
private void update(Choice choice, BitSet target)
|
|
{
|
|
// compute P^a_y->target
|
|
double Pp = mdp.sumOverTransitions(choice.getState(), choice.getChoice(), (int s, int t, double p) -> {
|
|
if (target.get(t)) return p;
|
|
return 0.0;
|
|
});
|
|
|
|
ChoiceValues c = choiceValues.get(choice);
|
|
c.p += Pp;
|
|
|
|
QueueEntry newPri = new QueueEntry(choice.getState(), 1 - c.p, c.w);
|
|
if (pri[choice.getState()] == null || newPri.compareTo(pri[choice.getState()]) < 0) {
|
|
pri[choice.getState()] = newPri;
|
|
pi[choice.getState()] = choice.getChoice();
|
|
queue.add(newPri);
|
|
}
|
|
}
|
|
|
|
private double computeLambda()
|
|
{
|
|
lambda = 0.0;
|
|
|
|
for (int x : IterableBitSet.getSetBits(unknown)) {
|
|
int a = pi[x];
|
|
double lambda_x_a = Double.POSITIVE_INFINITY;
|
|
|
|
// check condition (I)
|
|
double I_sum = mdp.sumOverTransitions(x, a, (int s, int t, double p) -> {
|
|
return p * pState[t];
|
|
});
|
|
|
|
if (pState[x] < I_sum) {
|
|
// condition (I) holds
|
|
double den = rewards.getStateReward(x) + rewards.getTransitionReward(x, a); // c(x,a)
|
|
den += mdp.sumOverTransitions(x, a, (int s, int t, double p) -> {
|
|
return p * wState[t];
|
|
});
|
|
den -= wState[x];
|
|
|
|
double num = mdp.sumOverTransitions(x, a, (int s, int t, double p) -> {
|
|
return p * pState[t];
|
|
});
|
|
num -= pState[x];
|
|
|
|
lambda_x_a = den / num;
|
|
} else {
|
|
// TODO: check condition (II)
|
|
lambda_x_a = 0;
|
|
}
|
|
|
|
lambda = Double.max(lambda, lambda_x_a);
|
|
}
|
|
|
|
return lambda;
|
|
}
|
|
|
|
public static double[] computeUpperBounds(PrismComponent parent, MDP mdp, MDPRewards rewards, BitSet target, BitSet unknown)
|
|
{
|
|
StopWatch timer = new StopWatch(parent.getLog());
|
|
timer.start("computing upper bound(s) for Rmin using the DSI-MP algorithm");
|
|
|
|
parent.getLog().println("Computing upper bound(s) for Rmin using the Dijkstra Sweep for Monotone Pessimistic Initialization algorithm...");
|
|
double[] upperBounds = new double[mdp.getNumStates()];
|
|
DijkstraSweepMPI dsmpi = new DijkstraSweepMPI(parent, mdp, rewards, target, unknown);
|
|
|
|
for (int x : IterableBitSet.getSetBits(unknown)) {
|
|
upperBounds[x] = dsmpi.wState[x] + dsmpi.lambda*(1 - dsmpi.pState[x]);
|
|
}
|
|
|
|
if (debug) {
|
|
parent.getLog().println(upperBounds);
|
|
}
|
|
|
|
timer.stop();
|
|
return upperBounds;
|
|
}
|
|
|
|
public static double computeUpperBound(PrismComponent parent, MDP mdp, MDPRewards rewards, BitSet target, BitSet unknown)
|
|
{
|
|
double bound = 0.0;
|
|
final double[] upperBounds = computeUpperBounds(parent, mdp, rewards, target, unknown);
|
|
for (int s : IterableBitSet.getSetBits(unknown)) {
|
|
bound = Double.max(bound, upperBounds[s]);
|
|
}
|
|
return bound;
|
|
}
|
|
}
|