package explicit;

import common.iterable.FunctionalIterator;
import explicit.rewards.MDPRewards;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import parser.State;
import prism.PrismException;
import prism.PrismSettings;
import prism.PrismUtils;

/* loaded from: input_file:explicit/POMDPSimple.class */
public class POMDPSimple<Value> extends MDPSimple<Value> implements POMDP<Value> {
    protected List<State> observationsList;
    protected List<State> unobservationsList;
    protected List<Integer> observationStates;
    protected List<Integer> observablesMap;
    protected List<Integer> unobservablesMap;
    static final /* synthetic */ boolean $assertionsDisabled;

    public POMDPSimple() {
        initialiseObservables();
    }

    public POMDPSimple(int i) {
        super(i);
        initialiseObservables();
    }

    public POMDPSimple(POMDPSimple<Value> pOMDPSimple) {
        super((MDPSimple) pOMDPSimple);
        this.observationsList = new ArrayList(pOMDPSimple.observationsList);
        this.unobservationsList = new ArrayList(pOMDPSimple.unobservationsList);
        this.observationStates = new ArrayList(pOMDPSimple.observationStates);
        this.observablesMap = new ArrayList(pOMDPSimple.observablesMap);
        this.unobservablesMap = new ArrayList(pOMDPSimple.unobservablesMap);
    }

    public POMDPSimple(POMDPSimple<Value> pOMDPSimple, int[] iArr) {
        super(pOMDPSimple, iArr);
        this.observationsList = new ArrayList(pOMDPSimple.observationsList);
        this.unobservationsList = new ArrayList(pOMDPSimple.unobservationsList);
        int numObservations = pOMDPSimple.getNumObservations();
        this.observationStates = new ArrayList(numObservations);
        for (int i = 0; i < numObservations; i++) {
            int intValue = pOMDPSimple.observationStates.get(i).intValue();
            this.observationStates.add(Integer.valueOf(intValue == -1 ? -1 : iArr[intValue]));
        }
        this.observablesMap = new ArrayList(getNumStates());
        this.unobservablesMap = new ArrayList(getNumStates());
        for (int i2 = 0; i2 < this.numStates; i2++) {
            this.observablesMap.add(-1);
            this.unobservablesMap.add(-1);
        }
        for (int i3 = 0; i3 < this.numStates; i3++) {
            this.observablesMap.set(iArr[i3], pOMDPSimple.observablesMap.get(i3));
            this.unobservablesMap.set(iArr[i3], pOMDPSimple.unobservablesMap.get(i3));
        }
    }

    public POMDPSimple(MDPSimple<Value> mDPSimple) {
        super((MDPSimple) mDPSimple);
        initialiseObservables(mDPSimple.numStates);
        for (int i = 0; i < this.numStates; i++) {
            try {
                setObservation(i, i);
            } catch (PrismException e) {
            }
            this.unobservablesMap.set(i, null);
        }
    }

    protected void initialiseObservables() {
        this.observationsList = new ArrayList();
        this.unobservationsList = new ArrayList();
        this.observationStates = new ArrayList();
        this.observablesMap = new ArrayList();
        this.unobservablesMap = new ArrayList();
    }

