@ -27,18 +27,22 @@
package explicit ;
import java.util.PrimitiveIterator ;
import common.IntSet ;
import common.PeriodicTimer ;
import explicit.rewards.MCRewards ;
import explicit.rewards.MDPRewards ;
import prism.OptionsIntervalIteration ;
import prism.PrismException ;
import prism.PrismSettings ;
import prism.PrismUtils ;
/ * *
* Abstract class that encapsulates the functionality for the different iteration methods
* ( e . g . , Power , Jacobi , Gauss - Seidel , . . . ) .
* < p >
* Provides methods as well to do the actual work in a ( topological ) value iteration .
* Provides methods as well to do the actual work in a ( topological ) value or interval iteration .
* /
public abstract class IterationMethod {
@ -64,6 +68,30 @@ public abstract class IterationMethod {
public Model getModel ( ) ;
}
/ * *
* Interface for an object that provides the atomic steps for a value iteration
* in the context of interval iteration , i . e . , for the interval iteration
* there are two IterationIntervalIter objects , one from below or from above .
* /
public interface IterationIntervalIter {
/** Initialise the value iteration with the given solution vector */
public void init ( double [ ] soln ) ;
/** Get the current solution vector */
public double [ ] getSolnVector ( ) ;
/** Perform one iteration (over the set of states) */
public void iterate ( IntSet states ) throws PrismException ;
/ * *
* Solve for a given singleton SCC consisting of { @code state } using { @code solver } ,
* store the result in the solution vector ( s ) .
* /
public void solveSingletonSCC ( int s , SingletonSCCSolver solver ) ;
/** Return the underlying model */
public Model getModel ( ) ;
}
/** Storage for a single solution vector */
public class IterationBasic {
protected final Model model ;
@ -104,6 +132,15 @@ public abstract class IterationMethod {
}
}
/** Abstract base class for an IterationIntervalIter with a single solution vector */
protected abstract class SingleVectorIterationIntervalIter extends IterationBasic implements IterationIntervalIter
{
public SingleVectorIterationIntervalIter ( Model model )
{
super ( model ) ;
}
}
/ * *
* Functional interface for a post - processing step after an iteration that involves
* a pair of solution vectors .
@ -120,11 +157,11 @@ public abstract class IterationMethod {
}
/ * *
* Abstract base class for an IterationValIter that
* Abstract base class for an IterationValIter / IterationIntervalIter that
* requires two solution vectors .
* Optionally , a post processing step is performed after each iteration .
* /
protected abstract class TwoVectorIteration extends IterationBasic implements IterationValIter {
protected abstract class TwoVectorIteration extends IterationBasic implements IterationValIter , IterationIntervalIter {
/** The solution vector that serves as the target vector in the iteration step */
protected double [ ] soln2 ;
/** Post processing, may be null */
@ -150,6 +187,22 @@ public abstract class IterationMethod {
/** Perform one iteration */
public abstract void doIterate ( IntSet states ) throws PrismException ;
@Override
public void iterate ( IntSet states ) throws PrismException
{
/ / do the iteration
doIterate ( states ) ;
/ / optionally , post processing
if ( postProcessor ! = null ) {
postProcessor . apply ( soln , soln2 , states ) ;
}
/ / switch vectors
double [ ] tmp = soln ;
soln = soln2 ;
soln2 = tmp ;
}
@Override
public boolean iterateAndCheckConvergence ( IntSet states ) throws PrismException
{
@ -218,10 +271,26 @@ public abstract class IterationMethod {
/** Obtain an Iteration object using mvMult (matrix-vector multiplication) in a DTMC */
public abstract IterationValIter forMvMult ( DTMC dtmc ) throws PrismException ;
/ * *
* Obtain an Iteration object ( for interval iteration ) using mvMult
* ( matrix - vector multiplication ) in a DTMC .
* @param fromBelow for interval iteration from below ?
* @param enforceMonotonic enforce element - wise monotonicity of the solution vector
* @param checkMonotonic check the element - wise monotonicity of the solution vector , throw exception if violated
* /
public abstract IterationIntervalIter forMvMultInterval ( DTMC dtmc , boolean fromBelow , boolean enforceMonotonicity , boolean checkMonotonicity ) throws PrismException ;
/** Obtain an Iteration object using mvMultRew (matrix-vector multiplication with rewards) in a DTMC */
public abstract IterationValIter forMvMultRew ( DTMC dtmc , MCRewards rew ) throws PrismException ;
/ * *
* Obtain an Iteration object ( for interval iteration ) using mvMultRew
* ( matrix - vector multiplication with rewards ) in a DTMC .
* @param fromBelow for interval iteration from below ?
* @param enforceMonotonic enforce element - wise monotonicity of the solution vector
* @param checkMonotonic check the element - wise monotonicity of the solution vector , throw exception if violated
* /
public abstract IterationIntervalIter forMvMultRewInterval ( DTMC dtmc , MCRewards rew , boolean fromBelow , boolean enforceMonotonicity , boolean checkMonotonicity ) throws PrismException ;
/ / - - - - - - - - - - - - Abstract MDP methods - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@ -234,6 +303,18 @@ public abstract class IterationMethod {
* /
public abstract IterationValIter forMvMultMinMax ( MDP mdp , boolean min , int [ ] strat ) throws PrismException ;
/ * *
* Obtain an Iteration object using mvMultMinMax ( matrix - vector multiplication , followed by min / max )
* in an MDP , for interval iteration .
* @param mdp the MDP
* @param min do min ?
* @param strat optional , storage for strategy , ignored if null
* @param fromBelow for interval iteration from below ?
* @param enforceMonotonic enforce element - wise monotonicity of the solution vector
* @param checkMonotonic check the element - wise monotonicity of the solution vector , throw exception if violated
* /
public abstract IterationIntervalIter forMvMultMinMaxInterval ( MDP mdp , boolean min , int [ ] strat , boolean fromBelow , boolean enforceMonotonicity , boolean checkMonotonicity ) throws PrismException ;
/ * *
* Obtain an Iteration object using mvMultRewMinMax ( matrix - vector multiplication with rewards , followed by min / max )
* in an MDP .
@ -244,6 +325,18 @@ public abstract class IterationMethod {
* /
public abstract IterationValIter forMvMultRewMinMax ( MDP mdp , MDPRewards rewards , boolean min , int [ ] strat ) throws PrismException ;
/ * *
* Obtain an Iteration object using mvMultRewMinMax ( matrix - vector multiplication with rewards , followed by min / max )
* in an MDP , for interval iteration .
* @param mdp the MDP
* @param rewards the reward structure
* @param min do min ?
* @param strat optional , storage for strategy , ignored if null
* @param fromBelow for interval iteration from below ?
* @param enforceMonotonic enforce element - wise monotonicity of the solution vector
* @param checkMonotonic check the element - wise monotonicity of the solution vector , throw exception if violated
* /
public abstract IterationIntervalIter forMvMultRewMinMaxInterval ( MDP mdp , MDPRewards rewards , boolean min , int [ ] strat , boolean fromBelow , boolean enforceMonotonicity , boolean checkMonotonicity ) throws PrismException ;
/ / - - - - - - - - - - - - Abstract generic methods - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@ -416,6 +509,289 @@ public abstract class IterationMethod {
return res ;
}
/ * *
* Perform the actual work of an interval iteration , i . e . , iterate until convergence or abort .
*
* @param mc ProbModelChecker ( for log and settings )
* @param description Description ( for logging )
* @param below The iteration object for the iteration from below
* @param above The iteration object for the iteration from above
* @param unknownStates The set of unknown states , i . e . , whose value should be determined
* @param startTime The start time ( for logging purposes , obtained from a call to System . currentTimeMillis ( ) )
* @param iterationsExport an ExportIterations object ( optional , ignored if null )
* @return a ModelChecker result with the solution vector and statistics
* @throws PrismException on non - convergence ( if mc . errorOnNonConverge is set )
* /
public ModelCheckerResult doIntervalIteration ( ProbModelChecker mc , String description , IterationIntervalIter below , IterationIntervalIter above , IntSet unknownStates , long timer , ExportIterations iterationsExport ) throws PrismException {
try {
/ / Start iterations
int iters = 0 ;
final int maxIters = mc . maxIters ;
boolean done = false ;
PeriodicTimer updatesTimer = new PeriodicTimer ( ProbModelChecker . UPDATE_DELAY ) ;
updatesTimer . start ( ) ;
while ( ! done & & iters < maxIters ) {
iters + + ;
/ / Matrix - vector multiply
below . iterate ( unknownStates ) ;
above . iterate ( unknownStates ) ;
if ( iterationsExport ! = null ) {
iterationsExport . exportVector ( below . getSolnVector ( ) , 0 ) ;
iterationsExport . exportVector ( above . getSolnVector ( ) , 1 ) ;
}
intervalIterationCheckForProblems ( below . getSolnVector ( ) , above . getSolnVector ( ) , unknownStates . iterator ( ) ) ;
/ / Check termination
done = PrismUtils . doublesAreClose ( below . getSolnVector ( ) , above . getSolnVector ( ) , termCritParam , absolute ) ;
if ( done ) {
double diff = PrismUtils . measureSupNormInterval ( below . getSolnVector ( ) , above . getSolnVector ( ) , absolute ) ;
mc . getLog ( ) . println ( "Max " + ( ! absolute ? "relative " : "" ) +
"diff between upper and lower bound on convergence: " + PrismUtils . formatDouble ( diff ) ) ;
done = true ;
}
if ( ! done & & updatesTimer . triggered ( ) ) {
double diff = PrismUtils . measureSupNormInterval ( below . getSolnVector ( ) , above . getSolnVector ( ) , absolute ) ;
mc . getLog ( ) . print ( "Iteration " + iters + ": " ) ;
mc . getLog ( ) . print ( "max " + ( absolute ? "" : "relative " ) + "diff=" + PrismUtils . formatDouble ( diff ) ) ;
mc . getLog ( ) . println ( ", " + PrismUtils . formatDouble2dp ( updatesTimer . elapsedMillisTotal ( ) / 1000 . 0 ) + " sec so far" ) ;
}
}
/ / Finished value iteration
long mvCount = 2 * iters * countTransitions ( below . getModel ( ) , unknownStates ) ;
timer = System . currentTimeMillis ( ) - timer ;
mc . getLog ( ) . print ( "Interval iteration (" + description + ")" ) ;
mc . getLog ( ) . print ( " took " + iters + " iterations, " ) ;
mc . getLog ( ) . print ( mvCount + " MV-multiplications" ) ;
mc . getLog ( ) . println ( " and " + timer / 1000 . 0 + " seconds." ) ;
if ( done & & OptionsIntervalIteration . from ( mc . getSettings ( ) ) . isSelectMidpointForResult ( ) ) {
PrismUtils . selectMidpoint ( below . getSolnVector ( ) , above . getSolnVector ( ) ) ;
if ( iterationsExport ! = null ) {
/ / export midpoint
iterationsExport . exportVector ( below . getSolnVector ( ) , 0 ) ;
iterationsExport . exportVector ( below . getSolnVector ( ) , 1 ) ;
}
}
/ / Non - convergence is an error ( usually )
if ( ! done & & mc . errorOnNonConverge ) {
String msg = "Iterative method (interval iteration) did not converge within " + iters + " iterations." ;
msg + = "\nConsider using a different numerical method or increasing the maximum number of iterations" ;
throw new PrismException ( msg ) ;
}
/ / Return results
ModelCheckerResult res = new ModelCheckerResult ( ) ;
res . soln = below . getSolnVector ( ) ;
res . numIters = iters ;
res . timeTaken = timer / 1000 . 0 ;
return res ;
} finally {
if ( iterationsExport ! = null )
iterationsExport . close ( ) ;
}
}
/ * *
* Perform the actual work of a topological interval iteration , i . e . , iterate until convergence or abort .
*
* @param mc ProbModelChecker ( for log and settings )
* @param description Description ( for logging )
* @param sccs The information about the SCCs and topological order
* @param below The iteration object for the value iteration from below
* @param above The iteration object for the value iteration from above
* @param singletonSCCSolver The solver for singleton SCCs
* @param startTime The start time ( for logging purposes , obtained from a call to System . currentTimeMillis ( ) )
* @param iterationsExport an ExportIterations object ( optional , ignored if null )
* @return a ModelChecker result with the solution vector and statistics
* @throws PrismException on non - convergence ( if mc . errorOnNonConverge is set )
* /
public ModelCheckerResult doTopologicalIntervalIteration ( ProbModelChecker mc , String description , SCCInfo sccs , IterationIntervalIter below , IterationIntervalIter above , SingletonSCCSolver singletonSCCSolver , long timer , ExportIterations iterationsExport ) throws PrismException {
try {
/ / Start iterations
int iters = 0 ;
long mvCount = 0 ;
final int maxIters = mc . maxIters ;
PeriodicTimer updatesTimer = new PeriodicTimer ( ProbModelChecker . UPDATE_DELAY ) ;
updatesTimer . start ( ) ;
int numSCCs = sccs . getNumSCCs ( ) ;
int numNonSingletonSCCs = sccs . countNonSingletonSCCs ( ) ;
int finishedNonSingletonSCCs = 0 ;
boolean done = true ;
for ( int scc = 0 ; scc < numSCCs ; scc + + ) {
boolean doneSCC ;
if ( sccs . isSingletonSCC ( scc ) ) {
/ / get the single state in this SCC
int state = sccs . getStatesForSCC ( scc ) . iterator ( ) . nextInt ( ) ;
below . solveSingletonSCC ( state , singletonSCCSolver ) ;
above . solveSingletonSCC ( state , singletonSCCSolver ) ;
iters + + ;
mvCount + = 2 * countTransitions ( below . getModel ( ) , IntSet . asIntSet ( state ) ) ;
if ( iterationsExport ! = null ) {
iterationsExport . exportVector ( below . getSolnVector ( ) , 0 ) ;
iterationsExport . exportVector ( above . getSolnVector ( ) , 1 ) ;
}
intervalIterationCheckForProblems ( below . getSolnVector ( ) , above . getSolnVector ( ) , IntSet . asIntSet ( state ) . iterator ( ) ) ;
doneSCC = true ;
} else {
/ / complex SCC : do VI
doneSCC = false ;
int itersInSCC = 0 ;
IntSet statesForSCC = sccs . getStatesForSCC ( scc ) ;
/ / Adjust upper bound by adding 2 * epsilon ,
/ / adding 1 * epsilon would be fine , but we are a bit more conservative .
/ / TODO : We also don ' t really need to do adjustment for bottom SCCs . . .
PrimitiveIterator . OfInt it = statesForSCC . iterator ( ) ;
final double [ ] solnAbove = above . getSolnVector ( ) ;
final double adjustment = 2 * termCritParam ;
while ( it . hasNext ( ) ) {
solnAbove [ it . nextInt ( ) ] + = adjustment ;
}
/ / abort on convergence or if iterations * in this SCC * are above maxIters
while ( ! doneSCC & & itersInSCC < maxIters ) {
iters + + ;
itersInSCC + + ;
/ / do iteration step
below . iterate ( statesForSCC ) ;
above . iterate ( statesForSCC ) ;
if ( iterationsExport ! = null ) {
iterationsExport . exportVector ( below . getSolnVector ( ) , 0 ) ;
iterationsExport . exportVector ( above . getSolnVector ( ) , 1 ) ;
}
intervalIterationCheckForProblems ( below . getSolnVector ( ) , above . getSolnVector ( ) , statesForSCC . iterator ( ) ) ;
/ / Check termination ( inside SCC )
doneSCC = PrismUtils . doublesAreClose ( below . getSolnVector ( ) , above . getSolnVector ( ) , statesForSCC . iterator ( ) , termCritParam , absolute ) ;
if ( ! doneSCC & & updatesTimer . triggered ( ) ) {
double diff = PrismUtils . measureSupNormInterval ( below . getSolnVector ( ) , above . getSolnVector ( ) , absolute , statesForSCC . iterator ( ) ) ;
mc . getLog ( ) . print ( "Iteration " + iters + ": " ) ;
mc . getLog ( ) . print ( "max " + ( absolute ? "" : "relative " ) + "diff (for iteration " + itersInSCC + " in current SCC " + ( finishedNonSingletonSCCs + 1 ) + " of " + numNonSingletonSCCs + ") = " + PrismUtils . formatDouble ( diff ) ) ;
mc . getLog ( ) . println ( ", " + PrismUtils . formatDouble2dp ( updatesTimer . elapsedMillisTotal ( ) / 1000 . 0 ) + " sec so far" ) ;
}
}
mvCount + = 2 * itersInSCC * countTransitions ( below . getModel ( ) , statesForSCC ) ;
finishedNonSingletonSCCs + + ;
}
if ( ! doneSCC ) {
done = false ;
break ;
}
}
if ( done ) {
double diff = PrismUtils . measureSupNormInterval ( below . getSolnVector ( ) , above . getSolnVector ( ) , absolute ) ;
mc . getLog ( ) . println ( "Max " + ( absolute ? "" : "relative " ) +
"diff between upper and lower bound on convergence: " + PrismUtils . formatDouble ( diff ) ) ;
done = true ;
}
/ / Finished value iteration
timer = System . currentTimeMillis ( ) - timer ;
mc . getLog ( ) . print ( "Interval iteration (" + description + ", with " + numNonSingletonSCCs + " non-singleton SCCs)" ) ;
mc . getLog ( ) . print ( " took " + iters + " iterations, " ) ;
mc . getLog ( ) . print ( mvCount + " MV-multiplications" ) ;
mc . getLog ( ) . println ( " and " + timer / 1000 . 0 + " seconds." ) ;
if ( done & & OptionsIntervalIteration . from ( mc . getSettings ( ) ) . isSelectMidpointForResult ( ) ) {
PrismUtils . selectMidpoint ( below . getSolnVector ( ) , above . getSolnVector ( ) ) ;
if ( iterationsExport ! = null ) {
/ / export midpoint
iterationsExport . exportVector ( below . getSolnVector ( ) , 0 ) ;
iterationsExport . exportVector ( below . getSolnVector ( ) , 1 ) ;
}
}
if ( iterationsExport ! = null )
iterationsExport . close ( ) ;
/ / Non - convergence is an error ( usually )
if ( ! done & & mc . errorOnNonConverge ) {
String msg = "Iterative method (interval iteration) did not converge within " + iters + " iterations." ;
msg + = "\nConsider using a different numerical method or increasing the maximum number of iterations" ;
throw new PrismException ( msg ) ;
}
/ / Return results
ModelCheckerResult res = new ModelCheckerResult ( ) ;
res . soln = below . getSolnVector ( ) ;
res . numIters = iters ;
res . timeTaken = timer / 1000 . 0 ;
return res ;
} finally {
if ( iterationsExport ! = null )
iterationsExport . close ( ) ;
}
}
/ * *
* Compares the current lower and upper solution vectors in an interval iteration
* and throws an exception if lower bound values are larger than upper bound values ,
* as this indicates problems .
* @param lower the current lower iteration solution vector
* @param upper the current upper iteration solution vector
* @param states iterator over the states in question
* /
private static void intervalIterationCheckForProblems ( double [ ] lower , double [ ] upper , PrimitiveIterator . OfInt states ) throws PrismException
{
while ( states . hasNext ( ) ) {
int s = states . nextInt ( ) ;
if ( lower [ s ] > upper [ s ] ) {
throw new PrismException ( "In interval iteration, the lower value (" + lower [ s ] + ") is larger than the upper value (" + upper [ s ] + ").\n"
+ "This indicates either problems with numerical stability (rounding, precision of the floating-point representation) or that the initial bounds (for reward computations) are incorrect" ) ;
}
}
}
/ * *
* Perform a post - processing for a two - vector value iteration .
* @param solnOld the previous solution vector
* @param solnNew the newly computed solution vector
* @param states the relevant set of states
* @param fromBelow are we iterating from below ?
* @param enforceMonotonicity if true , enforces monotonicity
* @param checkMonotonicity if true , checks for monotonicity ( and throws error when non - monotonic )
* /
public static void twoVectorPostProcessing ( double [ ] solnOld , double [ ] solnNew , IntSet states , boolean fromBelow , boolean enforceMonotonicity , boolean checkMonotonicity ) throws PrismException
{
/ / TODO : use IntSet states
if ( enforceMonotonicity )
if ( fromBelow ) {
PrismUtils . ensureMonotonicityFromBelow ( solnOld , solnNew ) ;
} else {
PrismUtils . ensureMonotonicityFromAbove ( solnOld , solnNew ) ;
}
if ( checkMonotonicity ) {
PrismUtils . checkMonotonicity ( solnOld , solnNew , ! fromBelow ) ;
}
}
protected long countTransitions ( Model model , IntSet unknownStates )
{
if ( model instanceof DTMC ) {