package explicit;

import explicit.ProbModelChecker;
import explicit.rewards.MDPRewards;
import explicit.rewards.StateRewardsSimple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import prism.Accuracy;
import prism.AccuracyFactory;
import prism.Pair;
import prism.PrismComponent;
import prism.PrismException;
import prism.PrismLog;
import prism.PrismNotSupportedException;
import prism.PrismSettings;
import prism.PrismUtils;
import strat.FMDObsStrategyBeliefs;

/* loaded from: input_file:explicit/POMDPModelChecker.class */
public class POMDPModelChecker extends ProbModelChecker {

    /* JADX INFO: Access modifiers changed from: package-private */
    @FunctionalInterface
    /* loaded from: input_file:explicit/POMDPModelChecker$BeliefMDPBackUp.class */
    public interface BeliefMDPBackUp extends BiFunction<Belief, BeliefMDPState, Pair<Double, Integer>> {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:explicit/POMDPModelChecker$BeliefMDPState.class */
    public class BeliefMDPState {
        public List<HashMap<Belief, Double>> trans = new ArrayList();
        public List<Double> rewards = new ArrayList();

        public BeliefMDPState() {
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:explicit/POMDPModelChecker$POMDPStrategyModel.class */
    public class POMDPStrategyModel {
        public MDP<Double> mdp;
        public List<int[]> mdpStates;
        public List<double[]> unobsBeliefs;
        public MDPRewards<Double> mdpRewards;

        POMDPStrategyModel() {
        }
    }

    public POMDPModelChecker(PrismComponent prismComponent) throws PrismException {
        super(prismComponent);
    }

    public ModelCheckerResult computeReachProbs(POMDP<Double> pomdp, BitSet bitSet, BitSet bitSet2, boolean z, BitSet bitSet3) throws PrismException {
        if (bitSet3 == null) {
            bitSet3 = new BitSet();
            bitSet3.set(pomdp.getFirstInitialState());
        } else if (bitSet3.cardinality() > 1) {
            throw new PrismNotSupportedException("POMDPs can only be solved from a single start state");
        }
        long currentTimeMillis = System.currentTimeMillis();
        this.mainLog.println("\nStarting probabilistic reachability (" + (z ? "min" : "max") + ")...");
        ModelCheckerResult computeReachProbsFixedGrid = computeReachProbsFixedGrid(pomdp, bitSet, bitSet2, z, bitSet3.nextSetBit(0));
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        this.mainLog.println("Probabilistic reachability took " + (currentTimeMillis2 / 1000.0d) + " seconds.");
        computeReachProbsFixedGrid.timeTaken = currentTimeMillis2 / 1000.0d;
        return computeReachProbsFixedGrid;
    }

    protected ModelCheckerResult computeReachProbsFixedGrid(POMDP<Double> pomdp, BitSet bitSet, BitSet bitSet2, boolean z, int i) throws PrismException {
        long currentTimeMillis = System.currentTimeMillis();
        this.mainLog.println("Starting fixed-resolution grid approximation (" + (z ? "min" : "max") + ")...");
        BitSet observationsMatchingStates = getObservationsMatchingStates(pomdp, bitSet2);
        if (observationsMatchingStates == null) {
            throw new PrismException("Target for reachability is not observable");
        }
        BitSet observationsMatchingStates2 = bitSet == null ? null : getObservationsMatchingStates(pomdp, bitSet);
        if (bitSet != null && observationsMatchingStates2 == null) {
            throw new PrismException("Left-hand side of until is not observable");
        }
        this.mainLog.println("target obs=" + observationsMatchingStates.cardinality() + (observationsMatchingStates2 == null ? PrismSettings.DEFAULT_STRING : ", remain obs=" + observationsMatchingStates2.cardinality()));
        BitSet bitSet3 = new BitSet();
        bitSet3.set(0, pomdp.getNumObservations());
        bitSet3.andNot(observationsMatchingStates);
        if (observationsMatchingStates2 != null) {
            bitSet3.and(observationsMatchingStates2);
        }
        List<Belief> initialiseGridPoints = initialiseGridPoints(pomdp, bitSet3);
        this.mainLog.println("Grid statistics: resolution=" + this.gridResolution + ", points=" + initialiseGridPoints.size());
        this.mainLog.println("Building belief space approximation...");
        List<BeliefMDPState> buildBeliefMDP = buildBeliefMDP(pomdp, null, initialiseGridPoints);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Belief belief : initialiseGridPoints) {
            hashMap.put(belief, Double.valueOf(PrismSettings.DEFAULT_DOUBLE));
            hashMap2.put(belief, Double.valueOf(PrismSettings.DEFAULT_DOUBLE));
        }
        Function function = belief2 -> {
            return Double.valueOf(approximateReachProb(belief2, hashMap2, observationsMatchingStates, bitSet3));
        };
        BeliefMDPBackUp beliefMDPBackUp = (belief3, beliefMDPState) -> {
            return approximateReachProbBackup(belief3, beliefMDPState, function, z);
        };
        this.mainLog.println("Solving belief space approximation...");
        long currentTimeMillis2 = System.currentTimeMillis();
        int i2 = 0;
        boolean z2 = false;
        while (!z2 && i2 < this.maxIters) {
            int size = initialiseGridPoints.size();
            for (int i3 = 0; i3 < size; i3++) {
                Belief belief4 = initialiseGridPoints.get(i3);
                hashMap.put(belief4, beliefMDPBackUp.apply(belief4, buildBeliefMDP.get(i3)).first);
            }
            z2 = PrismUtils.doublesAreClose(hashMap, hashMap2, this.termCritParam, this.termCrit == ProbModelChecker.TermCrit.RELATIVE);
            for (Map.Entry entry : hashMap.entrySet()) {
                hashMap2.put((Belief) entry.getKey(), (Double) entry.getValue());
            }
            i2++;
        }
        if (!z2 && this.errorOnNonConverge) {
            throw new PrismException(("Iterative method did not converge within " + i2 + " iterations.") + "\nConsider using a different numerical method or increasing the maximum number of iterations");
        }
        long currentTimeMillis3 = System.currentTimeMillis() - currentTimeMillis2;
        this.mainLog.print("Belief space value iteration (" + (z ? "min" : "max") + ")");
        this.mainLog.println(" took " + i2 + " iterations and " + (currentTimeMillis3 / 1000.0d) + " seconds.");
        double doubleValue = ((Double) function.apply(Belief.pointDistribution(i, pomdp))).doubleValue();
        Accuracy valueIteration = AccuracyFactory.valueIteration(this.termCritParam, PrismUtils.measureSupNorm(hashMap, hashMap2, this.termCrit == ProbModelChecker.TermCrit.RELATIVE), this.termCrit == ProbModelChecker.TermCrit.RELATIVE);
        PrismLog prismLog = this.mainLog;
        valueIteration.toString(Double.valueOf(doubleValue));
        prismLog.println("Outer bound: " + doubleValue + " (" + prismLog + ")");
        this.mainLog.println("\nBuilding strategy-induced model...");
        POMDPStrategyModel buildStrategyModel = buildStrategyModel(pomdp, i, null, observationsMatchingStates, bitSet3, beliefMDPBackUp);
        MDP<Double> mdp = buildStrategyModel.mdp;
        this.mainLog.print("Strategy-induced model: " + mdp.infoString());
        FMDObsStrategyBeliefs fMDObsStrategyBeliefs = this.genStrat ? new FMDObsStrategyBeliefs(pomdp, buildStrategyModel.mdp, buildStrategyModel.mdpStates, buildStrategyModel.unobsBeliefs) : null;
        MDPModelChecker mDPModelChecker = new MDPModelChecker(this);
        mDPModelChecker.setGenStrat(false);
        ModelCheckerResult computeReachProbs = mDPModelChecker.computeReachProbs(mdp, mdp.getLabelStates("target"), true);
        double d = computeReachProbs.soln[0];
        Accuracy accuracy = computeReachProbs.accuracy;
        String str = d;
        if (accuracy != null) {
            str = str + " (" + accuracy.toString(Double.valueOf(d)) + ")";
        }
        this.mainLog.println("Inner bound: " + str);
        long currentTimeMillis4 = System.currentTimeMillis() - currentTimeMillis;
        this.mainLog.print("\nFixed-resolution grid approximation (" + (z ? "min" : "max") + ")");
        this.mainLog.println(" took " + (currentTimeMillis4 / 1000.0d) + " seconds.");
        Pair<Double, Accuracy> valueAndAccuracyFromInterval = z ? AccuracyFactory.valueAndAccuracyFromInterval(doubleValue, valueIteration, d, accuracy) : AccuracyFactory.valueAndAccuracyFromInterval(d, accuracy, doubleValue, valueIteration);
        double doubleValue2 = valueAndAccuracyFromInterval.first.doubleValue();
        Accuracy accuracy2 = valueAndAccuracyFromInterval.second;
        PrismLog prismLog2 = this.mainLog;
        double resultLowerBound = accuracy2.getResultLowerBound(doubleValue2);
        accuracy2.getResultUpperBound(doubleValue2);
        prismLog2.println("Result bounds: [" + resultLowerBound + "," + prismLog2 + "]");
        double[] dArr = new double[pomdp.getNumStates()];
        dArr[i] = doubleValue2;
        ModelCheckerResult modelCheckerResult = new ModelCheckerResult();
        if (this.genStrat) {
            modelCheckerResult.f3strat = fMDObsStrategyBeliefs;
        }
        modelCheckerResult.soln = dArr;
        modelCheckerResult.accuracy = accuracy2;
        modelCheckerResult.numIters = i2;
        modelCheckerResult.timeTaken = currentTimeMillis4 / 1000.0d;
        return modelCheckerResult;
    }

