From 90511afb27787b14829937121dee6cba88a1d3a7 Mon Sep 17 00:00:00 2001 From: Sascha Wunderlich Date: Mon, 14 Nov 2016 14:51:35 +0100 Subject: [PATCH] accumulation: refactor some common function in product --- prism/src/explicit/AccumulationProduct.java | 107 +++++++++++++++++ .../explicit/AccumulationProductCounting.java | 69 +++-------- .../explicit/AccumulationProductRegular.java | 111 ++---------------- .../explicit/AccumulationTransformation.java | 6 +- 4 files changed, 137 insertions(+), 156 deletions(-) diff --git a/prism/src/explicit/AccumulationProduct.java b/prism/src/explicit/AccumulationProduct.java index 74a83960..6c0f2b9a 100644 --- a/prism/src/explicit/AccumulationProduct.java +++ b/prism/src/explicit/AccumulationProduct.java @@ -5,6 +5,9 @@ import java.util.BitSet; import java.util.Iterator; import java.util.Map; +import parser.ast.AccumulationFactor; +import parser.ast.ExpressionAccumulation; +import prism.IntegerBound; import prism.PrismException; import prism.PrismFileLog; import prism.PrismLog; @@ -47,6 +50,110 @@ public abstract class AccumulationProduct extends Pro public int getNumberOfTracks() { return numberOfTracks; } + + protected abstract boolean isFinalTrack(final AccumulationTrack track, final ExpressionAccumulation accexp, final ProbModelChecker mc) + throws PrismException; + + protected boolean isGoodTrack(final AccumulationTrack track, final ExpressionAccumulation accexp, final ProbModelChecker mc) + throws PrismException { + // Only final tracks can be good + if (!isFinalTrack(track,accexp,mc)) { return false; } + boolean isGood = false; + + // Collect the weight linear combination, factor*weight+... + int lhs = 0; + int factorNr = 0; + for (AccumulationFactor factor : accexp.getConstraint().getFactors()) { + lhs += factor.getFactor().evaluateInt(mc.getConstantValues()) + * track.getWeight(factorNr); + } + + // Check the bound + IntegerBound rhs = IntegerBound.fromTemporalOperatorBound(accexp.getConstraint().getBound(), mc.getConstantValues(), true); + + // For DIA operators, we just check the bound. + // For BOX operators, we check the INVERTED bound. + switch(accexp.getSymbol()) { + case ACCBOXMINUS: + case ACCBOXPLUS: + if (!rhs.isInBounds(lhs)) { + isGood = true; + } + break; + case ACCDIAMINUS: + case ACCDIAPLUS: + if (rhs.isInBounds(lhs)) { + isGood = true; + } + break; + default: + throw new RuntimeException("Oh boy!"); + } + //if(isGood) {mc.getLog().print("+");} else {mc.getLog().print("-");} + return isGood; + } + + protected boolean isGoodAccState(final AccumulationState state, final ExpressionAccumulation accexp, final ProbModelChecker mc) + throws PrismException { + return state.hasGoodTrack(); + } + + protected abstract Component getInitialComponent(); + + protected abstract AccumulationTrack updateTrack(Integer modelFromStateId, final AccumulationTrack track, + final ExpressionAccumulation accexp, final double[] weights, final StateModelChecker mc); + + protected AccumulationState updateAccumulationState(final int modelFromStateId, + final AccumulationState accstate, final ExpressionAccumulation accexp, + final double[] weights, final ProbModelChecker mc) throws PrismException { + // We have the current accumulation state, the current model id and the accumulation expression. + + // Get the old tracker and tracks. + AccumulationTracker oldTracker = accstate.getTracker(trackers); + ArrayList> oldTracks = oldTracker.getTracks(); + + BitSet oldGoodTracks = accstate.getGoodTracks(); + BitSet newGoodTracks = (BitSet) oldGoodTracks.clone(); + + // This restart will be... + int newLastRestartNr = accstate.getNextRestartNr(); + //mc.getLog().print(newLastRestartNr); + + // Build the new tracks. + ArrayList> newTracks = new ArrayList<>(); + + int trackNr = 0; + for(AccumulationTrack oldTrack : oldTracks) { + AccumulationTrack newTrack; + + // restart or advance + if(trackNr == newLastRestartNr) { + //assert oldTrack == null : "Track " + newLastRestartNr + " is not null!"; + newTrack = new AccumulationTrack(numberOfWeights, getInitialComponent()); + newGoodTracks.clear(trackNr); + } else if (oldTrack == null) { + newTrack = null; + } else { + assert oldTrack != null; + newTrack = updateTrack(modelFromStateId, oldTrack, accexp, weights, mc); + } + + // check whether the track is good + if(!newGoodTracks.get(trackNr)) { + newGoodTracks.set(trackNr, isGoodTrack(newTrack, accexp, mc)); + } + + newTracks.add(newTrack); + trackNr++; + } + + AccumulationTracker newTracker = new AccumulationTracker<>(newTracks); + + + int newTrackerId = trackers.findOrAdd(newTracker); + + return new AccumulationState<>(newTrackerId, newLastRestartNr, numberOfTracks, newGoodTracks); + } public void exportToDotFile(String filename) throws PrismException { try (PrismFileLog log = PrismFileLog.create(filename)) { diff --git a/prism/src/explicit/AccumulationProductCounting.java b/prism/src/explicit/AccumulationProductCounting.java index 8cd63924..2afc149f 100644 --- a/prism/src/explicit/AccumulationProductCounting.java +++ b/prism/src/explicit/AccumulationProductCounting.java @@ -7,7 +7,6 @@ import java.util.Vector; import explicit.rewards.MCRewards; import explicit.rewards.MDPRewards; import explicit.rewards.Rewards; -import parser.ast.AccumulationFactor; import parser.ast.ExpressionAccumulation; import prism.IntegerBound; import prism.PrismException; @@ -142,7 +141,7 @@ public class AccumulationProductCounting extends AccumulationPr @Override public void finish() throws PrismException { // Do nothing - mc.getLog().println("."); + //mc.getLog().println("."); } @Override @@ -156,7 +155,8 @@ public class AccumulationProductCounting extends AccumulationPr return result; } - private boolean isFinalTrack(final AccumulationTrack track, final ExpressionAccumulation accexp, final ProbModelChecker mc) + @Override + protected boolean isFinalTrack(final AccumulationTrack track, final ExpressionAccumulation accexp, final ProbModelChecker mc) throws PrismException { boolean isFinal = false; if ( track != null ) { @@ -166,51 +166,12 @@ public class AccumulationProductCounting extends AccumulationPr return isFinal; } - private boolean isGoodAccState(final AccumulationState state, final ExpressionAccumulation accexp, final ProbModelChecker mc) - throws PrismException { - return state.hasGoodTrack(); + @Override + protected Integer getInitialComponent() { + return 0; } - - private boolean isGoodTrack(final AccumulationTrack track, final ExpressionAccumulation accexp, final ProbModelChecker mc) - throws PrismException { - // Only final tracks can be good - if (!isFinalTrack(track,accexp,mc)) { return false; } - boolean isGood = false; - - // Collect the weight linear combination, factor*weight+... - int lhs = 0; - int factorNr = 0; - for (AccumulationFactor factor : accexp.getConstraint().getFactors()) { - lhs += factor.getFactor().evaluateInt(mc.getConstantValues()) - * track.getWeight(factorNr); - } - - // Check the bound - IntegerBound rhs = IntegerBound.fromTemporalOperatorBound(accexp.getConstraint().getBound(), mc.getConstantValues(), true); - // For DIA operators, we just check the bound. - // For BOX operators, we check the INVERTED bound. - switch(accexp.getSymbol()) { - case ACCBOXMINUS: - case ACCBOXPLUS: - if (!rhs.isInBounds(lhs)) { - isGood = true; - } - break; - case ACCDIAMINUS: - case ACCDIAPLUS: - if (rhs.isInBounds(lhs)) { - isGood = true; - } - break; - default: - throw new RuntimeException("Oh boy!"); - } - if(isGood) {mc.getLog().print("+");} else {mc.getLog().print("-");} - return isGood; - } - - private AccumulationState updateAccumulationState(final int modelFromStateId, + protected AccumulationState updateAccumulationState(final int modelFromStateId, final AccumulationState accstate, final ExpressionAccumulation accexp, final double[] weights, final ProbModelChecker mc) throws PrismException { // We have the current accumulation state, the current model id and the accumulation expression. @@ -224,7 +185,7 @@ public class AccumulationProductCounting extends AccumulationPr // This restart will be... int newLastRestartNr = accstate.getNextRestartNr(); - mc.getLog().print(newLastRestartNr); + //mc.getLog().print(newLastRestartNr); // Build the new tracks. ArrayList> newTracks = new ArrayList<>(); @@ -242,7 +203,7 @@ public class AccumulationProductCounting extends AccumulationPr newTrack = null; } else { assert oldTrack != null; - newTrack = updateTrackBounds(oldTrack, accexp, weights, mc); + newTrack = updateTrack(modelFromStateId, oldTrack, accexp, weights, mc); } // check whether the track is good @@ -262,10 +223,16 @@ public class AccumulationProductCounting extends AccumulationPr return new AccumulationState<>(newTrackerId, newLastRestartNr, numberOfTracks, newGoodTracks); } - private AccumulationTrack updateTrackBounds(final AccumulationTrack track, - final ExpressionAccumulation accexp, final double[] weights, final StateModelChecker mc) throws PrismException { + @Override + protected AccumulationTrack updateTrack(Integer modelFromStateId, final AccumulationTrack track, + final ExpressionAccumulation accexp, final double[] weights, final StateModelChecker mc) { int currentStep = track.getComponent(); - int maxStep = IntegerBound.fromTemporalOperatorBound(accexp.getBoundExpression(), mc.getConstantValues(), true).getHighestInteger(); + int maxStep = 0; + try { + maxStep = IntegerBound.fromTemporalOperatorBound(accexp.getBoundExpression(), mc.getConstantValues(), true).getHighestInteger(); + } catch(PrismException e) { + throw new RuntimeException("..."); + } // If we are done, return null-Track if (currentStep >= maxStep) { return null; } diff --git a/prism/src/explicit/AccumulationProductRegular.java b/prism/src/explicit/AccumulationProductRegular.java index ecd0b855..cefa805b 100644 --- a/prism/src/explicit/AccumulationProductRegular.java +++ b/prism/src/explicit/AccumulationProductRegular.java @@ -11,10 +11,8 @@ import automata.finite.State; import explicit.rewards.MCRewards; import explicit.rewards.MDPRewards; import explicit.rewards.Rewards; -import parser.ast.AccumulationFactor; import parser.ast.ExpressionAccumulation; import parser.ast.ExpressionRegular; -import prism.IntegerBound; import prism.PrismException; /** @@ -163,7 +161,8 @@ public class AccumulationProductRegular extends AccumulationPro return result; } - private boolean isFinalTrack(AccumulationTrack track, ExpressionAccumulation accexp, ProbModelChecker mc) throws PrismException { + @Override + protected boolean isFinalTrack(AccumulationTrack track, ExpressionAccumulation accexp, ProbModelChecker mc) throws PrismException { boolean isFinal = false; if ( track != null ) { isFinal = automaton.isAcceptingState(track.getComponent()); @@ -172,106 +171,13 @@ public class AccumulationProductRegular extends AccumulationPro return isFinal; } - private boolean isGoodAccState(AccumulationState state, ExpressionAccumulation accexp, ProbModelChecker mc) throws PrismException { - return state.hasGoodTrack(); + @Override + protected State getInitialComponent() { + return automaton.getInitialState(); } - private boolean isGoodTrack(AccumulationTrack track, ExpressionAccumulation accexp, ProbModelChecker mc) throws PrismException { - mc.getLog().println("Checking " + track + " for goodness..."); - // Only final tracks can be good - mc.getLog().println("Final? " + isFinalTrack(track,accexp,mc)); - if (!isFinalTrack(track,accexp,mc)) { return false; } - boolean isGood = false; - - //System.out.println("Final: " + track); - - // Collect the weight linear combination, factor*weight+... - int lhs = 0; - int factorNr = 0; - for (AccumulationFactor factor : accexp.getConstraint().getFactors()) { - lhs += factor.getFactor().evaluateInt(mc.getConstantValues()) - * track.getWeight(factorNr); - } - - // Check the bound - IntegerBound rhs = IntegerBound.fromTemporalOperatorBound(accexp.getConstraint().getBound(), mc.getConstantValues(), true); - - // For DIA operators, we just check the bound. - // For BOX operators, we check the INVERTED bound. - switch(accexp.getSymbol()) { - case ACCBOXMINUS: - case ACCBOXPLUS: - if (!rhs.isInBounds(lhs)) { - isGood = true; - } - break; - case ACCDIAMINUS: - case ACCDIAPLUS: - if (rhs.isInBounds(lhs)) { - isGood = true; - } - break; - default: - throw new RuntimeException("Oh boy!"); - } - - mc.getLog().println("Good? " + lhs + rhs + " : " + isGood); - - return isGood; - } - - private AccumulationState updateAccumulationState(final int modelFromStateId, - final AccumulationState accstate, final ExpressionAccumulation accexp, - final double[] weights, final ProbModelChecker mc) throws PrismException { - // We have the current accumulation state, the current model id and the accumulation expression. - - // Get the old tracker and tracks. - AccumulationTracker oldTracker = trackers.getById(accstate.getTrackerId()); - ArrayList> oldTracks = oldTracker.getTracks(); - - BitSet oldGoodTracks = accstate.getGoodTracks(); - BitSet newGoodTracks = (BitSet) oldGoodTracks.clone(); - - // This restart will be... - int newLastRestartNr = accstate.getNextRestartNr(); - mc.getLog().print(newLastRestartNr); - - // Build the new tracks. - ArrayList> newTracks = new ArrayList<>(); - - int trackNr = 0; - for(AccumulationTrack oldTrack : oldTracks) { - AccumulationTrack newTrack; - - // restart or advance - if(trackNr == newLastRestartNr) { - //assert oldTrack == null : "Track " + newLastRestartNr + " is not null!"; - newTrack = new AccumulationTrack(numberOfWeights, automaton.getInitialState()); //TODO: off-by-one? - newGoodTracks.clear(trackNr); - } else if (oldTrack == null) { - newTrack = null; - } else { - assert oldTrack != null; - newTrack = updateTrackRegular(modelFromStateId, oldTrack, accexp, weights, mc); - } - - // check whether the track is good - if(!newGoodTracks.get(trackNr)) { - newGoodTracks.set(trackNr, isGoodTrack(newTrack, accexp, mc)); - } - - newTracks.add(newTrack); - trackNr++; - } - - AccumulationTracker newTracker = new AccumulationTracker<>(newTracks); - - - int newTrackerId = trackers.findOrAdd(newTracker); - - return new AccumulationState(newTrackerId, newLastRestartNr, numberOfTracks, newGoodTracks); - } - private AccumulationTrack updateTrackRegular(Integer modelFromStateId, AccumulationTrack track, ExpressionAccumulation accexp, double[] weights, StateModelChecker mc) { + @Override + protected AccumulationTrack updateTrack(Integer modelFromStateId, AccumulationTrack track, ExpressionAccumulation accexp, double[] weights, StateModelChecker mc) { State currentState = track.getComponent(); // Build EdgeLabel from labels. @@ -345,7 +251,8 @@ public class AccumulationProductRegular extends AccumulationPro automaton = nfa.determinize(); automaton.trim(); // This should remove cycles. - nfa.exportToDotFile("DEBUG-automaton-nfa.dot"); + //DEBUG + //nfa.exportToDotFile("DEBUG-automaton-nfa.dot"); if (!automaton.isAcyclic()) { throw new PrismException("Cannot handle cyclic automata!"); diff --git a/prism/src/explicit/AccumulationTransformation.java b/prism/src/explicit/AccumulationTransformation.java index 21d94604..cc2fa40b 100644 --- a/prism/src/explicit/AccumulationTransformation.java +++ b/prism/src/explicit/AccumulationTransformation.java @@ -133,10 +133,10 @@ public class AccumulationTransformation implements ModelExpress // Transform the expression ReplaceAccumulationExpression replace = new ReplaceAccumulationExpression(accexp, label, product.getNumberOfTracks()-1); transformedExpression = (Expression)transformedExpression.accept(replace); - mc.getLog().println("Transformed " + originalExpression.toString() + - "\n into " + transformedExpression.toString()); + mc.getLog().println(" [AT] " + originalExpression.toString() + "\n" + + " -> " + transformedExpression.toString()); //DEBUG: output dotfile - product.exportToDotFile("DEBUG-product.dot"); + //product.exportToDotFile("DEBUG-product.dot"); } public String gensymLabel(String prefix, Model model) {