/*
 * Decompiled with CFR 0.152.
 */
package aima.learning.neural;

import aima.learning.framework.DataSet;
import aima.learning.neural.FunctionApproximator;
import aima.learning.neural.Layer;
import aima.learning.neural.LogSigActivationFunction;
import aima.learning.neural.NNConfig;
import aima.learning.neural.NNDataSet;
import aima.learning.neural.NNExample;
import aima.learning.neural.NNTrainingScheme;
import aima.learning.neural.PureLinearActivationFunction;
import aima.learning.neural.Vector;
import aima.util.Matrix;

public class FeedForwardNeuralNetwork
implements FunctionApproximator {
    public static final String UPPER_LIMIT_WEIGHTS = "upper_limit_weights";
    public static final String LOWER_LIMIT_WEIGHTS = "lower_limit_weights";
    public static final String NUMBER_OF_OUTPUTS = "number_of_outputs";
    public static final String NUMBER_OF_HIDDEN_NEURONS = "number_of_hidden_neurons";
    public static final String NUMBER_OF_INPUTS = "number_of_inputs";
    private final Layer hiddenLayer;
    private final Layer outputLayer;
    private NNTrainingScheme trainingScheme;

    public FeedForwardNeuralNetwork(NNConfig nNConfig) {
        int n = nNConfig.getParameterAsInteger(NUMBER_OF_INPUTS);
        int n2 = nNConfig.getParameterAsInteger(NUMBER_OF_HIDDEN_NEURONS);
        int n3 = nNConfig.getParameterAsInteger(NUMBER_OF_OUTPUTS);
        double d = nNConfig.getParameterAsDouble(LOWER_LIMIT_WEIGHTS);
        double d2 = nNConfig.getParameterAsDouble(UPPER_LIMIT_WEIGHTS);
        this.hiddenLayer = new Layer(n2, n, d, d2, new LogSigActivationFunction());
        this.outputLayer = new Layer(n3, n2, d, d2, new PureLinearActivationFunction());
    }

    public FeedForwardNeuralNetwork(Matrix matrix, Vector vector, Matrix matrix2, Vector vector2) {
        this.hiddenLayer = new Layer(matrix, vector, new LogSigActivationFunction());
        this.outputLayer = new Layer(matrix2, vector2, new PureLinearActivationFunction());
    }

    @Override
    public void processError(Vector vector) {
        this.trainingScheme.processError(this, vector);
    }

    @Override
    public Vector processInput(Vector vector) {
        return this.trainingScheme.processInput(this, vector);
    }

    public void trainOn(NNDataSet nNDataSet, int n) {
        for (int i = 0; i < n; ++i) {
            nNDataSet.refreshDataset();
            while (nNDataSet.hasMoreExamples()) {
                NNExample nNExample = nNDataSet.getExampleAtRandom();
                this.processInput(nNExample.getInput());
                Vector vector = this.getOutputLayer().errorVectorFrom(nNExample.getTarget());
                this.processError(vector);
            }
        }
    }

    public Vector predict(NNExample nNExample) {
        return this.processInput(nNExample.getInput());
    }

    public int[] testOnDataSet(NNDataSet nNDataSet) {
        int[] nArray = new int[]{0, 0};
        nNDataSet.refreshDataset();
        while (nNDataSet.hasMoreExamples()) {
            Vector vector;
            NNExample nNExample = nNDataSet.getExampleAtRandom();
            if (nNExample.isCorrect(vector = this.predict(nNExample))) {
                nArray[0] = nArray[0] + 1;
                continue;
            }
            nArray[1] = nArray[1] + 1;
        }
        return nArray;
    }

    public void testOn(DataSet dataSet) {
    }

    public Matrix getHiddenLayerWeights() {
        return this.hiddenLayer.getWeightMatrix();
    }

    public Vector getHiddenLayerBias() {
        return this.hiddenLayer.getBiasVector();
    }

    public Matrix getOutputLayerWeights() {
        return this.outputLayer.getWeightMatrix();
    }

    public Vector getOutputLayerBias() {
        return this.outputLayer.getBiasVector();
    }

    public Layer getHiddenLayer() {
        return this.hiddenLayer;
    }

    public Layer getOutputLayer() {
        return this.outputLayer;
    }

    public void setTrainingScheme(NNTrainingScheme nNTrainingScheme) {
        this.trainingScheme = nNTrainingScheme;
        nNTrainingScheme.setNeuralNetwork(this);
    }
}