    public ModelCheckerResult computeReachRewards(POMDP<Double> pomdp, MDPRewards<Double> mDPRewards, BitSet bitSet, boolean z, BitSet bitSet2) throws PrismException {
        if (bitSet2 == null) {
            bitSet2 = new BitSet();
            bitSet2.set(pomdp.getFirstInitialState());
        } else if (bitSet2.cardinality() > 1) {
            throw new PrismNotSupportedException("POMDPs can only be solved from a single start state");
        }
        long currentTimeMillis = System.currentTimeMillis();
        this.mainLog.println("\nStarting expected reachability (" + (z ? "min" : "max") + ")...");
        ModelCheckerResult computeReachRewardsFixedGrid = computeReachRewardsFixedGrid(pomdp, mDPRewards, bitSet, z, bitSet2.nextSetBit(0));
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        this.mainLog.println("Expected reachability took " + (currentTimeMillis2 / 1000.0d) + " seconds.");
        computeReachRewardsFixedGrid.timeTaken = currentTimeMillis2 / 1000.0d;
        return computeReachRewardsFixedGrid;
    }

    protected ModelCheckerResult computeReachRewardsFixedGrid(POMDP<Double> pomdp, MDPRewards<Double> mDPRewards, BitSet bitSet, boolean z, int i) throws PrismException {
        long currentTimeMillis = System.currentTimeMillis();
        this.mainLog.println("Starting fixed-resolution grid approximation (" + (z ? "min" : "max") + ")...");
        BitSet observationsMatchingStates = getObservationsMatchingStates(pomdp, bitSet);
        if (observationsMatchingStates == null) {
            throw new PrismException("Target for expected reachability is not observable");
        }
        BitSet prob1 = new MDPModelChecker(this).prob1(pomdp, null, bitSet, false, null);
        prob1.flip(0, pomdp.getNumStates());
        BitSet observationsCoveredByStates = getObservationsCoveredByStates(pomdp, prob1);
        this.mainLog.println("target obs=" + observationsMatchingStates.cardinality() + ", inf obs=" + observationsCoveredByStates.cardinality());
        BitSet bitSet2 = new BitSet();
        bitSet2.set(0, pomdp.getNumObservations());
        bitSet2.andNot(observationsMatchingStates);
        bitSet2.andNot(observationsCoveredByStates);
        List<Belief> initialiseGridPoints = initialiseGridPoints(pomdp, bitSet2);
        this.mainLog.println("Grid statistics: resolution=" + this.gridResolution + ", points=" + initialiseGridPoints.size());
        this.mainLog.println("Building belief space approximation...");
        List<BeliefMDPState> buildBeliefMDP = buildBeliefMDP(pomdp, mDPRewards, initialiseGridPoints);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Belief belief : initialiseGridPoints) {
            hashMap.put(belief, Double.valueOf(PrismSettings.DEFAULT_DOUBLE));
            hashMap2.put(belief, Double.valueOf(PrismSettings.DEFAULT_DOUBLE));
        }
        Function function = belief2 -> {
            return Double.valueOf(approximateReachReward(belief2, hashMap2, observationsMatchingStates, observationsCoveredByStates));
        };
        BeliefMDPBackUp beliefMDPBackUp = (belief3, beliefMDPState) -> {
            return approximateReachRewardBackup(belief3, beliefMDPState, function, z);
        };
        this.mainLog.println("Solving belief space approximation...");
        long currentTimeMillis2 = System.currentTimeMillis();
        int i2 = 0;
        boolean z2 = false;
        while (!z2 && i2 < this.maxIters) {
            int size = initialiseGridPoints.size();
            for (int i3 = 0; i3 < size; i3++) {
                Belief belief4 = initialiseGridPoints.get(i3);
                hashMap.put(belief4, beliefMDPBackUp.apply(belief4, buildBeliefMDP.get(i3)).first);
            }
            z2 = PrismUtils.doublesAreClose(hashMap, hashMap2, this.termCritParam, this.termCrit == ProbModelChecker.TermCrit.RELATIVE);
            for (Map.Entry entry : hashMap.entrySet()) {
                hashMap2.put((Belief) entry.getKey(), (Double) entry.getValue());
            }
            i2++;
        }
        if (!z2 && this.errorOnNonConverge) {
            throw new PrismException(("Iterative method did not converge within " + i2 + " iterations.") + "\nConsider using a different numerical method or increasing the maximum number of iterations");
        }
        long currentTimeMillis3 = System.currentTimeMillis() - currentTimeMillis2;
        this.mainLog.print("Belief space value iteration (" + (z ? "min" : "max") + ")");
        this.mainLog.println(" took " + i2 + " iterations and " + (currentTimeMillis3 / 1000.0d) + " seconds.");
        double doubleValue = ((Double) function.apply(Belief.pointDistribution(i, pomdp))).doubleValue();
        Accuracy valueIteration = AccuracyFactory.valueIteration(this.termCritParam, PrismUtils.measureSupNorm(hashMap, hashMap2, this.termCrit == ProbModelChecker.TermCrit.RELATIVE), this.termCrit == ProbModelChecker.TermCrit.RELATIVE);
        PrismLog prismLog = this.mainLog;
        valueIteration.toString(Double.valueOf(doubleValue));
        prismLog.println("Outer bound: " + doubleValue + " (" + prismLog + ")");
        this.mainLog.println("\nBuilding strategy-induced model...");
        POMDPStrategyModel buildStrategyModel = buildStrategyModel(pomdp, i, mDPRewards, observationsMatchingStates, bitSet2, beliefMDPBackUp);
        MDP<Double> mdp = buildStrategyModel.mdp;
        MDPRewards<Double> mDPRewards2 = buildStrategyModel.mdpRewards;
        this.mainLog.print("Strategy-induced model: " + mdp.infoString());
        FMDObsStrategyBeliefs fMDObsStrategyBeliefs = this.genStrat ? new FMDObsStrategyBeliefs(pomdp, buildStrategyModel.mdp, buildStrategyModel.mdpStates, buildStrategyModel.unobsBeliefs) : null;
        MDPModelChecker mDPModelChecker = new MDPModelChecker(this);
        mDPModelChecker.setGenStrat(false);
        ModelCheckerResult computeReachRewards = mDPModelChecker.computeReachRewards(mdp, mDPRewards2, mdp.getLabelStates("target"), true);
        double d = computeReachRewards.soln[0];
        Accuracy accuracy = computeReachRewards.accuracy;
        String str = d;
        if (accuracy != null) {
            str = str + " (" + accuracy.toString(Double.valueOf(d)) + ")";
        }
        this.mainLog.println("Inner bound: " + str);
        long currentTimeMillis4 = System.currentTimeMillis() - currentTimeMillis;
        this.mainLog.print("\nFixed-resolution grid approximation (" + (z ? "min" : "max") + ")");
        this.mainLog.println(" took " + (currentTimeMillis4 / 1000.0d) + " seconds.");
        Pair<Double, Accuracy> valueAndAccuracyFromInterval = z ? AccuracyFactory.valueAndAccuracyFromInterval(doubleValue, valueIteration, d, accuracy) : AccuracyFactory.valueAndAccuracyFromInterval(d, accuracy, doubleValue, valueIteration);
        double doubleValue2 = valueAndAccuracyFromInterval.first.doubleValue();
        Accuracy accuracy2 = valueAndAccuracyFromInterval.second;
        PrismLog prismLog2 = this.mainLog;
        double resultLowerBound = accuracy2.getResultLowerBound(doubleValue2);
        accuracy2.getResultUpperBound(doubleValue2);
        prismLog2.println("Result bounds: [" + resultLowerBound + "," + prismLog2 + "]");
        double[] dArr = new double[pomdp.getNumStates()];
        dArr[i] = doubleValue2;
        ModelCheckerResult modelCheckerResult = new ModelCheckerResult();
        if (this.genStrat) {
            modelCheckerResult.f3strat = fMDObsStrategyBeliefs;
        }
        modelCheckerResult.soln = dArr;
        modelCheckerResult.accuracy = accuracy2;
        modelCheckerResult.numIters = i2;
        modelCheckerResult.timeTaken = currentTimeMillis4 / 1000.0d;
        return modelCheckerResult;
    }

