@ -106,6 +106,14 @@ public class Modules2MTBDD
/ / flags for keeping track of which variables have been used
private boolean [ ] varsUsed ;
/ / symmetry info
private boolean doSymmetry ; / / use symmetry reduction
private JDDNode symm ; / / dd of symmetric states
private JDDNode nonSymms [ ] ; / / dds of non - ( i , i + 1 ) - symmetric states ( i = 0 . . . numSymmModules - 1 )
private int numModulesBeforeSymm ; / / number of modules in the PRISM file before the symmetric ones
private int numModulesAfterSymm ; / / number of modules in the PRISM file after the symmetric ones
private int numSymmModules ; / / number of symmetric components
/ / data structure used to store mtbdds and related info
/ / for some component of the whole model
@ -146,6 +154,9 @@ public class Modules2MTBDD
mainLog = p . getMainLog ( ) ;
techLog = p . getTechLog ( ) ;
modulesFile = mf ;
/ / get symmetry reduction info
String s = prism . getSettings ( ) . getString ( PrismSettings . PRISM_SYMM_RED_PARAMS ) ;
doSymmetry = ! ( s = = null | | s = = "" ) ;
}
/ / main method - translate
@ -262,6 +273,9 @@ public class Modules2MTBDD
mainLog . print ( "Reach: " + JDD . GetNumNodes ( model . getReach ( ) ) + " nodes\n" ) ;
}
/ / symmetrification
if ( doSymmetry ) doSymmetry ( model ) ;
/ / find any deadlocks
model . findDeadlocks ( ) ;
@ -289,6 +303,12 @@ public class Modules2MTBDD
JDD . Deref ( ddChoiceVars [ i ] ) ;
}
}
if ( doSymmetry ) {
JDD . Deref ( symm ) ;
for ( i = 0 ; i < numSymmModules - 1 ; i + + ) {
JDD . Deref ( nonSymms [ i ] ) ;
}
}
expr2mtbdd . clearDummyModel ( ) ;
@ -1968,6 +1988,234 @@ public class Modules2MTBDD
}
}
}
/ / symmetrification
private void doSymmetry ( Model model ) throws PrismException
{
JDDNode tmp , transNew , reach , trans , transRewards [ ] ;
int i , j , k , numSwaps ;
boolean done ;
long clock ;
String ss [ ] ;
/ / parse symmetry reduction parameters
ss = prism . getSettings ( ) . getString ( PrismSettings . PRISM_SYMM_RED_PARAMS ) . split ( " " ) ;
if ( ss . length ! = 2 ) throw new PrismException ( "Invalid parameters for symmetry reduction" ) ;
try {
numModulesBeforeSymm = Integer . parseInt ( ss [ 0 ] . trim ( ) ) ;
numModulesAfterSymm = Integer . parseInt ( ss [ 1 ] . trim ( ) ) ;
}
catch ( NumberFormatException e ) {
throw new PrismException ( "Invalid parameters for symmetry reduction" ) ;
}
clock = System . currentTimeMillis ( ) ;
/ / get a copies of model ( MT ) BDDs
reach = model . getReach ( ) ;
JDD . Ref ( reach ) ;
trans = model . getTrans ( ) ;
JDD . Ref ( trans ) ;
transRewards = new JDDNode [ numRewardStructs ] ;
for ( i = 0 ; i < numRewardStructs ; i + + ) {
transRewards [ i ] = model . getTransRewards ( i ) ;
JDD . Ref ( transRewards [ i ] ) ;
}
mainLog . print ( "\nApplying symmetry reduction...\n" ) ;
/ / identifySymmetricModules ( ) ;
numSymmModules = numModules - ( numModulesBeforeSymm + numModulesAfterSymm ) ;
computeSymmetryFilters ( reach ) ;
/ / compute number of local states
/ / JDD . Ref ( reach ) ;
/ / tmp = reach ;
/ / for ( i = 0 ; i < numModules ; i + + ) {
/ / if ( i ! = numModulesBeforeSymm ) tmp = JDD . ThereExists ( tmp , moduleDDRowVars [ i ] ) ;
/ / }
/ / tmp = JDD . ThereExists ( tmp , globalDDRowVars ) ;
/ / mainLog . println ( "Local states: " + ( int ) JDD . GetNumMinterms ( tmp , moduleDDRowVars [ numModulesBeforeSymm ] . n ( ) ) ) ;
/ / JDD . Deref ( tmp ) ;
/ / ODDNode odd = ODDUtils . BuildODD ( reach , allDDRowVars ) ;
/ / try { sparse . PrismSparse . NondetExport ( trans , allDDRowVars , allDDColVars , allDDNondetVars , odd , Prism . EXPORT_PLAIN , "trans-full.tra" ) ; } catch ( FileNotFoundException e ) { }
/ / trans - rows
mainLog . print ( "trans (full): " ) ;
mainLog . println ( JDD . GetInfoString ( trans , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
JDD . Ref ( symm ) ;
trans = JDD . Apply ( JDD . TIMES , trans , symm ) ;
mainLog . print ( "trans (symm): " ) ;
mainLog . println ( JDD . GetInfoString ( trans , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
/ / trans rewards - rows
for ( k = 0 ; k < numRewardStructs ; k + + ) {
mainLog . print ( "transrew[" + k + "] (full): " ) ;
mainLog . println ( JDD . GetInfoString ( transRewards [ k ] , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
JDD . Ref ( symm ) ;
transRewards [ k ] = JDD . Apply ( JDD . TIMES , transRewards [ k ] , symm ) ;
mainLog . print ( "transrew[" + k + "] (symm): " ) ;
mainLog . println ( JDD . GetInfoString ( transRewards [ k ] , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
}
mainLog . println ( "Starting quicksort..." ) ;
done = false ;
numSwaps = 0 ;
for ( i = numSymmModules ; i > 1 & & ! done ; i - - ) {
/ / store trans from previous iter
JDD . Ref ( trans ) ;
transNew = trans ;
for ( j = 0 ; j < i - 1 ; j + + ) {
/ / are there any states where j + 1 > j + 2 ?
if ( nonSymms [ j ] . equals ( JDD . ZERO ) ) continue ;
/ / identify offending block in trans
JDD . Ref ( transNew ) ;
JDD . Ref ( nonSymms [ j ] ) ;
tmp = JDD . Apply ( JDD . TIMES , transNew , JDD . PermuteVariables ( nonSymms [ j ] , allDDRowVars , allDDColVars ) ) ;
/ / mainLog . print ( "bad block: " ) ;
/ / mainLog . println ( JDD . GetInfoString ( tmp , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
if ( tmp . equals ( JDD . ZERO ) ) { JDD . Deref ( tmp ) ; continue ; }
numSwaps + + ;
mainLog . println ( "Iteration " + ( numSymmModules - i + 1 ) + "." + ( j + 1 ) ) ;
/ / swap
tmp = JDD . SwapVariables ( tmp , moduleDDColVars [ numModulesBeforeSymm + j ] , moduleDDColVars [ numModulesBeforeSymm + j + 1 ] ) ;
/ / mainLog . print ( "bad block (swapped): " ) ;
/ / mainLog . println ( JDD . GetInfoString ( tmp , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
/ / insert swapped block
JDD . Ref ( nonSymms [ j ] ) ;
JDD . Ref ( tmp ) ;
transNew = JDD . ITE ( JDD . PermuteVariables ( nonSymms [ j ] , allDDRowVars , allDDColVars ) , JDD . Constant ( 0 ) , JDD . Apply ( JDD . PLUS , transNew , tmp ) ) ;
/ / mainLog . print ( "trans (symm): " ) ;
/ / mainLog . println ( JDD . GetInfoString ( transNew , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
JDD . Deref ( tmp ) ;
for ( k = 0 ; k < numRewardStructs ; k + + ) {
/ / identify offending block in trans rewards
JDD . Ref ( transRewards [ k ] ) ;
JDD . Ref ( nonSymms [ j ] ) ;
tmp = JDD . Apply ( JDD . TIMES , transRewards [ k ] , JDD . PermuteVariables ( nonSymms [ j ] , allDDRowVars , allDDColVars ) ) ;
/ / mainLog . print ( "bad block: " ) ;
/ / mainLog . println ( JDD . GetInfoString ( tmp , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
/ / swap
tmp = JDD . SwapVariables ( tmp , moduleDDColVars [ numModulesBeforeSymm + j ] , moduleDDColVars [ numModulesBeforeSymm + j + 1 ] ) ;
/ / mainLog . print ( "bad block (swapped): " ) ;
/ / mainLog . println ( JDD . GetInfoString ( tmp , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
/ / insert swapped block
JDD . Ref ( nonSymms [ j ] ) ;
JDD . Ref ( tmp ) ;
transRewards [ k ] = JDD . ITE ( JDD . PermuteVariables ( nonSymms [ j ] , allDDRowVars , allDDColVars ) , JDD . Constant ( 0 ) , JDD . Apply ( JDD . PLUS , transRewards [ k ] , tmp ) ) ;
/ / mainLog . print ( "transrew[" + k + "] (symm): " ) ;
/ / mainLog . println ( JDD . GetInfoString ( transRewards [ k ] , ( type = = ModulesFile . NONDETERMINISTIC ) ? ( allDDRowVars . n ( ) * 2 + allDDNondetVars . n ( ) ) : ( allDDRowVars . n ( ) * 2 ) ) ) ;
JDD . Deref ( tmp ) ;
}
}
if ( transNew . equals ( trans ) ) {
done = true ;
}
JDD . Deref ( trans ) ;
trans = transNew ;
}
/ / reset ( MT ) BDDs in model
model . resetTrans ( trans ) ;
for ( i = 0 ; i < numRewardStructs ; i + + ) {
model . resetTransRewards ( i , transRewards [ i ] ) ;
}
/ / reset reach bdd , etc .
JDD . Ref ( symm ) ;
reach = JDD . And ( reach , symm ) ;
model . setReach ( reach ) ;
model . filterReachableStates ( ) ;
clock = System . currentTimeMillis ( ) - clock ;
mainLog . println ( "Symmetry complete: " + ( numSymmModules - i ) + " iterations, " + numSwaps + " swaps, " + clock / 1000 . 0 + " seconds" ) ;
}
private void computeSymmetryFilters ( JDDNode reach ) throws PrismException
{
int i ;
JDDNode tmp ;
/ / array for non - symmetric parts
nonSymms = new JDDNode [ numSymmModules - 1 ] ;
/ / dd for all symmetric states
JDD . Ref ( reach ) ;
symm = reach ;
/ / loop through symmetric module pairs
for ( i = 0 ; i < numSymmModules - 1 ; i + + ) {
/ / ( locally ) symmetric states , i . e . where i + 1 < = i + 2
tmp = JDD . VariablesLessThanEquals ( moduleDDRowVars [ numModulesBeforeSymm + i ] , moduleDDRowVars [ numModulesBeforeSymm + i + 1 ] ) ;
/ / non - ( locally ) - symmetric states
JDD . Ref ( tmp ) ;
JDD . Ref ( reach ) ;
nonSymms [ i ] = JDD . And ( JDD . Not ( tmp ) , reach ) ;
/ / all symmetric states
symm = JDD . And ( symm , tmp ) ;
}
}
/ / old version of computeSymmetryFilters ( )
/ * private void computeSymmetryFilters ( ) throws PrismException
{
int i , j , k , n ;
String varNames [ ] [ ] = null ;
JDDNode tmp ;
Expression expr , exprTmp ;
/ / get var names for each symm module
n = modulesFile . getModule ( numModulesBeforeSymm ) . getNumDeclarations ( ) ;
varNames = new String [ numModules ] [ ] ;
for ( i = numModulesBeforeSymm ; i < numModulesBeforeSymm + numSymmModules ; i + + ) {
varNames [ i - numModulesBeforeSymm ] = new String [ n ] ;
j = 0 ;
while ( j < numVars & & varList . getModule ( j ) ! = i ) j + + ;
for ( k = 0 ; k < n ; k + + ) {
varNames [ i - numModulesBeforeSymm ] [ k ] = varList . getName ( j + k ) ;
}
}
/ / array for non - symmetric parts
nonSymms = new JDDNode [ numSymmModules - 1 ] ;
/ / dd for all symmetric states
JDD . Ref ( reach ) ;
symm = reach ;
/ / loop through symmetric module pairs
for ( i = 0 ; i < numSymmModules - 1 ; i + + ) {
/ / expression for ( locally ) symmetric states , i . e . where i + 1 < = i + 2
expr = new ExpressionTrue ( ) ;
for ( j = varNames [ 0 ] . length - 1 ; j > = 0 ; j - - ) {
exprTmp = new ExpressionAnd ( ) ;
( ( ExpressionAnd ) exprTmp ) . addOperand ( new ExpressionBrackets ( new ExpressionRelOp ( new ExpressionVar ( varNames [ i ] [ j ] , 0 ) , "=" , new ExpressionVar ( varNames [ i + 1 ] [ j ] , 0 ) ) ) ) ;
( ( ExpressionAnd ) exprTmp ) . addOperand ( new ExpressionBrackets ( expr ) ) ;
expr = exprTmp ;
exprTmp = new ExpressionOr ( ) ;
( ( ExpressionOr ) exprTmp ) . addOperand ( new ExpressionBrackets ( new ExpressionRelOp ( new ExpressionVar ( varNames [ i ] [ j ] , 0 ) , "<" , new ExpressionVar ( varNames [ i + 1 ] [ j ] , 0 ) ) ) ) ;
( ( ExpressionOr ) exprTmp ) . addOperand ( expr ) ;
expr = exprTmp ;
}
mainLog . println ( expr ) ;
/ / bdd for ( locally ) symmetric states , i . e . where i + 1 < = i + 2
tmp = expr2mtbdd . translateExpression ( expr ) ;
/ / non - ( locally ) - symmetric states
JDD . Ref ( tmp ) ;
JDD . Ref ( reach ) ;
nonSymms [ i ] = JDD . And ( JDD . Not ( tmp ) , reach ) ;
/ / all symmetric states
symm = JDD . And ( symm , tmp ) ;
}
} * /
}
/ / - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -