/*
 * Decompiled with CFR 0.152.
 */
package aima.probability.decision;

import aima.probability.Randomizer;
import aima.probability.decision.MDPPerception;
import aima.probability.decision.MDPPolicy;
import aima.probability.decision.MDPRewardFunction;
import aima.probability.decision.MDPSource;
import aima.probability.decision.MDPTransition;
import aima.probability.decision.MDPTransitionModel;
import aima.probability.decision.MDPUtilityFunction;
import aima.util.Pair;
import java.util.List;

public class MDP<STATE_TYPE, ACTION_TYPE> {
    private STATE_TYPE initialState;
    private MDPTransitionModel<STATE_TYPE, ACTION_TYPE> transitionModel;
    private MDPRewardFunction<STATE_TYPE> rewardFunction;
    private List<STATE_TYPE> nonFinalstates;
    private List<STATE_TYPE> terminalStates;
    private MDPSource<STATE_TYPE, ACTION_TYPE> source;

    public MDP(MDPSource<STATE_TYPE, ACTION_TYPE> source) {
        this.initialState = source.getInitialState();
        this.transitionModel = source.getTransitionModel();
        this.rewardFunction = source.getRewardFunction();
        this.nonFinalstates = source.getNonFinalStates();
        this.terminalStates = source.getFinalStates();
        this.source = source;
    }

    public MDP<STATE_TYPE, ACTION_TYPE> emptyMdp() {
        MDP<STATE_TYPE, ACTION_TYPE> mdp = new MDP<STATE_TYPE, ACTION_TYPE>(this.source);
        mdp.rewardFunction = new MDPRewardFunction();
        mdp.rewardFunction.setReward(this.initialState, this.rewardFunction.getRewardFor(this.initialState));
        mdp.transitionModel = new MDPTransitionModel(this.terminalStates);
        return mdp;
    }

    public MDPUtilityFunction<STATE_TYPE> valueIteration(double gamma, double error, double delta) {
        MDPUtilityFunction<STATE_TYPE> U = this.initialUtilityFunction();
        MDPUtilityFunction<STATE_TYPE> U_dash = this.initialUtilityFunction();
        double delta_max = error * gamma / (1.0 - gamma);
        do {
            U = U_dash.copy();
            delta = 0.0;
            for (STATE_TYPE s : this.nonFinalstates) {
                Pair<ACTION_TYPE, Double> highestUtilityTransition = this.transitionModel.getTransitionWithMaximumExpectedUtility(s, U);
                double utility = this.rewardFunction.getRewardFor(s) + gamma * highestUtilityTransition.getSecond();
                U_dash.setUtility(s, utility);
                if (!(Math.abs(U_dash.getUtility(s) - U.getUtility(s)) > delta)) continue;
                delta = Math.abs(U_dash.getUtility(s) - U.getUtility(s));
            }
        } while (delta < delta_max);
        return U;
    }

    public MDPUtilityFunction<STATE_TYPE> valueIterationForFixedIterations(int numberOfIterations, double gamma) {
        MDPUtilityFunction<STATE_TYPE> utilityFunction = this.initialUtilityFunction();
        for (int i = 0; i < numberOfIterations; ++i) {
            MDPUtilityFunction<STATE_TYPE> cachedUtilityFunction = utilityFunction.copy();
            Pair<MDPUtilityFunction<STATE_TYPE>, Double> result = this.valueIterateOnce(gamma, utilityFunction);
            utilityFunction = result.getFirst();
            double d = result.getSecond();
        }
        return utilityFunction;
    }

    public MDPUtilityFunction<STATE_TYPE> valueIterationTillMAximumUtilityGrowthFallsBelowErrorMargin(double gamma, double errorMargin) {
        int iterationCounter = 0;
        double maxUtilityGrowth = 0.0;
        MDPUtilityFunction<STATE_TYPE> utilityFunction = this.initialUtilityFunction();
        do {
            Pair<MDPUtilityFunction<STATE_TYPE>, Double> result = this.valueIterateOnce(gamma, utilityFunction);
            utilityFunction = result.getFirst();
            maxUtilityGrowth = result.getSecond();
            ++iterationCounter;
        } while (maxUtilityGrowth > errorMargin);
        return utilityFunction;
    }

    public Pair<MDPUtilityFunction<STATE_TYPE>, Double> valueIterateOnce(double gamma, MDPUtilityFunction<STATE_TYPE> presentUtilityFunction) {
        double maxUtilityGrowth = 0.0;
        MDPUtilityFunction<STATE_TYPE> newUtilityFunction = new MDPUtilityFunction<STATE_TYPE>();
        for (STATE_TYPE s : this.nonFinalstates) {
            Pair<ACTION_TYPE, Double> highestUtilityTransition = this.transitionModel.getTransitionWithMaximumExpectedUtility(s, presentUtilityFunction);
            double utility = this.valueIterateOnceForGivenState(gamma, presentUtilityFunction, s);
            double differenceInUtility = Math.abs(utility - presentUtilityFunction.getUtility(s));
            if (differenceInUtility > maxUtilityGrowth) {
                maxUtilityGrowth = differenceInUtility;
            }
            newUtilityFunction.setUtility(s, utility);
            for (STATE_TYPE state : this.terminalStates) {
                newUtilityFunction.setUtility(state, presentUtilityFunction.getUtility(state));
            }
        }
        return new Pair(newUtilityFunction, maxUtilityGrowth);
    }