    protected BitSet getObservationsMatchingStates(POMDP<Double> pomdp, BitSet bitSet) {
        BitSet bitSet2 = new BitSet();
        int nextSetBit = bitSet.nextSetBit(0);
        while (true) {
            int i = nextSetBit;
            if (i < 0) {
                break;
            }
            bitSet2.set(pomdp.getObservation(i));
            nextSetBit = bitSet.nextSetBit(i + 1);
        }
        BitSet bitSet3 = new BitSet();
        int numStates = pomdp.getNumStates();
        for (int i2 = 0; i2 < numStates; i2++) {
            if (bitSet2.get(pomdp.getObservation(i2))) {
                bitSet3.set(i2);
            }
        }
        if (bitSet.equals(bitSet3)) {
            return bitSet2;
        }
        return null;
    }

    protected BitSet getObservationsCoveredByStates(POMDP<Double> pomdp, BitSet bitSet) throws PrismException {
        BitSet bitSet2 = new BitSet();
        int nextSetBit = bitSet.nextSetBit(0);
        while (true) {
            int i = nextSetBit;
            if (i < 0) {
                break;
            }
            bitSet2.set(pomdp.getObservation(i));
            nextSetBit = bitSet.nextSetBit(i + 1);
        }
        int numStates = pomdp.getNumStates();
        int nextSetBit2 = bitSet2.nextSetBit(0);
        while (true) {
            int i2 = nextSetBit2;
            if (i2 < 0) {
                return bitSet2;
            }
            int i3 = 0;
            while (true) {
                if (i3 < numStates) {
                    if (pomdp.getObservation(i3) == i2 && !bitSet.get(i3)) {
                        bitSet2.set(i2, false);
                        break;
                    }
                    i3++;
                } else {
                    break;
                }
            }
            nextSetBit2 = bitSet2.nextSetBit(i2 + 1);
        }
    }

