From 1616ded04623e36fc4565f095579faa0bd9a373a Mon Sep 17 00:00:00 2001 From: Sascha Wunderlich Date: Tue, 18 Dec 2018 13:02:50 +0100 Subject: [PATCH] Clean up AccumulationTransformation --- .../explicit/AccumulationProductCounting.java | 12 ++++ .../explicit/AccumulationProductRegular.java | 15 +++- .../explicit/AccumulationTransformation.java | 70 ++++++------------- 3 files changed, 44 insertions(+), 53 deletions(-) diff --git a/prism/src/explicit/AccumulationProductCounting.java b/prism/src/explicit/AccumulationProductCounting.java index ca434f93..0a21b5d5 100644 --- a/prism/src/explicit/AccumulationProductCounting.java +++ b/prism/src/explicit/AccumulationProductCounting.java @@ -27,6 +27,18 @@ public class AccumulationProductCounting extends AccumulationPr super(originalModel); } + @SuppressWarnings("unchecked") + public static AccumulationProductCounting generate(final Model graph, final ExpressionAccumulation accexp, final Vector rewards, final ProbModelChecker mc, BitSet statesOfInterest) throws PrismException { + switch(graph.getModelType()) { + case DTMC: + return (AccumulationProductCounting)generate((DTMC) graph, accexp, (Vector) rewards, mc, statesOfInterest); + case MDP: + return (AccumulationProductCounting)generate((MDP) graph, accexp, (Vector) rewards, mc, statesOfInterest); + default: + throw new PrismException("Can't handle accumulation product for " + graph.getModelType()); + } + } + public static AccumulationProductCounting generate(final DTMC graph, final ExpressionAccumulation accexp, final Vector rewards, final ProbModelChecker mc, BitSet statesOfInterest) throws PrismException { final AccumulationProductCounting result = new AccumulationProductCounting(graph); // Create auxiliary data diff --git a/prism/src/explicit/AccumulationProductRegular.java b/prism/src/explicit/AccumulationProductRegular.java index 0d944633..367974c6 100644 --- a/prism/src/explicit/AccumulationProductRegular.java +++ b/prism/src/explicit/AccumulationProductRegular.java @@ -34,6 +34,18 @@ public class AccumulationProductRegular extends AccumulationPro super(originalModel); } + @SuppressWarnings("unchecked") + public static AccumulationProductRegular generate(final Model graph, final ExpressionAccumulation accexp, final Vector rewards, final ProbModelChecker mc, BitSet statesOfInterest) throws PrismException { + switch(graph.getModelType()) { + case DTMC: + return (AccumulationProductRegular)generate((DTMC) graph, accexp, (Vector) rewards, mc, statesOfInterest); + case MDP: + return (AccumulationProductRegular)generate((MDP) graph, accexp, (Vector) rewards, mc, statesOfInterest); + default: + throw new PrismException("Can't handle accumulation product for " + graph.getModelType()); + } + } + public static AccumulationProductRegular generate(final DTMC graph, final ExpressionAccumulation accexp, final Vector rewards, final ProbModelChecker mc, BitSet statesOfInterest) throws PrismException { final AccumulationProductRegular result = new AccumulationProductRegular(graph); // Create auxiliary data @@ -96,11 +108,8 @@ public class AccumulationProductRegular extends AccumulationPro public static AccumulationProductRegular generate(final MDP graph, final ExpressionAccumulation accexp, final Vector rewards, final ProbModelChecker mc, BitSet statesOfInterest) throws PrismException { // This is basically the same thing as for DTMCs final AccumulationProductRegular result = new AccumulationProductRegular(graph); - // Create auxiliary data - mc.getLog().println(" [AP] generating aux data..."); result.createAuxData(graph, accexp, rewards, mc); - mc.getLog().println(" done."); class AccumulationMDPProductOperator implements MDPProductOperator { diff --git a/prism/src/explicit/AccumulationTransformation.java b/prism/src/explicit/AccumulationTransformation.java index ee4318ad..ac2bf4d0 100644 --- a/prism/src/explicit/AccumulationTransformation.java +++ b/prism/src/explicit/AccumulationTransformation.java @@ -5,8 +5,7 @@ import java.util.BitSet; import java.util.Vector; import explicit.rewards.ConstructRewards; -import explicit.rewards.MCRewards; -import explicit.rewards.MDPRewards; +import explicit.rewards.Rewards; import parser.ast.Expression; import parser.ast.ExpressionAccumulation; import parser.ast.ExpressionReward; @@ -77,57 +76,28 @@ public class AccumulationTransformation implements ModelExpress mc.getLog().println(" ... a simple expression."); } - // Get the rewards and build the product - switch(originalModel.getModelType()) { - case DTMC: - Vector dtmc_rewards = new Vector(); - - for (int i=0; i < accexp.getConstraint().getFactors().size(); i++) { - Object rewardIndex = accexp.getConstraint().getFactors().get(i).getFunction().getRewardIndex(); - - RewardStruct rewStruct = ExpressionReward.getRewardStructByIndexObject(rewardIndex, mc.modulesFile, originalModel.getConstantValues()); - ConstructRewards constructRewards = new ConstructRewards(); - constructRewards.allowNegativeRewards(); - - MCRewards dtmc_reward = constructRewards.buildMCRewardStructure((DTMC)originalModel, rewStruct, mc.getConstantValues()); - dtmc_rewards.add(i,dtmc_reward); - } - mc.getLog().println(" [AT] performing product construction..."); - if(accexp.hasRegularExpression()) { - product = (AccumulationProductRegular) AccumulationProductRegular.generate((DTMC)originalModel, accexp, dtmc_rewards, mc, statesOfInterest); - } else if (accexp.hasBoundExpression()) { - product = (AccumulationProductCounting) AccumulationProductCounting.generate((DTMC)originalModel, accexp, dtmc_rewards, mc, statesOfInterest); - } else { - throw new PrismException("Accumulation Expression has no valid monitor!"); - } - break; - case MDP: - Vector mdp_rewards = new Vector(); + // Get the rewards and build the product + Vector rewards = new Vector(); + + for (int i=0; i < accexp.getConstraint().getFactors().size(); i++) { + Object rewardIndex = accexp.getConstraint().getFactors().get(i).getFunction().getRewardIndex(); - for (int i=0; i < accexp.getConstraint().getFactors().size(); i++) { - Object rewardIndex = accexp.getConstraint().getFactors().get(i).getFunction().getRewardIndex(); - - RewardStruct rewStruct = ExpressionReward.getRewardStructByIndexObject(rewardIndex, mc.modulesFile, originalModel.getConstantValues()); - ConstructRewards constructRewards = new ConstructRewards(); - constructRewards.allowNegativeRewards(); - - MDPRewards mdp_reward = constructRewards.buildMDPRewardStructure((MDP)originalModel, rewStruct, mc.getConstantValues()); - mdp_rewards.add(i,mdp_reward); - } - mc.getLog().println(" [AT] performing product construction..."); - if(accexp.hasRegularExpression()) { - product = (AccumulationProductRegular) AccumulationProductRegular.generate((MDP)originalModel, accexp, mdp_rewards, mc, statesOfInterest); - } else if (accexp.hasBoundExpression()) { - product = (AccumulationProductCounting) AccumulationProductCounting.generate((MDP)originalModel, accexp, mdp_rewards, mc, statesOfInterest); - } else { - throw new PrismException("Accumulation Expression has no valid monitor!"); - } - mc.getLog().println(" [AT] finished product construction: " + product.getTransformedModel().getNumStates()); + RewardStruct rewStruct = ExpressionReward.getRewardStructByIndexObject(rewardIndex, mc.modulesFile, originalModel.getConstantValues()); + ConstructRewards constructRewards = new ConstructRewards(); + constructRewards.allowNegativeRewards(); - break; - default: - throw new PrismException("Can't handle weight functions for " + originalModel.getModelType()); + Rewards dtmc_reward = constructRewards.buildRewardStructure(originalModel, rewStruct, mc.getConstantValues()); + rewards.add(i,dtmc_reward); } + if(accexp.hasRegularExpression()) { + product = (AccumulationProductRegular) AccumulationProductRegular.generate(originalModel, accexp, rewards, mc, statesOfInterest); + } else if (accexp.hasBoundExpression()) { + product = (AccumulationProductCounting) AccumulationProductCounting.generate(originalModel, accexp, rewards, mc, statesOfInterest); + } else { + throw new PrismException("Accumulation Expression has no valid monitor!"); + } + + mc.getLog().println(" [AT] finished product construction: " + product.getTransformedModel().getNumStates()); // Transform the model mc.getLog().println(" [AT] getting the init/run/goal states: ");