    private double valueIterateOnceForGivenState(double gamma, MDPUtilityFunction<STATE_TYPE> presentUtilityFunction, STATE_TYPE state) {
        Pair<ACTION_TYPE, Double> highestUtilityTransition = this.transitionModel.getTransitionWithMaximumExpectedUtility(state, presentUtilityFunction);
        double utility = this.rewardFunction.getRewardFor(state) + gamma * highestUtilityTransition.getSecond();
        return utility;
    }

    public MDPPolicy<STATE_TYPE, ACTION_TYPE> policyIteration(double gamma) {
        MDPUtilityFunction<STATE_TYPE> U = this.initialUtilityFunction();
        MDPPolicy<STATE_TYPE, ACTION_TYPE> pi = this.randomPolicy();
        boolean unchanged = false;
        do {
            unchanged = true;
            U = this.policyEvaluation(pi, U, gamma, 3);
            for (STATE_TYPE s : this.nonFinalstates) {
                Pair<ACTION_TYPE, Double> maxTransit = this.transitionModel.getTransitionWithMaximumExpectedUtility(s, U);
                Pair<ACTION_TYPE, Double> maxPolicyTransit = this.transitionModel.getTransitionWithMaximumExpectedUtilityUsingPolicy(pi, s, U);
                if (!(maxTransit.getSecond() > maxPolicyTransit.getSecond())) continue;
                pi.setAction(s, maxTransit.getFirst());
                unchanged = false;
            }
        } while (!unchanged);
        return pi;
    }

    public MDPUtilityFunction<STATE_TYPE> policyEvaluation(MDPPolicy<STATE_TYPE, ACTION_TYPE> pi, MDPUtilityFunction<STATE_TYPE> U, double gamma, int iterations) {
        MDPUtilityFunction<STATE_TYPE> U_dash = U.copy();
        for (int i = 0; i < iterations; ++i) {
            U_dash = this.valueIterateOnceWith(gamma, pi, U_dash);
        }
        return U_dash;
    }

    private MDPUtilityFunction<STATE_TYPE> valueIterateOnceWith(double gamma, MDPPolicy<STATE_TYPE, ACTION_TYPE> pi, MDPUtilityFunction<STATE_TYPE> U) {
        MDPUtilityFunction<STATE_TYPE> U_dash = U.copy();
        for (STATE_TYPE s : this.nonFinalstates) {
            Pair<ACTION_TYPE, Double> highestPolicyTransition = this.transitionModel.getTransitionWithMaximumExpectedUtilityUsingPolicy(pi, s, U);
            double utility = this.rewardFunction.getRewardFor(s) + gamma * highestPolicyTransition.getSecond();
            U_dash.setUtility(s, utility);
        }
        return U_dash;
    }

    public MDPPolicy<STATE_TYPE, ACTION_TYPE> randomPolicy() {
        MDPPolicy<STATE_TYPE, ACTION_TYPE> policy = new MDPPolicy<STATE_TYPE, ACTION_TYPE>();
        for (STATE_TYPE s : this.nonFinalstates) {
            policy.setAction(s, this.transitionModel.randomActionFor(s));
        }
        return policy;
    }

    public MDPUtilityFunction<STATE_TYPE> initialUtilityFunction() {
        return this.rewardFunction.asUtilityFunction();
    }

    public STATE_TYPE getInitialState() {
        return this.initialState;
    }

    public double getRewardFor(STATE_TYPE state) {
        return this.rewardFunction.getRewardFor(state);
    }

    public void setReward(STATE_TYPE state, double reward) {
        this.rewardFunction.setReward(state, reward);
    }

    public void setTransitionProbability(MDPTransition<STATE_TYPE, ACTION_TYPE> transition, double probability) {
        this.transitionModel.setTransitionProbability(transition.getInitialState(), transition.getAction(), transition.getDestinationState(), probability);
    }

    public double getTransitionProbability(MDPTransition<STATE_TYPE, ACTION_TYPE> transition) {
        return this.transitionModel.getTransitionProbability(transition.getInitialState(), transition.getAction(), transition.getDestinationState());
    }

    public MDPPerception<STATE_TYPE> execute(STATE_TYPE state, ACTION_TYPE action, Randomizer r) {
        return this.source.execute(state, action, r);
    }

    public boolean isTerminalState(STATE_TYPE state) {
        return this.terminalStates.contains(state);
    }

    public List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsWith(STATE_TYPE initialState, ACTION_TYPE action) {
        return this.transitionModel.getTransitionsWithStartingStateAndAction(initialState, action);
    }

    public List<ACTION_TYPE> getAllActions() {
        return this.source.getAllActions();
    }

    public String toString() {
        return "initial State = " + this.initialState.toString() + "\n rewardFunction = " + this.rewardFunction.toString() + "\n transitionModel = " + this.transitionModel.toString() + "\n states = " + this.nonFinalstates.toString();
    }
}

