You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

195 lines
6.4 KiB

package explicit;
import java.util.BitSet;
import explicit.rewards.Rewards;
import prism.IntegerBound;
import prism.PrismException;
/**
* An AccumulationProduct has ProductStates, where the first component is the
* stateId in the original model, and the second component is the index of an
* AccumulationTracker.
*
* @author Sascha Wunderlich
*
* @param <M>
*/
public class AccumulationProductSimpleCounting<M extends Model> extends AccumulationProductSimple<M,Integer>
{
public AccumulationProductSimpleCounting(M originalModel, AccumulationContext ctx) {
super(originalModel, ctx);
}
@SuppressWarnings("unchecked")
public static <T extends Model, R extends Rewards> AccumulationProductSimpleCounting<T> generate(final Model graph, AccumulationContext ctx, BitSet statesOfInterest) throws PrismException {
switch(graph.getModelType()) {
case DTMC:
return (AccumulationProductSimpleCounting<T>)generate((DTMC) graph, ctx, statesOfInterest);
case MDP:
return (AccumulationProductSimpleCounting<T>)generate((MDP) graph, ctx, statesOfInterest);
default:
throw new PrismException("Can't handle accumulation product for " + graph.getModelType());
}
}
public static AccumulationProductSimpleCounting<DTMC> generate(final DTMC graph, AccumulationContext ctx, BitSet statesOfInterest) throws PrismException {
final AccumulationProductSimpleCounting<DTMC> result = new AccumulationProductSimpleCounting<DTMC>(graph,ctx);
// Create auxiliary data
result.createAuxData(graph);
// Build an operator
class AccumulationDTMCProductOperator implements DTMCProductOperator
{
@Override
public ProductState getInitialState(Integer dtmc_state)
throws PrismException {
int initialAccStateId = result.createInitialStateId();
return new ProductState(dtmc_state, initialAccStateId);
}
@Override
public ProductState getSuccessor(ProductState from_state, Integer dtmc_to_state) throws PrismException {
// Get the current accumulation state
AccumulationState<Integer> from_accstate = result.accStates.getById(from_state.getSecondState());
// Get step weights
double[] weights = ctx.getWeights(from_state.getFirstState());
// Update accumulation product state, store it and get its ID.
AccumulationState<Integer> to_accproduct = result.updateAccumulationState(from_state.getFirstState(), from_accstate, weights);
int to_accproduct_id = result.accStates.findOrAdd(to_accproduct);
return new ProductState(dtmc_to_state, to_accproduct_id);
}
@Override
public void notify(ProductState state, Integer index)
throws PrismException {
result.generateStateInfo(state, index);
}
@Override
public void finish() throws PrismException {
// Do nothing
}
@Override
public DTMC getGraph() {
return graph;
}
}
// Apply the operator
AccumulationDTMCProductOperator op = new AccumulationDTMCProductOperator();
ProductWithProductStates.generate(op, result, statesOfInterest);
return result;
}
public static AccumulationProductSimpleCounting<MDP> generate(final MDP graph, AccumulationContext ctx, BitSet statesOfInterest) throws PrismException {
// This is basically the same thing as for DTMCs
final AccumulationProductSimpleCounting<MDP> result = new AccumulationProductSimpleCounting<MDP>(graph, ctx);
// Create auxiliary data
result.createAuxData(graph);
class AccumulationMDPProductOperator implements MDPProductOperator
{
@Override
public ProductState getInitialState(final Integer MDP_state)
throws PrismException {
int initialAccStateId = result.createInitialStateId();
return new ProductState(MDP_state, initialAccStateId);
}
@Override
public ProductState getSuccessor(final ProductState from_state, final int choice_i, final Integer mdp_to_state) throws PrismException {
// Get the current accumulation state
AccumulationState<Integer> from_accstate = result.accStates.getById(from_state.getSecondState());
// Get step weights
// THIS IS DIFFERENT FROM ABOVE!
double[] weights = ctx.getWeights(from_state.getFirstState(), choice_i);
// Update accumulation product state, store it and get its ID.
AccumulationState<Integer> to_accproduct = result.updateAccumulationState(from_state.getFirstState(), from_accstate, weights);
int to_accproduct_id = result.accStates.findOrAdd(to_accproduct);
return new ProductState(mdp_to_state, to_accproduct_id);
}
@Override
public void notify(final ProductState state, final Integer index)
throws PrismException {
result.generateStateInfo(state, index);
}
@Override
public void finish() throws PrismException {
// Do nothing
//mc.getLog().println(".");
}
@Override
public MDP getGraph() {
return graph;
}
}
AccumulationMDPProductOperator op = new AccumulationMDPProductOperator();
ProductWithProductStates.generate(op, result, statesOfInterest);
return result;
}
@Override
protected boolean isFinalTrack(final AccumulationTrack<Integer> track) throws PrismException {
boolean isFinal = false;
if ( track != null ) {
Integer step = track.getComponent();
if ( step > 0 ) { // if the step is 0, we cannot have a goal state.
IntegerBound stepBound = IntegerBound.fromTemporalOperatorBound(ctx.accexp.getBoundExpression(), ctx.mc.getConstantValues(), true);
isFinal = stepBound.isInBounds(step);
}
}
return isFinal;
}
@Override
protected Integer getInitialComponent() {
return 0;
}
@Override
protected Integer updateComponent(Integer modelFromStateId, final AccumulationTrack<Integer> track) {
int currentStep = track.getComponent();
int maxStep = 0;
try {
maxStep = IntegerBound.fromTemporalOperatorBound(ctx.accexp.getBoundExpression(), ctx.mc.getConstantValues(), true).getHighestInteger();
} catch(PrismException e) {
throw new RuntimeException("...");
}
int nextStep = currentStep+1;
if(nextStep > maxStep) { return null; }
return nextStep;
}
/**
* Populates fields:
* - numberOfTracks with the highest relevant integer plus one, and
* - numberOfWeights with the size of the reward vector.
* @param graph
* @param accexp
* @param rewards
* @param mc
* @throws PrismException
*/
protected void createAuxData(final Model graph) throws PrismException {
numberOfTracks = IntegerBound.fromTemporalOperatorBound(ctx.accexp.getBoundExpression(), ctx.mc.getConstantValues(), true).getHighestInteger()+1;
numberOfWeights = ctx.rewards.size();
}
}