Browse Source

Enable simulation of POMDPs via SimulatorEngine.

accumulation-v4.7
Dave Parker 5 years ago
parent
commit
d78b627bf4
  1. 6
      prism/src/cex/CexPathAsBDDs.java
  2. 6
      prism/src/cex/CexPathStates.java
  3. 15
      prism/src/simulator/Path.java
  4. 27
      prism/src/simulator/PathFull.java
  5. 6
      prism/src/simulator/PathFullInfo.java
  6. 12
      prism/src/simulator/PathFullPrefix.java
  7. 25
      prism/src/simulator/PathOnTheFly.java
  8. 23
      prism/src/simulator/SimulatorEngine.java

6
prism/src/cex/CexPathAsBDDs.java

@ -85,6 +85,12 @@ public class CexPathAsBDDs implements PathFullInfo
return model.convertBddToState(states.get(step)); return model.convertBddToState(states.get(step));
} }
@Override
public State getObservation(int step)
{
return null;
}
@Override @Override
public double getStateReward(int step, int rsi) public double getStateReward(int step, int rsi)
{ {

6
prism/src/cex/CexPathStates.java

@ -79,6 +79,12 @@ public class CexPathStates implements PathFullInfo
return states.get(step); return states.get(step);
} }
@Override
public State getObservation(int step)
{
return null;
}
@Override @Override
public double getStateReward(int step, int rsi) public double getStateReward(int step, int rsi)
{ {

15
prism/src/simulator/Path.java

@ -37,22 +37,22 @@ public abstract class Path
// MUTATORS // MUTATORS
/** /**
* Initialise the path with an initial state and rewards.
* Note: State object and array will be copied, not stored directly.
* Initialise the path with an initial state, observation and rewards.
* Note: State objects and array will be copied, not stored directly.
*/ */
public abstract void initialise(State initialState, double[] initialStateRewards);
public abstract void initialise(State initialState, State initialObs, double[] initialStateRewards);
/** /**
* Add a step to the path. * Add a step to the path.
* Note: State object and arrays will be copied, not stored directly. * Note: State object and arrays will be copied, not stored directly.
*/ */
public abstract void addStep(int choice, Object action, String actionString, double probability, double[] transRewards, State newState, double[] newStateRewards, ModelGenerator modelGen);
public abstract void addStep(int choice, Object action, String actionString, double probability, double[] transRewards, State newState, State newObs, double[] newStateRewards, ModelGenerator modelGen);
/** /**
* Add a timed step to the path. * Add a timed step to the path.
* Note: State object and arrays will be copied, not stored directly. * Note: State object and arrays will be copied, not stored directly.
*/ */
public abstract void addStep(double time, int choice, Object action, String actionString, double probability, double[] transRewards, State newState, double[] newStateRewards, ModelGenerator modelGen);
public abstract void addStep(double time, int choice, Object action, String actionString, double probability, double[] transRewards, State newState, State newObs, double[] newStateRewards, ModelGenerator modelGen);
// ACCESSORS // ACCESSORS
@ -76,6 +76,11 @@ public abstract class Path
*/ */
public abstract State getCurrentState(); public abstract State getCurrentState();
/**
* Get the observation for the current state, i.e. for the current final state of the path.
*/
public abstract State getCurrentObservation();
/** /**
* Get the action taken in the previous step. * Get the action taken in the previous step.
*/ */

27
prism/src/simulator/PathFull.java

@ -89,7 +89,7 @@ public class PathFull extends Path implements PathFullInfo
// MUTATORS (for Path) // MUTATORS (for Path)
@Override @Override
public void initialise(State initialState, double[] initialStateRewards)
public void initialise(State initialState, State initialObs, double[] initialStateRewards)
{ {
clear(); clear();
// Add new step item to the path // Add new step item to the path
@ -97,6 +97,7 @@ public class PathFull extends Path implements PathFullInfo
steps.add(step); steps.add(step);
// Add (copies of) initial state and state rewards to new step // Add (copies of) initial state and state rewards to new step
step.state = new State(initialState); step.state = new State(initialState);
step.obs = initialObs == null ? null : new State(initialObs);
step.stateRewards = initialStateRewards.clone(); step.stateRewards = initialStateRewards.clone();
// Set cumulative time/reward (up until entering this state) // Set cumulative time/reward (up until entering this state)
step.timeCumul = 0.0; step.timeCumul = 0.0;
@ -108,15 +109,15 @@ public class PathFull extends Path implements PathFullInfo
} }
@Override @Override
public void addStep(int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState, double[] newStateRewards,
public void addStep(int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState, State newObs, double[] newStateRewards,
ModelGenerator modelGen) ModelGenerator modelGen)
{ {
addStep(1.0, choice, action, actionString, probability, transitionRewards, newState, newStateRewards, modelGen);
addStep(1.0, choice, action, actionString, probability, transitionRewards, newState, newObs, newStateRewards, modelGen);
} }
@Override @Override
public void addStep(double time, int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState, public void addStep(double time, int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState,
double[] newStateRewards, ModelGenerator modelGen)
State newObs, double[] newStateRewards, ModelGenerator modelGen)
{ {
Step stepOld, stepNew; Step stepOld, stepNew;
// Add info to last existing step // Add info to last existing step
@ -130,8 +131,9 @@ public class PathFull extends Path implements PathFullInfo
// Add new step item to the path // Add new step item to the path
stepNew = new Step(); stepNew = new Step();
steps.add(stepNew); steps.add(stepNew);
// Add (copies of) new state and state rewards to new step
// Add (copies of) new state, observation and state rewards to new step
stepNew.state = new State(newState); stepNew.state = new State(newState);
stepNew.obs = newObs == null ? null : new State(newObs);
stepNew.stateRewards = newStateRewards.clone(); stepNew.stateRewards = newStateRewards.clone();
// Set cumulative time/rewards (up until entering this state) // Set cumulative time/rewards (up until entering this state)
stepNew.timeCumul = stepOld.timeCumul + time; stepNew.timeCumul = stepOld.timeCumul + time;
@ -241,6 +243,12 @@ public class PathFull extends Path implements PathFullInfo
return steps.get(steps.size() - 1).state; return steps.get(steps.size() - 1).state;
} }
@Override
public State getCurrentObservation()
{
return steps.get(steps.size() - 1).obs;
}
@Override @Override
public Object getPreviousAction() public Object getPreviousAction()
{ {
@ -339,6 +347,12 @@ public class PathFull extends Path implements PathFullInfo
return steps.get(step).state; return steps.get(step).state;
} }
@Override
public State getObservation(int step)
{
return steps.get(step).obs;
}
@Override @Override
public double getStateReward(int step, int rsi) public double getStateReward(int step, int rsi)
{ {
@ -545,6 +559,7 @@ public class PathFull extends Path implements PathFullInfo
{ {
// Set (unknown) defaults and initialise arrays // Set (unknown) defaults and initialise arrays
state = null; state = null;
obs = null;
stateRewards = new double[numRewardStructs]; stateRewards = new double[numRewardStructs];
timeCumul = 0.0; timeCumul = 0.0;
rewardsCumul = new double[numRewardStructs]; rewardsCumul = new double[numRewardStructs];
@ -558,6 +573,8 @@ public class PathFull extends Path implements PathFullInfo
// Current state (before transition) // Current state (before transition)
public State state; public State state;
// Observation for current state
public State obs;
// State rewards for current state // State rewards for current state
public double stateRewards[]; public double stateRewards[];
// Cumulative time spent up until entering this state // Cumulative time spent up until entering this state

6
prism/src/simulator/PathFullInfo.java

@ -44,6 +44,12 @@ public interface PathFullInfo
*/ */
public abstract State getState(int step); public abstract State getState(int step);
/**
* Get the observation at a given step of the path.
* @param step Step index (0 = initial state/step of path)
*/
public abstract State getObservation(int step);
/** /**
* Get a state reward for the state at a given step of the path. * Get a state reward for the state at a given step of the path.
* If no reward info is stored ({@link #hasRewardInfo()} is false), returns 0.0. * If no reward info is stored ({@link #hasRewardInfo()} is false), returns 0.0.

12
prism/src/simulator/PathFullPrefix.java

@ -52,19 +52,19 @@ public class PathFullPrefix extends Path
// MUTATORS (for Path) // MUTATORS (for Path)
@Override @Override
public void initialise(State initialState, double[] initialStateRewards)
public void initialise(State initialState, State initialObs, double[] initialStateRewards)
{ {
// Do nothing (we are not allowed to modify the underlying PathFull) // Do nothing (we are not allowed to modify the underlying PathFull)
} }
@Override @Override
public void addStep(int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState, double[] newStateRewards, ModelGenerator modelGen)
public void addStep(int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState, State newObs, double[] newStateRewards, ModelGenerator modelGen)
{ {
// Do nothing (we are not allowed to modify the underlying PathFull) // Do nothing (we are not allowed to modify the underlying PathFull)
} }
@Override @Override
public void addStep(double time, int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState, double[] newStateRewards, ModelGenerator modelGen)
public void addStep(double time, int choice, Object action, String actionString, double probability, double[] transitionRewards, State newState, State newObs, double[] newStateRewards, ModelGenerator modelGen)
{ {
// Do nothing (we are not allowed to modify the underlying PathFull) // Do nothing (we are not allowed to modify the underlying PathFull)
} }
@ -102,6 +102,12 @@ public class PathFullPrefix extends Path
return pathFull.getState(prefixLength); return pathFull.getState(prefixLength);
} }
@Override
public State getCurrentObservation()
{
return pathFull.getObservation(prefixLength);
}
@Override @Override
public Object getPreviousAction() public Object getPreviousAction()
{ {

25
prism/src/simulator/PathOnTheFly.java

@ -48,6 +48,7 @@ public class PathOnTheFly extends Path
protected long size; protected long size;
protected State previousState; protected State previousState;
protected State currentState; protected State currentState;
protected State currentObs;
protected Object previousAction; protected Object previousAction;
protected String previousActionString; protected String previousActionString;
protected double previousProbability; protected double previousProbability;
@ -73,6 +74,10 @@ public class PathOnTheFly extends Path
// Create State objects for current/previous state // Create State objects for current/previous state
previousState = new State(modelInfo.getNumVars()); previousState = new State(modelInfo.getNumVars());
currentState = new State(modelInfo.getNumVars()); currentState = new State(modelInfo.getNumVars());
currentObs = null;
if (modelInfo.getModelType().partiallyObservable()) {
currentObs = new State(modelInfo.getNumObservables());
}
// Create arrays to store totals // Create arrays to store totals
totalRewards = new double[numRewardStructs]; totalRewards = new double[numRewardStructs];
previousStateRewards = new double[numRewardStructs]; previousStateRewards = new double[numRewardStructs];
@ -106,10 +111,13 @@ public class PathOnTheFly extends Path
// MUTATORS (for Path) // MUTATORS (for Path)
@Override @Override
public void initialise(State initialState, double[] initialStateRewards)
public void initialise(State initialState, State initialObs, double[] initialStateRewards)
{ {
clear(); clear();
currentState.copy(initialState); currentState.copy(initialState);
if (initialObs != null) {
currentObs.copy(initialObs);
}
for (int i = 0; i < numRewardStructs; i++) { for (int i = 0; i < numRewardStructs; i++) {
currentStateRewards[i] = initialStateRewards[i]; currentStateRewards[i] = initialStateRewards[i];
} }
@ -118,17 +126,20 @@ public class PathOnTheFly extends Path
} }
@Override @Override
public void addStep(int choice, Object action, String actionString, double probability, double[] transRewards, State newState, double[] newStateRewards, ModelGenerator modelGen)
public void addStep(int choice, Object action, String actionString, double probability, double[] transRewards, State newState, State newObs, double[] newStateRewards, ModelGenerator modelGen)
{ {
addStep(1.0, choice, action, actionString, probability, transRewards, newState, newStateRewards, modelGen);
addStep(1.0, choice, action, actionString, probability, transRewards, newState, newObs, newStateRewards, modelGen);
} }
@Override @Override
public void addStep(double time, int choice, Object action, String actionString, double probability, double[] transRewards, State newState, double[] newStateRewards, ModelGenerator modelGen)
public void addStep(double time, int choice, Object action, String actionString, double probability, double[] transRewards, State newState, State newObs, double[] newStateRewards, ModelGenerator modelGen)
{ {
size++; size++;
previousState.copy(currentState); previousState.copy(currentState);
currentState.copy(newState); currentState.copy(newState);
if (newObs != null) {
currentObs.copy(newObs);
}
previousAction = action; previousAction = action;
previousActionString = actionString; previousActionString = actionString;
previousProbability = probability; previousProbability = probability;
@ -174,6 +185,12 @@ public class PathOnTheFly extends Path
return currentState; return currentState;
} }
@Override
public State getCurrentObservation()
{
return currentObs;
}
@Override @Override
public Object getPreviousAction() public Object getPreviousAction()
{ {

23
prism/src/simulator/SimulatorEngine.java

@ -303,10 +303,12 @@ public class SimulatorEngine extends PrismComponent
throw new PrismNotSupportedException("Random choice of multiple initial states not yet supported"); throw new PrismNotSupportedException("Random choice of multiple initial states not yet supported");
} }
} }
// Get initial observation
State currentObs = modelGen.getObservation(currentState);
// Get initial state reward // Get initial state reward
calculateStateRewards(currentState, tmpStateRewards); calculateStateRewards(currentState, tmpStateRewards);
// Initialise stored path // Initialise stored path
path.initialise(currentState, tmpStateRewards);
path.initialise(currentState, currentObs, tmpStateRewards);
// Explore initial state in model generator // Explore initial state in model generator
computeTransitionsForState(currentState); computeTransitionsForState(currentState);
// Reset and then update samplers for any loaded properties // Reset and then update samplers for any loaded properties
@ -373,6 +375,7 @@ public class SimulatorEngine extends PrismComponent
executeTransition(ref.i, ref.offset, -1); executeTransition(ref.i, ref.offset, -1);
break; break;
case MDP: case MDP:
case POMDP:
// Pick a random choice // Pick a random choice
i = rng.randomUnifInt(modelGen.getNumChoices()); i = rng.randomUnifInt(modelGen.getNumChoices());
// Pick a random transition from this choice // Pick a random transition from this choice
@ -858,10 +861,12 @@ public class SimulatorEngine extends PrismComponent
calculateTransitionRewards(path.getCurrentState(), action, tmpTransitionRewards); calculateTransitionRewards(path.getCurrentState(), action, tmpTransitionRewards);
// Compute next state // Compute next state
currentState.copy(modelGen.computeTransitionTarget(i, offset)); currentState.copy(modelGen.computeTransitionTarget(i, offset));
// Compute observation for new state
State currentObs = modelGen.getObservation(currentState);
// Compute state rewards for new state // Compute state rewards for new state
calculateStateRewards(currentState, tmpStateRewards); calculateStateRewards(currentState, tmpStateRewards);
// Update path // Update path
path.addStep(index, action, actionString, p, tmpTransitionRewards, currentState, tmpStateRewards, modelGen);
path.addStep(index, action, actionString, p, tmpTransitionRewards, currentState, currentObs, tmpStateRewards, modelGen);
// Explore new state in model generator // Explore new state in model generator
computeTransitionsForState(currentState); computeTransitionsForState(currentState);
// Update samplers for any loaded properties // Update samplers for any loaded properties
@ -896,10 +901,12 @@ public class SimulatorEngine extends PrismComponent
calculateTransitionRewards(path.getCurrentState(), action, tmpTransitionRewards); calculateTransitionRewards(path.getCurrentState(), action, tmpTransitionRewards);
// Compute next state // Compute next state
currentState.copy(modelGen.computeTransitionTarget(i, offset)); currentState.copy(modelGen.computeTransitionTarget(i, offset));
// Compute observation for new state
State currentObs = modelGen.getObservation(currentState);
// Compute state rewards for new state // Compute state rewards for new state
calculateStateRewards(currentState, tmpStateRewards); calculateStateRewards(currentState, tmpStateRewards);
// Update path // Update path
path.addStep(time, index, action, actionString, p, tmpTransitionRewards, currentState, tmpStateRewards, modelGen);
path.addStep(time, index, action, actionString, p, tmpTransitionRewards, currentState, currentObs, tmpStateRewards, modelGen);
// Explore new state in model generator // Explore new state in model generator
computeTransitionsForState(currentState); computeTransitionsForState(currentState);
// Update samplers for any loaded properties // Update samplers for any loaded properties
@ -1278,6 +1285,16 @@ public class SimulatorEngine extends PrismComponent
return ((PathFull) path).getState(step); return ((PathFull) path).getState(step);
} }
/**
* Get the observation at a given step of the path.
* (Not applicable for on-the-fly paths)
* @param step Step index (0 = initial state/step of path)
*/
public State getObservationOfPathStep(int step)
{
return ((PathFull) path).getObservation(step);
}
/** /**
* Get a state reward for the state at a given step of the path. * Get a state reward for the state at a given step of the path.
* (Not applicable for on-the-fly paths) * (Not applicable for on-the-fly paths)

Loading…
Cancel
Save