    protected List<Belief> initialiseGridPoints(POMDP<Double> pomdp, BitSet bitSet) {
        ArrayList arrayList = new ArrayList();
        int numUnobservations = pomdp.getNumUnobservations();
        int numStates = pomdp.getNumStates();
        int nextSetBit = bitSet.nextSetBit(0);
        while (true) {
            int i = nextSetBit;
            if (i < 0) {
                return arrayList;
            }
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < numStates; i2++) {
                if (i == pomdp.getObservation(i2)) {
                    arrayList2.add(Integer.valueOf(pomdp.getUnobservation(i2)));
                }
            }
            Iterator<ArrayList<Double>> it = fullAssignment(arrayList2.size(), this.gridResolution).iterator();
            while (it.hasNext()) {
                ArrayList<Double> next = it.next();
                double[] dArr = new double[numUnobservations];
                int i3 = 0;
                Iterator it2 = arrayList2.iterator();
                while (it2.hasNext()) {
                    dArr[((Integer) it2.next()).intValue()] = next.get(i3).doubleValue();
                    i3++;
                }
                arrayList.add(new Belief(i, dArr));
            }
            nextSetBit = bitSet.nextSetBit(i + 1);
        }
    }

    protected List<BeliefMDPState> buildBeliefMDP(POMDP<Double> pomdp, MDPRewards<Double> mDPRewards, List<Belief> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Belief> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(buildBeliefMDPState(pomdp, mDPRewards, it.next()));
        }
        return arrayList;
    }

    protected BeliefMDPState buildBeliefMDPState(POMDP<Double> pomdp, MDPRewards<Double> mDPRewards, Belief belief) {
        double[] distributionOverStates = belief.toDistributionOverStates(pomdp);
        BeliefMDPState beliefMDPState = new BeliefMDPState();
        int numChoicesForObservation = pomdp.getNumChoicesForObservation(belief.so);
        for (int i = 0; i < numChoicesForObservation; i++) {
            HashMap<Integer, Double> computeObservationProbsAfterAction = pomdp.computeObservationProbsAfterAction(distributionOverStates, i);
            HashMap<Belief, Double> hashMap = new HashMap<>();
            for (Map.Entry<Integer, Double> entry : computeObservationProbsAfterAction.entrySet()) {
                hashMap.put(pomdp.getBeliefAfterChoiceAndObservation(belief, i, entry.getKey().intValue()), entry.getValue());
            }
            beliefMDPState.trans.add(hashMap);
            if (mDPRewards != null) {
                beliefMDPState.rewards.add(Double.valueOf(pomdp.getRewardAfterChoice(belief, i, mDPRewards)));
            }
        }
        return beliefMDPState;
    }

    protected Pair<Double, Integer> approximateReachProbBackup(Belief belief, BeliefMDPState beliefMDPState, Function<Belief, Double> function, boolean z) {
        int size = beliefMDPState.trans.size();
        double d = z ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < size; i2++) {
            double d2 = 0.0d;
            for (Map.Entry<Belief, Double> entry : beliefMDPState.trans.get(i2).entrySet()) {
                d2 += entry.getValue().doubleValue() * function.apply(entry.getKey()).doubleValue();
            }
            if ((z && d - d2 > 1.0E-6d) || (!z && d2 - d > 1.0E-6d)) {
                d = d2;
                i = i2;
            } else if (Math.abs(d2 - d) < 1.0E-6d) {
                i = i2;
            }
        }
        return new Pair<>(Double.valueOf(d), Integer.valueOf(i));
    }

    protected Pair<Double, Integer> approximateReachRewardBackup(Belief belief, BeliefMDPState beliefMDPState, Function<Belief, Double> function, boolean z) {
        int size = beliefMDPState.trans.size();
        double d = z ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            double doubleValue = beliefMDPState.rewards.get(i2).doubleValue();
            for (Map.Entry<Belief, Double> entry : beliefMDPState.trans.get(i2).entrySet()) {
                doubleValue += entry.getValue().doubleValue() * function.apply(entry.getKey()).doubleValue();
            }
            if ((z && d - doubleValue > 1.0E-6d) || (!z && doubleValue - d > 1.0E-6d)) {
                d = doubleValue;
                i = i2;
            } else if (Math.abs(doubleValue - d) < 1.0E-6d) {
                i = i2;
            }
        }
        return new Pair<>(Double.valueOf(d), Integer.valueOf(i));
    }

    protected double approximateReachProb(Belief belief, HashMap<Belief, Double> hashMap, BitSet bitSet, BitSet bitSet2) {
        if (bitSet.get(belief.so)) {
            return 1.0d;
        }
        return !bitSet2.get(belief.so) ? PrismSettings.DEFAULT_DOUBLE : interpolateOverGrid(belief, hashMap);
    }

    protected double approximateReachReward(Belief belief, HashMap<Belief, Double> hashMap, BitSet bitSet, BitSet bitSet2) {
        if (bitSet.get(belief.so)) {
            return PrismSettings.DEFAULT_DOUBLE;
        }
        if (bitSet2.get(belief.so)) {
            return Double.POSITIVE_INFINITY;
        }
        return interpolateOverGrid(belief, hashMap);
    }

    protected double interpolateOverGrid(Belief belief, HashMap<Belief, Double> hashMap) {
        ArrayList<double[]> arrayList = new ArrayList<>();
        double[] dArr = new double[belief.bu.length];
        getSubSimplexAndLambdas(belief.bu, arrayList, dArr, this.gridResolution);
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] >= 1.0E-6d) {
                d += dArr[i] * hashMap.get(new Belief(belief.so, arrayList.get(i))).doubleValue();
            }
        }
        return d;
    }

    protected POMDPStrategyModel buildStrategyModel(POMDP<Double> pomdp, int i, MDPRewards<Double> mDPRewards, BitSet bitSet, BitSet bitSet2, BeliefMDPBackUp beliefMDPBackUp) throws PrismException {
        MDPSimple mDPSimple = new MDPSimple();
        IndexedSet indexedSet = new IndexedSet((dArr, dArr2) -> {
            return Arrays.compare(dArr, dArr2);
        });
        IndexedSet indexedSet2 = new IndexedSet((iArr, iArr2) -> {
            return Arrays.compare(iArr, iArr2);
        });
        LinkedList linkedList = new LinkedList();
        BitSet bitSet3 = new BitSet();
        StateRewardsSimple stateRewardsSimple = new StateRewardsSimple();
        Belief pointDistribution = Belief.pointDistribution(i, pomdp);
        indexedSet.add(pointDistribution.bu);
        indexedSet2.add(new int[]{pointDistribution.so, indexedSet.getIndexOfLastAdd()});
        linkedList.offer(pointDistribution);
        mDPSimple.addState();
        mDPSimple.addInitialState(0);
        int i2 = -1;
        while (!linkedList.isEmpty()) {
            Belief belief = (Belief) linkedList.pollFirst();
            i2++;
            if (bitSet.get(belief.so)) {
                bitSet3.set(i2);
            }
            if (bitSet2.get(belief.so)) {
                BeliefMDPState buildBeliefMDPState = buildBeliefMDPState(pomdp, mDPRewards, belief);
                int intValue = beliefMDPBackUp.apply(belief, buildBeliefMDPState).second.intValue();
                Distribution<Double> ofDouble = Distribution.ofDouble();
                for (Map.Entry<Belief, Double> entry : buildBeliefMDPState.trans.get(intValue).entrySet()) {
                    double doubleValue = entry.getValue().doubleValue();
                    Belief key = entry.getKey();
                    indexedSet.add(key.bu);
                    if (indexedSet2.add(new int[]{key.so, indexedSet.getIndexOfLastAdd()})) {
                        linkedList.add(key);
                        mDPSimple.addState();
                    }
                    ofDouble.add(indexedSet2.getIndexOfLastAdd(), Double.valueOf(doubleValue));
                }
                mDPSimple.addActionLabelledChoice(i2, ofDouble, pomdp.getActionForObservation(belief.so, intValue));
                if (mDPRewards != null) {
                    stateRewardsSimple.setStateReward(i2, Double.valueOf(pomdp.getRewardAfterChoice(belief, intValue, mDPRewards)));
                } else {
                    stateRewardsSimple.setStateReward(i2, Double.valueOf(PrismSettings.DEFAULT_DOUBLE));
                }
            } else {
                stateRewardsSimple.setStateReward(i2, Double.valueOf(PrismSettings.DEFAULT_DOUBLE));
            }
        }
        mDPSimple.findDeadlocks(true);
        mDPSimple.addLabel("target", bitSet3);
        POMDPStrategyModel pOMDPStrategyModel = new POMDPStrategyModel();
        pOMDPStrategyModel.mdp = mDPSimple;
        pOMDPStrategyModel.mdpStates = new ArrayList();
        pOMDPStrategyModel.mdpStates.addAll(indexedSet2.toArrayList());
        pOMDPStrategyModel.unobsBeliefs = new ArrayList();
        pOMDPStrategyModel.unobsBeliefs.addAll(indexedSet.toArrayList());
        pOMDPStrategyModel.mdpRewards = stateRewardsSimple;
        return pOMDPStrategyModel;
    }

    protected ArrayList<ArrayList<Integer>> assignGPrime(int i, int i2, int i3, int i4) {
        ArrayList<ArrayList<Integer>> arrayList = new ArrayList<>();
        if (i == i4 - 1) {
            for (int i5 = i2; i5 <= i3; i5++) {
                ArrayList<Integer> arrayList2 = new ArrayList<>();
                arrayList2.add(Integer.valueOf(i5));
                arrayList.add(arrayList2);
            }
        } else {
            for (int i6 = i2; i6 <= i3; i6++) {
                Iterator<ArrayList<Integer>> it = assignGPrime(i + 1, 0, i6, i4).iterator();
                while (it.hasNext()) {
                    ArrayList<Integer> next = it.next();
                    ArrayList<Integer> arrayList3 = new ArrayList<>();
                    arrayList3.add(Integer.valueOf(i6));
                    Iterator<Integer> it2 = next.iterator();
                    while (it2.hasNext()) {
                        arrayList3.add(it2.next());
                    }
                    arrayList.add(arrayList3);
                }
            }
        }
        return arrayList;
    }

    private ArrayList<ArrayList<Double>> fullAssignment(int i, int i2) {
        ArrayList<ArrayList<Integer>> assignGPrime = assignGPrime(0, i2, i2, i);
        ArrayList<ArrayList<Double>> arrayList = new ArrayList<>();
        Iterator<ArrayList<Integer>> it = assignGPrime.iterator();
        while (it.hasNext()) {
            ArrayList<Integer> next = it.next();
            ArrayList<Double> arrayList2 = new ArrayList<>();
            int i3 = 0;
            while (i3 < i - 1) {
                arrayList2.add(Double.valueOf((next.get(i3).intValue() - next.get(i3 + 1).intValue()) / i2));
                i3++;
            }
            arrayList2.add(Double.valueOf(next.get(i3).intValue() / i2));
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    private int[] getSortedPermutation(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        int[] iArr = new int[length];
        int i = 0;
        int i2 = 0;
        int i3 = length - 1;
        boolean z = false;
        for (int i4 = length - 1; i4 >= 0; i4--) {
            if (dArr[i4] == PrismSettings.DEFAULT_DOUBLE) {
                dArr2[i3] = 0.0d;
                iArr[i3] = i4;
                i3--;
            }
        }
        for (int i5 = 0; i5 < length; i5++) {
            if (dArr[i5] != PrismSettings.DEFAULT_DOUBLE) {
                dArr2[i2] = dArr[i5];
                iArr[i2] = i5;
                i2++;
            }
        }
        while (!z) {
            z = true;
            for (int i6 = 0; i6 < (i2 - i) - 1; i6++) {
                if (dArr2[i6] < dArr2[i6 + 1]) {
                    swap(dArr2, i6, i6 + 1);
                    swap(iArr, i6, i6 + 1);
                    z = false;
                }
            }
            i++;
        }
        return iArr;
    }

    private void swap(int[] iArr, int i, int i2) {
        int i3 = iArr[i];
        iArr[i] = iArr[i2];
        iArr[i2] = i3;
    }

    private void swap(double[] dArr, int i, int i2) {
        double d = dArr[i];
        dArr[i] = dArr[i2];
        dArr[i2] = d;
    }

    protected boolean getSubSimplexAndLambdas(double[] dArr, ArrayList<double[]> arrayList, double[] dArr2, int i) {
        int length = dArr.length;
        double[] dArr3 = new double[length];
        int[] iArr = new int[length];
        double[] dArr4 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            dArr3[i2] = 0.0d;
            for (int i3 = i2; i3 < length; i3++) {
                int i4 = i2;
                dArr3[i4] = dArr3[i4] + (i * dArr[i3]);
            }
            dArr3[i2] = Math.round(dArr3[i2] * 1000000.0d) / 1000000.0d;
            iArr[i2] = (int) Math.floor(dArr3[i2]);
            dArr4[i2] = dArr3[i2] - iArr[i2];
        }
        int[] sortedPermutation = getSortedPermutation(dArr4);
        ArrayList arrayList2 = new ArrayList();
        for (int i5 = 0; i5 < length; i5++) {
            int[] iArr2 = new int[length];
            if (i5 == 0) {
                for (int i6 = 0; i6 < length; i6++) {
                    iArr2[i6] = iArr[i6];
                }
                arrayList2.add(iArr2);
            } else {
                for (int i7 = 0; i7 < length; i7++) {
                    if (i7 == sortedPermutation[i5 - 1]) {
                        iArr2[i7] = ((int[]) arrayList2.get(i5 - 1))[i7] + 1;
                    } else {
                        iArr2[i7] = ((int[]) arrayList2.get(i5 - 1))[i7];
                    }
                }
                arrayList2.add(iArr2);
            }
        }
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            int[] iArr3 = (int[]) it.next();
            double[] dArr5 = new double[length];
            int i8 = 0;
            while (i8 < length - 1) {
                dArr5[i8] = (iArr3[i8] - iArr3[i8 + 1]) / i;
                i8++;
            }
            dArr5[i8] = iArr3[i8] / i;
            arrayList.add(dArr5);
        }
        double d = 0.0d;
        for (int i9 = 1; i9 < length; i9++) {
            double d2 = dArr4[sortedPermutation[i9 - 1]] - dArr4[sortedPermutation[i9]];
            dArr2[i9] = d2;
            d += d2;
        }
        dArr2[0] = 1.0d - d;
        for (int i10 = 0; i10 < length; i10++) {
            double d3 = 0.0d;
            for (int i11 = 0; i11 < length; i11++) {
                d3 += dArr2[i11] * arrayList.get(i11)[i10];
            }
            if (Math.abs(dArr[i10] - d3) > 1.0E-4d) {
                return false;
            }
        }
        return true;
    }

    public static boolean isTargetBelief(double[] dArr, BitSet bitSet) {
        double d = 0.0d;
        int nextSetBit = bitSet.nextSetBit(0);
        while (true) {
            int i = nextSetBit;
            if (i < 0) {
                break;
            }
            d += dArr[i];
            nextSetBit = bitSet.nextSetBit(i + 1);
        }
        return Math.abs(d - 1.0d) < 1.0E-6d;
    }

    public static void main(String[] strArr) {
        boolean z = true;
        try {
            POMDPModelChecker pOMDPModelChecker = new POMDPModelChecker(null);
            MDPSimple mDPSimple = new MDPSimple();
            mDPSimple.buildFromPrismExplicit(strArr[0]);
            Map<String, BitSet> loadLabelsFile = loadLabelsFile(strArr[1]);
            BitSet bitSet = loadLabelsFile.get("init");
            BitSet bitSet2 = loadLabelsFile.get(strArr[2]);
            if (bitSet2 == null) {
                throw new PrismException("Unknown label \"" + strArr[2] + "\"");
            }
            for (int i = 3; i < strArr.length; i++) {
                if (strArr[i].equals("-min")) {
                    z = true;
                } else if (strArr[i].equals("-max")) {
                    z = false;
                } else if (strArr[i].equals("-nopre")) {
                    pOMDPModelChecker.setPrecomp(false);
                }
            }
            System.out.println(pOMDPModelChecker.computeReachRewards(new POMDPSimple(mDPSimple), null, bitSet2, z, null).soln[bitSet.nextSetBit(0)]);
        } catch (PrismException e) {
            System.out.println(e);
        }
    }
}