    protected void initialiseObservables(int i) {
        this.observationsList = new ArrayList();
        this.unobservationsList = new ArrayList();
        this.observationStates = new ArrayList();
        this.observablesMap = new ArrayList(i);
        this.unobservablesMap = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.observablesMap.add(-1);
            this.unobservablesMap.add(-1);
        }
    }

    @Override // explicit.MDPSimple, explicit.ModelSimple
    public void clearState(int i) {
        super.clearState(i);
        this.observablesMap.set(i, -1);
        this.unobservablesMap.set(i, -1);
    }

    @Override // explicit.MDPSimple, explicit.ModelSimple
    public void addStates(int i) {
        super.addStates(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.observablesMap.add(-1);
            this.unobservablesMap.add(-1);
        }
    }

    public void setObservationsList(List<State> list) {
        this.observationsList = list;
    }

    public void setUnobservationsList(List<State> list) {
        this.unobservationsList = list;
    }

    public void setObservation(int i, State state, State state2, List<String> list) throws PrismException {
        int indexOf = this.observationsList.indexOf(state);
        if (indexOf == -1) {
            this.observationsList.add(state);
            indexOf = this.observationsList.size() - 1;
            this.observationStates.add(-1);
        }
        try {
            setObservation(i, indexOf);
            int indexOf2 = this.unobservationsList.indexOf(state2);
            if (indexOf2 == -1) {
                this.unobservationsList.add(state2);
                indexOf2 = this.unobservationsList.size() - 1;
            }
            this.unobservablesMap.set(i, Integer.valueOf(indexOf2));
        } catch (PrismException e) {
            throw new PrismException("Problem with observation " + (list == null ? state.toString() : state.toString(list)) + ": " + e.getMessage());
        }
    }

    protected void setObservation(int i, int i2) throws PrismException {
        this.observablesMap.set(i, Integer.valueOf(i2));
        int intValue = this.observationStates.get(i2).intValue();
        if (intValue == -1) {
            this.observationStates.set(i2, Integer.valueOf(i));
        } else {
            checkActionsMatchExactly(i, intValue);
        }
    }

    protected void checkActionsMatchExactly(int i, int i2) throws PrismException {
        int numChoices = getNumChoices(i);
        if (numChoices != getNumChoices(i2)) {
            throw new PrismException("Differing actions found in states: " + getAvailableActions(i) + " vs. " + getAvailableActions(i2));
        }
        for (int i3 = 0; i3 < numChoices; i3++) {
            Object action = getAction(i, i3);
            Object action2 = getAction(i2, i3);
            if (action == null) {
                if (action2 != null) {
                    throw new PrismException("Differing actions found in states: " + getAvailableActions(i) + " vs. " + getAvailableActions(i2));
                }
            } else if (!action.equals(action2)) {
                throw new PrismException("Differing actions found in states: " + getAvailableActions(i) + " vs. " + getAvailableActions(i2));
            }
        }
    }

    protected void checkActionsMatch(int i, int i2) throws PrismException {
        ArrayList arrayList = new ArrayList();
        int numChoices = getNumChoices(i);
        for (int i3 = 0; i3 < numChoices; i3++) {
            Object action = getAction(i, i3);
            arrayList.add(action == null ? PrismSettings.DEFAULT_STRING : action.toString());
        }
        Collections.sort(arrayList);
        ArrayList arrayList2 = new ArrayList();
        int numChoices2 = getNumChoices(i2);
        for (int i4 = 0; i4 < numChoices2; i4++) {
            Object action2 = getAction(i2, i4);
            arrayList2.add(action2 == null ? PrismSettings.DEFAULT_STRING : action2.toString());
        }
        Collections.sort(arrayList2);
        if (!arrayList.equals(arrayList2)) {
            throw new PrismException("Differing actions found in states: " + arrayList + " vs. " + arrayList2);
        }
    }

    @Override // explicit.PartiallyObservableModel
    public List<State> getObservationsList() {
        return this.observationsList;
    }

    @Override // explicit.PartiallyObservableModel
    public List<State> getUnobservationsList() {
        return this.unobservationsList;
    }

    @Override // explicit.PartiallyObservableModel
    public int getObservation(int i) {
        if (this.observablesMap == null) {
            return -1;
        }
        return this.observablesMap.get(i).intValue();
    }

    @Override // explicit.PartiallyObservableModel
    public int getUnobservation(int i) {
        return this.unobservablesMap.get(i).intValue();
    }

    @Override // explicit.PartiallyObservableModel
    public int getNumChoicesForObservation(int i) {
        return getNumChoices(this.observationStates.get(i).intValue());
    }

    @Override // explicit.POMDP
    public Object getActionForObservation(int i, int i2) {
        return getAction(this.observationStates.get(i).intValue(), i2);
    }

    @Override // explicit.POMDP
    public Belief getInitialBelief() {
        double[] dArr = new double[this.numStates];
        Iterator<Integer> it = this.initialStates.iterator();
        while (it.hasNext()) {
            dArr[it.next().intValue()] = 1.0d;
        }
        PrismUtils.normalise(dArr);
        return new Belief(dArr, this);
    }

    @Override // explicit.POMDP
    public double[] getInitialBeliefInDist() {
        double[] dArr = new double[this.numStates];
        Iterator<Integer> it = this.initialStates.iterator();
        while (it.hasNext()) {
            dArr[it.next().intValue()] = 1.0d;
        }
        PrismUtils.normalise(dArr);
        return dArr;
    }

    @Override // explicit.POMDP
    public Belief getBeliefAfterChoice(Belief belief, int i) {
        return new Belief(getBeliefInDistAfterChoice(belief.toDistributionOverStates(this), i), this);
    }

    @Override // explicit.POMDP
    public double[] getBeliefInDistAfterChoice(double[] dArr, int i) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            if (dArr[i2] >= 1.0E-6d) {
                FunctionalIterator<Map.Entry<Integer, Value>> mo31iterator = getChoice(i2, i).mo31iterator();
                while (mo31iterator.hasNext()) {
                    Map.Entry<Integer, Value> next = mo31iterator.next();
                    int intValue = next.getKey().intValue();
                    dArr2[intValue] = dArr2[intValue] + (dArr[i2] * ((Double) next.getValue()).doubleValue());
                }
            }
        }
        return dArr2;
    }

    @Override // explicit.POMDP
    public Belief getBeliefAfterChoiceAndObservation(Belief belief, int i, int i2) {
        Belief belief2 = new Belief(getBeliefInDistAfterChoiceAndObservation(belief.toDistributionOverStates(this), i, i2), this);
        if ($assertionsDisabled || belief2.so == i2) {
            return belief2;
        }
        throw new AssertionError();
    }

    @Override // explicit.POMDP
    public double[] getBeliefInDistAfterChoiceAndObservation(double[] dArr, int i, int i2) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        double[] beliefInDistAfterChoice = getBeliefInDistAfterChoice(dArr, i);
        for (int i3 = 0; i3 < length; i3++) {
            dArr2[i3] = beliefInDistAfterChoice[i3] * getObservationProb(i3, i2);
        }
        PrismUtils.normalise(dArr2);
        return dArr2;
    }

    @Override // explicit.POMDP
    public double getObservationProbAfterChoice(Belief belief, int i, int i2) {
        return getObservationProbAfterChoice(belief.toDistributionOverStates(this), i, i2);
    }

    @Override // explicit.POMDP
    public double getObservationProbAfterChoice(double[] dArr, int i, int i2) {
        double[] beliefInDistAfterChoice = getBeliefInDistAfterChoice(dArr, i);
        double d = 0.0d;
        for (int i3 = 0; i3 < beliefInDistAfterChoice.length; i3++) {
            d += beliefInDistAfterChoice[i3] * getObservationProb(i3, i2);
        }
        return d;
    }

    @Override // explicit.POMDP
    public HashMap<Integer, Double> computeObservationProbsAfterAction(double[] dArr, int i) {
        HashMap<Integer, Double> hashMap = new HashMap<>();
        double[] beliefInDistAfterChoice = getBeliefInDistAfterChoice(dArr, i);
        for (int i2 = 0; i2 < beliefInDistAfterChoice.length; i2++) {
            int observation = getObservation(i2);
            double d = beliefInDistAfterChoice[i2];
            if (d > 1.0E-6d) {
                Double d2 = hashMap.get(Integer.valueOf(observation));
                if (d2 == null) {
                    hashMap.put(Integer.valueOf(observation), Double.valueOf(d));
                } else {
                    hashMap.put(Integer.valueOf(observation), Double.valueOf(d2.doubleValue() + d));
                }
            }
        }
        return hashMap;
    }

    @Override // explicit.POMDP
    public double getRewardAfterChoice(Belief belief, int i, MDPRewards<Double> mDPRewards) {
        return getRewardAfterChoice(belief.toDistributionOverStates(this), i, mDPRewards);
    }

    @Override // explicit.POMDP
    public double getRewardAfterChoice(double[] dArr, int i, MDPRewards<Double> mDPRewards) {
        double d;
        double doubleValue;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] == PrismSettings.DEFAULT_DOUBLE) {
                d = d2;
                doubleValue = PrismSettings.DEFAULT_DOUBLE;
            } else {
                d = d2;
                doubleValue = dArr[i2] * (mDPRewards.getTransitionReward(i2, i).doubleValue() + mDPRewards.getStateReward(i2).doubleValue());
            }
            d2 = d + doubleValue;
        }
        return d2;
    }

    protected Belief beliefInDistToBelief(double[] dArr) {
        int i = -1;
        double[] dArr2 = new double[getNumUnobservations()];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] != PrismSettings.DEFAULT_DOUBLE) {
                i = getObservation(i2);
                int unobservation = getUnobservation(i2);
                dArr2[unobservation] = dArr2[unobservation] + dArr[i2];
            }
        }
        Belief belief = null;
        if (i != -1) {
            belief = new Belief(i, dArr2);
        } else {
            System.err.println("Something wrong in POMDPSimple.beliefInDistToBelief(double[] beliefInDist)");
        }
        return belief;
    }

    @Override // explicit.MDPSimple
    public String toString() {
        String str = "[ ";
        for (int i = 0; i < this.numStates; i++) {
            if (i > 0) {
                str = str + ", ";
            }
            String str2 = (str + i + "(" + getObservation(i) + "/" + getUnobservation(i) + "): ") + "[";
            int numChoices = getNumChoices(i);
            for (int i2 = 0; i2 < numChoices; i2++) {
                if (i2 > 0) {
                    str2 = str2 + ",";
                }
                Object action = getAction(i, i2);
                if (action != null) {
                    str2 = str2 + action + ":";
                }
                str2 = str2 + this.trans.get(i).get(i2);
            }
            str = str2 + "]";
        }
        return str + " ]\n";
    }

    @Override // explicit.MDPSimple, explicit.ModelExplicit
    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof POMDPSimple)) {
            return false;
        }
        POMDPSimple pOMDPSimple = (POMDPSimple) obj;
        return this.numStates == pOMDPSimple.numStates && this.initialStates.equals(pOMDPSimple.initialStates) && this.trans.equals(pOMDPSimple.trans);
    }

    static {
        $assertionsDisabled = !POMDPSimple.class.desiredAssertionStatus();
    }
}
