/*
 * Decompiled with CFR 0.152.
 */
package aima.learning.neural;

import aima.learning.neural.FeedForwardNeuralNetwork;
import aima.learning.neural.FunctionApproximator;
import aima.learning.neural.Layer;
import aima.learning.neural.LayerSensitivity;
import aima.learning.neural.NNTrainingScheme;
import aima.learning.neural.Vector;
import aima.util.Matrix;

public class BackPropLearning
implements NNTrainingScheme {
    private final double learningRate;
    private final double momentum;
    private Layer hiddenLayer;
    private Layer outputLayer;
    private LayerSensitivity hiddenSensitivity;
    private LayerSensitivity outputSensitivity;

    public BackPropLearning(double d, double d2) {
        this.learningRate = d;
        this.momentum = d2;
    }

    @Override
    public void setNeuralNetwork(FunctionApproximator functionApproximator) {
        FeedForwardNeuralNetwork feedForwardNeuralNetwork = (FeedForwardNeuralNetwork)functionApproximator;
        this.hiddenLayer = feedForwardNeuralNetwork.getHiddenLayer();
        this.outputLayer = feedForwardNeuralNetwork.getOutputLayer();
        this.hiddenSensitivity = new LayerSensitivity(this.hiddenLayer);
        this.outputSensitivity = new LayerSensitivity(this.outputLayer);
    }

    @Override
    public Vector processInput(FeedForwardNeuralNetwork feedForwardNeuralNetwork, Vector vector) {
        this.hiddenLayer.feedForward(vector);
        this.outputLayer.feedForward(this.hiddenLayer.getLastActivationValues());
        return this.outputLayer.getLastActivationValues();
    }

    @Override
    public void processError(FeedForwardNeuralNetwork feedForwardNeuralNetwork, Vector vector) {
        this.outputSensitivity.sensitivityMatrixFromErrorMatrix(vector);
        this.hiddenSensitivity.sensitivityMatrixFromSucceedingLayer(this.outputSensitivity);
        this.calculateWeightUpdates(this.outputSensitivity, this.hiddenLayer.getLastActivationValues(), this.learningRate, this.momentum);
        this.calculateWeightUpdates(this.hiddenSensitivity, this.hiddenLayer.getLastInputValues(), this.learningRate, this.momentum);
        this.calculateBiasUpdates(this.outputSensitivity, this.learningRate, this.momentum);
        this.calculateBiasUpdates(this.hiddenSensitivity, this.learningRate, this.momentum);
        this.outputLayer.updateWeights();
        this.outputLayer.updateBiases();
        this.hiddenLayer.updateWeights();
        this.hiddenLayer.updateBiases();
    }

    public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity, Vector vector, double d, double d2) {
        Layer layer = layerSensitivity.getLayer();
        Matrix matrix = vector.transpose();
        Matrix matrix2 = layerSensitivity.getSensitivityMatrix().times(matrix).times(d).times(-1.0);
        Matrix matrix3 = layer.getLastWeightUpdateMatrix().times(d2).plus(matrix2.times(1.0 - d2));
        layer.acceptNewWeightUpdate(matrix3.copy());
        return matrix3;
    }

    public static Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity, Vector vector, double d) {
        Layer layer = layerSensitivity.getLayer();
        Matrix matrix = vector.transpose();
        Matrix matrix2 = layerSensitivity.getSensitivityMatrix().times(matrix).times(d).times(-1.0);
        layer.acceptNewWeightUpdate(matrix2.copy());
        return matrix2;
    }

    public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity, double d, double d2) {
        Layer layer = layerSensitivity.getLayer();
        Matrix matrix = layerSensitivity.getSensitivityMatrix().times(d).times(-1.0);
        Matrix matrix2 = layer.getLastBiasUpdateVector().times(d2).plus(matrix.times(1.0 - d2));
        Vector vector = new Vector(matrix2.getRowDimension());
        for (int i = 0; i < matrix2.getRowDimension(); ++i) {
            vector.setValue(i, matrix2.get(i, 0));
        }
        layer.acceptNewBiasUpdate(vector.copyVector());
        return vector;
    }

    public static Vector calculateBiasUpdates(LayerSensitivity layerSensitivity, double d) {
        Layer layer = layerSensitivity.getLayer();
        Matrix matrix = layerSensitivity.getSensitivityMatrix().times(d).times(-1.0);
        Vector vector = new Vector(matrix.getRowDimension());
        for (int i = 0; i < matrix.getRowDimension(); ++i) {
            vector.setValue(i, matrix.get(i, 0));
        }
        layer.acceptNewBiasUpdate(vector.copyVector());
        return vector;
    }
}

