package explicit;

import common.IterableBitSet;
import common.StopWatch;
import common.iterable.FunctionalPrimitiveIterator;
import explicit.IncomingChoiceRelation;
import explicit.rewards.MDPRewards;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.PriorityQueue;
import prism.PrismComponent;
import prism.PrismSettings;

/* loaded from: input_file:explicit/DijkstraSweepMPI.class */
public class DijkstraSweepMPI {
    private static boolean debug;
    private MDP<Double> mdp;
    private MDPRewards<Double> rewards;
    private double[] pState;
    private double[] wState;
    private QueueEntry[] pri;
    private int[] pi;
    private BitSet unknown;
    private BitSet target;
    private IncomingChoiceRelation incoming;
    private double lambda;
    static final /* synthetic */ boolean $assertionsDisabled;
    private HashMap<IncomingChoiceRelation.Choice, ChoiceValues> choiceValues = new HashMap<>();
    private BitSet fin = new BitSet();
    private PriorityQueue<QueueEntry> queue = new PriorityQueue<>();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:explicit/DijkstraSweepMPI$ChoiceValues.class */
    public static class ChoiceValues {
        public double p;
        public double w;

        public ChoiceValues(double d, double d2) {
            this.p = d;
            this.w = d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:explicit/DijkstraSweepMPI$QueueEntry.class */
    public static class QueueEntry implements Comparable<QueueEntry> {
        public int y;
        public double p;
        public double w;

        public QueueEntry(int i, double d, double d2) {
            this.y = i;
            this.p = d;
            this.w = d2;
        }

        @Override // java.lang.Comparable
        public int compareTo(QueueEntry queueEntry) {
            int compare = Double.compare(this.p, queueEntry.p);
            return compare == 0 ? Double.compare(this.w, queueEntry.w) : compare;
        }
    }

    private DijkstraSweepMPI(PrismComponent prismComponent, MDP<Double> mdp, MDPRewards<Double> mDPRewards, BitSet bitSet, BitSet bitSet2) {
        this.mdp = mdp;
        this.unknown = bitSet2;
        this.target = bitSet;
        this.rewards = mDPRewards;
        this.incoming = IncomingChoiceRelation.forModel(prismComponent, mdp);
        this.pState = new double[mdp.getNumStates()];
        this.wState = new double[mdp.getNumStates()];
        this.pri = new QueueEntry[mdp.getNumStates()];
        this.pi = new int[mdp.getNumStates()];
        FunctionalPrimitiveIterator.OfInt mo31iterator = IterableBitSet.getSetBits(bitSet2).mo31iterator();
        while (mo31iterator.hasNext()) {
            int intValue = mo31iterator.next().intValue();
            int numChoices = mdp.getNumChoices(intValue);
            for (int i = 0; i < numChoices; i++) {
                this.choiceValues.put(new IncomingChoiceRelation.Choice(intValue, i), new ChoiceValues(PrismSettings.DEFAULT_DOUBLE, mDPRewards.getStateReward(intValue).doubleValue() + mDPRewards.getTransitionReward(intValue, i).doubleValue()));
            }
        }
        FunctionalPrimitiveIterator.OfInt mo31iterator2 = IterableBitSet.getSetBits(bitSet).mo31iterator();
        while (mo31iterator2.hasNext()) {
            this.pState[mo31iterator2.next().intValue()] = 1.0d;
        }
        HashSet hashSet = new HashSet();
        FunctionalPrimitiveIterator.OfInt mo31iterator3 = IterableBitSet.getSetBits(bitSet).mo31iterator();
        while (mo31iterator3.hasNext()) {
            for (IncomingChoiceRelation.Choice choice : this.incoming.getIncomingChoices(mo31iterator3.next().intValue())) {
                if (hashSet.add(choice) && bitSet2.get(choice.getState()) && validChoice(choice)) {
                    update(choice, bitSet);
                }
            }
        }
        hashSet.clear();
        sweep();
        computeLambda();
    }

    private void sweep() {
        while (!this.queue.isEmpty()) {
            int i = this.queue.poll().y;
            if (!this.fin.get(i)) {
                this.fin.set(i);
                ChoiceValues choiceValues = this.choiceValues.get(new IncomingChoiceRelation.Choice(i, this.pi[i]));
                this.wState[i] = choiceValues.w;
                this.pState[i] = choiceValues.p;
                for (IncomingChoiceRelation.Choice choice : this.incoming.getIncomingChoices(i)) {
                    if (!this.fin.get(choice.getState()) && this.unknown.get(choice.getState()) && validChoice(choice)) {
                        update(choice, i);
                    }
                }
            }
        }
    }

    private boolean validChoice(IncomingChoiceRelation.Choice choice) {
        return !this.mdp.someSuccessorsMatch(choice.getState(), choice.getChoice(), i -> {
            return (this.unknown.get(i) || this.target.get(i)) ? false : true;
        });
    }

    private void update(IncomingChoiceRelation.Choice choice, int i) {
        double d = this.wState[i];
        double sumOverDoubleTransitions = this.mdp.sumOverDoubleTransitions(choice.getState(), choice.getChoice(), (i2, i3, d2) -> {
            return i3 != i ? PrismSettings.DEFAULT_DOUBLE : d2 * d;
        });
        double d3 = this.pState[i];
        double sumOverDoubleTransitions2 = this.mdp.sumOverDoubleTransitions(choice.getState(), choice.getChoice(), (i4, i5, d4) -> {
            return i5 != i ? PrismSettings.DEFAULT_DOUBLE : d4 * d3;
        });
        ChoiceValues choiceValues = this.choiceValues.get(choice);
        if (!$assertionsDisabled && choiceValues == null) {
            throw new AssertionError();
        }
        choiceValues.p += sumOverDoubleTransitions2;
        choiceValues.w += sumOverDoubleTransitions;
        QueueEntry queueEntry = new QueueEntry(choice.getState(), 1.0d - choiceValues.p, choiceValues.w);
        if (this.pri[choice.getState()] == null || queueEntry.compareTo(this.pri[choice.getState()]) < 0) {
            this.pri[choice.getState()] = queueEntry;
            this.pi[choice.getState()] = choice.getChoice();
            this.queue.add(queueEntry);
        }
    }

    private void update(IncomingChoiceRelation.Choice choice, BitSet bitSet) {
        double sumOverDoubleTransitions = this.mdp.sumOverDoubleTransitions(choice.getState(), choice.getChoice(), (i, i2, d) -> {
            return bitSet.get(i2) ? d : PrismSettings.DEFAULT_DOUBLE;
        });
        ChoiceValues choiceValues = this.choiceValues.get(choice);
        choiceValues.p += sumOverDoubleTransitions;
        QueueEntry queueEntry = new QueueEntry(choice.getState(), 1.0d - choiceValues.p, choiceValues.w);
        if (this.pri[choice.getState()] == null || queueEntry.compareTo(this.pri[choice.getState()]) < 0) {
            this.pri[choice.getState()] = queueEntry;
            this.pi[choice.getState()] = choice.getChoice();
            this.queue.add(queueEntry);
        }
    }

    private double computeLambda() {
        this.lambda = PrismSettings.DEFAULT_DOUBLE;
        FunctionalPrimitiveIterator.OfInt mo31iterator = IterableBitSet.getSetBits(this.unknown).mo31iterator();
        while (mo31iterator.hasNext()) {
            int intValue = mo31iterator.next().intValue();
            int i = this.pi[intValue];
            this.lambda = Double.max(this.lambda, this.pState[intValue] < this.mdp.sumOverDoubleTransitions(intValue, i, (i2, i3, d) -> {
                return d * this.pState[i3];
            }) ? (((this.rewards.getStateReward(intValue).doubleValue() + this.rewards.getTransitionReward(intValue, i).doubleValue()) + this.mdp.sumOverDoubleTransitions(intValue, i, (i4, i5, d2) -> {
                return d2 * this.wState[i5];
            })) - this.wState[intValue]) / (this.mdp.sumOverDoubleTransitions(intValue, i, (i6, i7, d3) -> {
                return d3 * this.pState[i7];
            }) - this.pState[intValue]) : PrismSettings.DEFAULT_DOUBLE);
        }
        return this.lambda;
    }

    public static double[] computeUpperBounds(PrismComponent prismComponent, MDP<Double> mdp, MDPRewards<Double> mDPRewards, BitSet bitSet, BitSet bitSet2) {
        StopWatch stopWatch = new StopWatch(prismComponent.getLog());
        stopWatch.start("computing upper bound(s) for Rmin using the DSI-MP algorithm");
        prismComponent.getLog().println("Computing upper bound(s) for Rmin using the Dijkstra Sweep for Monotone Pessimistic Initialization algorithm...");
        double[] dArr = new double[mdp.getNumStates()];
        DijkstraSweepMPI dijkstraSweepMPI = new DijkstraSweepMPI(prismComponent, mdp, mDPRewards, bitSet, bitSet2);
        FunctionalPrimitiveIterator.OfInt mo31iterator = IterableBitSet.getSetBits(bitSet2).mo31iterator();
        while (mo31iterator.hasNext()) {
            int intValue = mo31iterator.next().intValue();
            dArr[intValue] = dijkstraSweepMPI.wState[intValue] + (dijkstraSweepMPI.lambda * (1.0d - dijkstraSweepMPI.pState[intValue]));
        }
        if (debug) {
            prismComponent.getLog().println(dArr);
        }
        stopWatch.stop();
        return dArr;
    }

    public static double computeUpperBound(PrismComponent prismComponent, MDP<Double> mdp, MDPRewards<Double> mDPRewards, BitSet bitSet, BitSet bitSet2) {
        double d = 0.0d;
        double[] computeUpperBounds = computeUpperBounds(prismComponent, mdp, mDPRewards, bitSet, bitSet2);
        FunctionalPrimitiveIterator.OfInt mo31iterator = IterableBitSet.getSetBits(bitSet2).mo31iterator();
        while (mo31iterator.hasNext()) {
            d = Double.max(d, computeUpperBounds[mo31iterator.next().intValue()]);
        }
        return d;
    }

    static {
        $assertionsDisabled = !DijkstraSweepMPI.class.desiredAssertionStatus();
        debug = false;
    }
}
