You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
309 lines
11 KiB
309 lines
11 KiB
package explicit;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.BitSet;
|
|
import java.util.Vector;
|
|
|
|
import explicit.rewards.MCRewards;
|
|
import explicit.rewards.MDPRewards;
|
|
import explicit.rewards.Rewards;
|
|
import parser.ast.AccumulationFactor;
|
|
import parser.ast.ExpressionAccumulation;
|
|
import prism.IntegerBound;
|
|
import prism.PrismException;
|
|
|
|
/**
|
|
* An AccumulationProduct has ProductStates, where the first component is the
|
|
* stateId in the original model, and the second component is the index of an
|
|
* AccumulationTracker.
|
|
*
|
|
* @author Sascha Wunderlich
|
|
*
|
|
* @param <M>
|
|
*/
|
|
|
|
public class AccumulationProductCounting<M extends Model> extends AccumulationProduct<M,Integer>
|
|
{
|
|
|
|
public AccumulationProductCounting(M originalModel) {
|
|
super(originalModel);
|
|
}
|
|
|
|
public static AccumulationProductCounting<DTMC> generate(final DTMC graph, final ExpressionAccumulation accexp, final Vector<MCRewards> rewards, final ProbModelChecker mc, BitSet statesOfInterest) throws PrismException {
|
|
final AccumulationProductCounting<DTMC> result = new AccumulationProductCounting<DTMC>(graph);
|
|
// Create auxiliary data
|
|
result.createAuxData(graph, accexp, rewards, mc);
|
|
|
|
// Build an operator
|
|
class AccumulationDTMCProductOperator implements DTMCProductOperator
|
|
{
|
|
@Override
|
|
public ProductState getInitialState(Integer dtmc_state)
|
|
throws PrismException {
|
|
int initialAccStateId = result.createInitialStateId(rewards.size());
|
|
return new ProductState(dtmc_state, initialAccStateId);
|
|
}
|
|
|
|
@Override
|
|
public ProductState getSuccessor(ProductState from_state,
|
|
Integer dtmc_to_state) throws PrismException {
|
|
// Get the current accumulation state
|
|
AccumulationState<Integer> from_accstate = result.accStates.getById(from_state.getSecondState());
|
|
|
|
// Get step weights
|
|
double[] weights = new double[rewards.size()];
|
|
|
|
for (int i=0; i < rewards.size(); i++) {
|
|
weights[i] = rewards.get(i).getStateReward(from_state.getFirstState());
|
|
}
|
|
|
|
// Update accumulation product state, store it and get its ID.
|
|
AccumulationState<Integer> to_accproduct = result.updateAccumulationState(from_state.getFirstState(), from_accstate, accexp, weights, mc);
|
|
int to_accproduct_id = result.accStates.findOrAdd(to_accproduct);
|
|
|
|
return new ProductState(dtmc_to_state, to_accproduct_id);
|
|
}
|
|
|
|
@Override
|
|
public void notify(ProductState state, Integer index)
|
|
throws PrismException {
|
|
AccumulationState<Integer> accState = result.accStates.getById(state.getSecondState());
|
|
if (result.isGoodAccState(accState, accexp, mc)) {
|
|
result.goodStates.set(index);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void finish() throws PrismException {
|
|
// Do nothing
|
|
}
|
|
|
|
@Override
|
|
public DTMC getGraph() {
|
|
return graph;
|
|
}
|
|
}
|
|
|
|
// Apply the operator
|
|
AccumulationDTMCProductOperator op = new AccumulationDTMCProductOperator();
|
|
ProductWithProductStates.generate(op, result, statesOfInterest);
|
|
|
|
return result;
|
|
}
|
|
|
|
public static AccumulationProductCounting<MDP> generate(final MDP graph, final ExpressionAccumulation accexp, final Vector<MDPRewards> rewards, final ProbModelChecker mc, BitSet statesOfInterest) throws PrismException {
|
|
// This is basically the same thing as for DTMCs
|
|
final AccumulationProductCounting<MDP> result = new AccumulationProductCounting<MDP>(graph);
|
|
|
|
// Create auxiliary data
|
|
result.createAuxData(graph, accexp, rewards, mc);
|
|
|
|
class AccumulationMDPProductOperator implements MDPProductOperator
|
|
{
|
|
|
|
@Override
|
|
public ProductState getInitialState(final Integer MDP_state)
|
|
throws PrismException {
|
|
int initialAccStateId = result.createInitialStateId(rewards.size());
|
|
return new ProductState(MDP_state, initialAccStateId);
|
|
}
|
|
|
|
@Override
|
|
public ProductState getSuccessor(final ProductState from_state,
|
|
final int choice_i, final Integer mdp_to_state) throws PrismException {
|
|
// Get the current accumulation state
|
|
AccumulationState<Integer> from_accstate = result.accStates.getById(from_state.getSecondState());
|
|
|
|
// Get step weights
|
|
// THIS IS DIFFERENT FROM ABOVE!
|
|
double[] weights = new double[rewards.size()];
|
|
|
|
for (int i=0; i < rewards.size(); i++) {
|
|
double currentWeight = rewards.get(i).getStateReward(from_state.getFirstState());
|
|
currentWeight += rewards.get(i).getTransitionReward(from_state.getFirstState(), choice_i);
|
|
weights[i] = currentWeight;
|
|
}
|
|
|
|
// Update accumulation product state, store it and get its ID.
|
|
AccumulationState<Integer> to_accproduct = result.updateAccumulationState(from_state.getFirstState(), from_accstate, accexp, weights, mc);
|
|
int to_accproduct_id = result.accStates.findOrAdd(to_accproduct);
|
|
return new ProductState(mdp_to_state, to_accproduct_id);
|
|
}
|
|
|
|
@Override
|
|
public void notify(final ProductState state, final Integer index)
|
|
throws PrismException {
|
|
AccumulationState<Integer> accState = result.accStates.getById(state.getSecondState());
|
|
if (result.isGoodAccState(accState, accexp, mc)) {
|
|
result.goodStates.set(index);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void finish() throws PrismException {
|
|
// Do nothing
|
|
mc.getLog().println(".");
|
|
}
|
|
|
|
@Override
|
|
public MDP getGraph() {
|
|
return graph;
|
|
}
|
|
}
|
|
|
|
AccumulationMDPProductOperator op = new AccumulationMDPProductOperator();
|
|
ProductWithProductStates.generate(op, result, statesOfInterest);
|
|
return result;
|
|
}
|
|
|
|
private boolean isFinalTrack(final AccumulationTrack<Integer> track, final ExpressionAccumulation accexp, final ProbModelChecker mc)
|
|
throws PrismException {
|
|
boolean isFinal = false;
|
|
if ( track != null ) {
|
|
IntegerBound stepBound = IntegerBound.fromTemporalOperatorBound(accexp.getBoundExpression(), mc.getConstantValues(), true);
|
|
isFinal = stepBound.isInBounds(track.getComponent());
|
|
}
|
|
return isFinal;
|
|
}
|
|
|
|
private boolean isGoodAccState(final AccumulationState<Integer> state, final ExpressionAccumulation accexp, final ProbModelChecker mc)
|
|
throws PrismException {
|
|
return state.hasGoodTrack();
|
|
}
|
|
|
|
private boolean isGoodTrack(final AccumulationTrack<Integer> track, final ExpressionAccumulation accexp, final ProbModelChecker mc)
|
|
throws PrismException {
|
|
// Only final tracks can be good
|
|
if (!isFinalTrack(track,accexp,mc)) { return false; }
|
|
boolean isGood = false;
|
|
|
|
// Collect the weight linear combination, factor*weight+...
|
|
int lhs = 0;
|
|
int factorNr = 0;
|
|
for (AccumulationFactor factor : accexp.getConstraint().getFactors()) {
|
|
lhs += factor.getFactor().evaluateInt(mc.getConstantValues())
|
|
* track.getWeight(factorNr);
|
|
}
|
|
|
|
// Check the bound
|
|
IntegerBound rhs = IntegerBound.fromTemporalOperatorBound(accexp.getConstraint().getBound(), mc.getConstantValues(), true);
|
|
|
|
// For DIA operators, we just check the bound.
|
|
// For BOX operators, we check the INVERTED bound.
|
|
switch(accexp.getSymbol()) {
|
|
case ACCBOXMINUS:
|
|
case ACCBOXPLUS:
|
|
if (!rhs.isInBounds(lhs)) {
|
|
isGood = true;
|
|
}
|
|
break;
|
|
case ACCDIAMINUS:
|
|
case ACCDIAPLUS:
|
|
if (rhs.isInBounds(lhs)) {
|
|
isGood = true;
|
|
}
|
|
break;
|
|
default:
|
|
throw new RuntimeException("Oh boy!");
|
|
}
|
|
if(isGood) {mc.getLog().print("+");} else {mc.getLog().print("-");}
|
|
return isGood;
|
|
}
|
|
|
|
private AccumulationState<Integer> updateAccumulationState(final int modelFromStateId,
|
|
final AccumulationState<Integer> accstate, final ExpressionAccumulation accexp,
|
|
final double[] weights, final ProbModelChecker mc) throws PrismException {
|
|
// We have the current accumulation state, the current model id and the accumulation expression.
|
|
|
|
// Get the old tracker and tracks.
|
|
AccumulationTracker<Integer> oldTracker = accstate.getTracker(trackers);
|
|
ArrayList<AccumulationTrack<Integer>> oldTracks = oldTracker.getTracks();
|
|
|
|
BitSet oldGoodTracks = accstate.getGoodTracks();
|
|
BitSet newGoodTracks = (BitSet) oldGoodTracks.clone();
|
|
|
|
// This restart will be...
|
|
int newLastRestartNr = accstate.getNextRestartNr();
|
|
mc.getLog().print(newLastRestartNr);
|
|
|
|
// Build the new tracks.
|
|
ArrayList<AccumulationTrack<Integer>> newTracks = new ArrayList<>();
|
|
|
|
int trackNr = 0;
|
|
for(AccumulationTrack<Integer> oldTrack : oldTracks) {
|
|
AccumulationTrack<Integer> newTrack;
|
|
|
|
// restart or advance
|
|
if(trackNr == newLastRestartNr) {
|
|
//assert oldTrack == null : "Track " + newLastRestartNr + " is not null!";
|
|
newTrack = new AccumulationTrack<Integer>(numberOfWeights, 0); //TODO: off-by-one?
|
|
newGoodTracks.clear(trackNr);
|
|
} else if (oldTrack == null) {
|
|
newTrack = null;
|
|
} else {
|
|
assert oldTrack != null;
|
|
newTrack = updateTrackBounds(oldTrack, accexp, weights, mc);
|
|
}
|
|
|
|
// check whether the track is good
|
|
if(!newGoodTracks.get(trackNr)) {
|
|
newGoodTracks.set(trackNr, isGoodTrack(newTrack, accexp, mc));
|
|
}
|
|
|
|
newTracks.add(newTrack);
|
|
trackNr++;
|
|
}
|
|
|
|
AccumulationTracker<Integer> newTracker = new AccumulationTracker<>(newTracks);
|
|
|
|
|
|
int newTrackerId = trackers.findOrAdd(newTracker);
|
|
|
|
return new AccumulationState<>(newTrackerId, newLastRestartNr, numberOfTracks, newGoodTracks);
|
|
}
|
|
|
|
private AccumulationTrack<Integer> updateTrackBounds(final AccumulationTrack<Integer> track,
|
|
final ExpressionAccumulation accexp, final double[] weights, final StateModelChecker mc) throws PrismException {
|
|
int currentStep = track.getComponent();
|
|
int maxStep = IntegerBound.fromTemporalOperatorBound(accexp.getBoundExpression(), mc.getConstantValues(), true).getHighestInteger();
|
|
|
|
// If we are done, return null-Track
|
|
if (currentStep >= maxStep) { return null; }
|
|
|
|
// Otherwise, we update the weights and increase the step.
|
|
double[] newweights = new double[weights.length];
|
|
for (int i = 0; i < weights.length; i++) {
|
|
newweights[i] = weights[i] + track.getWeights()[i];
|
|
}
|
|
|
|
return new AccumulationTrack<Integer>(newweights, currentStep+1);
|
|
}
|
|
|
|
protected int createInitialStateId(final int numberOfRewards) {
|
|
// The initial active track is the first one, all tracks are non-good by default
|
|
int initialActiveTrack = 0;
|
|
BitSet initialGoodTracks = new BitSet();
|
|
|
|
// Generate the initial tracker and product state
|
|
AccumulationTracker<Integer> initialTracker = new AccumulationTracker<>(numberOfTracks, numberOfRewards, 0);
|
|
int initialTrackerId = trackers.findOrAdd(initialTracker);
|
|
AccumulationState<Integer> initialAccState = new AccumulationState<>(initialTrackerId, initialActiveTrack, numberOfTracks, initialGoodTracks);
|
|
int initialAccStateId = accStates.findOrAdd(initialAccState);
|
|
|
|
return initialAccStateId;
|
|
}
|
|
|
|
/**
|
|
* Populates fields:
|
|
* - numberOfTracks
|
|
* @param graph
|
|
* @param accexp
|
|
* @param rewards
|
|
* @param mc
|
|
* @throws PrismException
|
|
*/
|
|
protected void createAuxData(final Model graph, final ExpressionAccumulation accexp, final Vector<? extends Rewards> rewards, final ProbModelChecker mc) throws PrismException {
|
|
numberOfTracks = IntegerBound.fromTemporalOperatorBound(accexp.getBoundExpression(), mc.getConstantValues(), true).getHighestInteger()+1;
|
|
numberOfWeights = rewards.size();
|
|
}
|
|
}
|