package smile.classification;

import com.github.mikephil.charting.utils.Utils;
import java.io.Serializable;
import java.lang.reflect.Array;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.Classifier;
import smile.math.Math;

/* loaded from: classes2.dex */
public class NeuralNetwork implements OnlineClassifier<double[]>, SoftClassifier<double[]> {
    private static final Logger a = LoggerFactory.a((Class<?>) NeuralNetwork.class);
    private static final long serialVersionUID = 1;
    private ActivationFunction activationFunction;
    private double alpha;
    private ErrorFunction errorFunction;
    private double eta;
    private Layer inputLayer;
    private int k;
    private double lambda;
    private Layer[] net;
    private Layer outputLayer;
    private int p;
    private double[] target;

    /* loaded from: classes2.dex */
    public enum ActivationFunction {
        LINEAR,
        LOGISTIC_SIGMOID,
        SOFTMAX
    }

    /* loaded from: classes2.dex */
    public enum ErrorFunction {
        LEAST_MEAN_SQUARES,
        CROSS_ENTROPY
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes2.dex */
    public class Layer implements Serializable {
        private static final long serialVersionUID = 1;
        double[][] delta;
        double[] error;
        double[] output;
        int units;
        double[][] weight;

        private Layer() {
        }
    }

    /* loaded from: classes2.dex */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private ErrorFunction b;
        private ActivationFunction c;
        private int[] d;
        private double e;
        private double f;
        private double g;
        private int h;

        @Override // smile.classification.ClassifierTrainer
        public NeuralNetwork a(double[][] dArr, int[] iArr) {
            NeuralNetwork neuralNetwork = new NeuralNetwork(this.b, this.c, this.d);
            neuralNetwork.setLearningRate(this.e);
            neuralNetwork.setMomentum(this.f);
            neuralNetwork.setWeightDecay(this.g);
            for (int i = 1; i <= this.h; i++) {
                neuralNetwork.learn(dArr, iArr);
                NeuralNetwork.a.info("Neural network learns epoch {}", Integer.valueOf(i));
            }
            return neuralNetwork;
        }
    }

    private NeuralNetwork() {
        this.errorFunction = ErrorFunction.LEAST_MEAN_SQUARES;
        this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        this.eta = 0.1d;
        this.alpha = Utils.a;
        this.lambda = Utils.a;
    }

    public NeuralNetwork(ErrorFunction errorFunction, ActivationFunction activationFunction, int... iArr) {
        double d;
        this.errorFunction = ErrorFunction.LEAST_MEAN_SQUARES;
        this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        this.eta = 0.1d;
        this.alpha = Utils.a;
        this.lambda = Utils.a;
        int length = iArr.length;
        if (length < 2) {
            throw new IllegalArgumentException("Invalid number of layers: " + length);
        }
        for (int i = 0; i < length; i++) {
            if (iArr[i] < 1) {
                throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", Integer.valueOf(i + 1), Integer.valueOf(iArr[i])));
            }
        }
        if (errorFunction == ErrorFunction.LEAST_MEAN_SQUARES && activationFunction == ActivationFunction.SOFTMAX) {
            throw new IllegalArgumentException("Sofmax activation function is invalid for least mean squares error.");
        }
        if (errorFunction == ErrorFunction.CROSS_ENTROPY) {
            if (activationFunction == ActivationFunction.LINEAR) {
                throw new IllegalArgumentException("Linear activation function is invalid with cross entropy error.");
            }
            if (activationFunction == ActivationFunction.SOFTMAX && iArr[length - 1] == 1) {
                throw new IllegalArgumentException("Softmax activation function is for multi-class.");
            }
            if (activationFunction == ActivationFunction.LOGISTIC_SIGMOID && iArr[length - 1] != 1) {
                throw new IllegalArgumentException("For cross entropy error, logistic sigmoid output is for binary classification.");
            }
        }
        this.errorFunction = errorFunction;
        this.activationFunction = activationFunction;
        if (errorFunction == ErrorFunction.CROSS_ENTROPY) {
            this.alpha = Utils.a;
            this.lambda = Utils.a;
        }
        this.p = iArr[0];
        int i2 = length - 1;
        this.k = iArr[i2] == 1 ? 2 : iArr[i2];
        this.target = new double[iArr[i2]];
        this.net = new Layer[length];
        int i3 = 0;
        while (true) {
            d = 1.0d;
            if (i3 >= length) {
                break;
            }
            this.net[i3] = new Layer();
            this.net[i3].units = iArr[i3];
            this.net[i3].output = new double[iArr[i3] + 1];
            this.net[i3].error = new double[iArr[i3] + 1];
            this.net[i3].output[iArr[i3]] = 1.0d;
            i3++;
        }
        Layer[] layerArr = this.net;
        this.inputLayer = layerArr[0];
        this.outputLayer = layerArr[i2];
        int i4 = 1;
        while (i4 < length) {
            int i5 = i4 - 1;
            this.net[i4].weight = (double[][]) Array.newInstance((Class<?>) double.class, iArr[i4], iArr[i5] + 1);
            this.net[i4].delta = (double[][]) Array.newInstance((Class<?>) double.class, iArr[i4], iArr[i5] + 1);
            double n = d / Math.n(this.net[i5].units);
            for (int i6 = 0; i6 < this.net[i4].units; i6++) {
                for (int i7 = 0; i7 <= this.net[i5].units; i7++) {
                    this.net[i4].weight[i6][i7] = Math.h(-n, n);
                }
            }
            i4++;
            d = 1.0d;
        }
    }

    public NeuralNetwork(ErrorFunction errorFunction, int... iArr) {
        this(errorFunction, a(errorFunction, iArr[iArr.length - 1]), iArr);
    }

    private static double a(double d) {
        if (d < 1.0E-300d) {
            return -690.7755d;
        }
        return Math.g(d);
    }

    private double a(double[] dArr, double[] dArr2) {
        double d;
        double a2;
        if (dArr.length != this.outputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.outputLayer.units)));
        }
        double d2 = Utils.a;
        for (int i = 0; i < this.outputLayer.units; i++) {
            double d3 = this.outputLayer.output[i];
            double d4 = dArr[i] - d3;
            if (this.errorFunction == ErrorFunction.LEAST_MEAN_SQUARES) {
                d2 += 0.5d * d4 * d4;
            } else if (this.errorFunction == ErrorFunction.CROSS_ENTROPY) {
                if (this.activationFunction == ActivationFunction.SOFTMAX) {
                    d = dArr[i];
                    a2 = a(d3);
                } else if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                    d2 = (-dArr[i]) * a(d3);
                    d = 1.0d - dArr[i];
                    a2 = a(1.0d - d3);
                }
                d2 -= d * a2;
            }
            if (this.errorFunction == ErrorFunction.LEAST_MEAN_SQUARES && this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                d4 *= d3 * (1.0d - d3);
            }
            dArr2[i] = d4;
        }
        return d2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ActivationFunction a(ErrorFunction errorFunction, int i) {
        if (errorFunction == ErrorFunction.CROSS_ENTROPY && i != 1) {
            return ActivationFunction.SOFTMAX;
        }
        return ActivationFunction.LOGISTIC_SIGMOID;
    }

    private void a() {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.outputLayer.units; i++) {
            if (this.outputLayer.output[i] > d) {
                d = this.outputLayer.output[i];
            }
        }
        double d2 = Utils.a;
        for (int i2 = 0; i2 < this.outputLayer.units; i2++) {
            double e = Math.e(this.outputLayer.output[i2] - d);
            this.outputLayer.output[i2] = e;
            d2 += e;
        }
        for (int i3 = 0; i3 < this.outputLayer.units; i3++) {
            double[] dArr = this.outputLayer.output;
            dArr[i3] = dArr[i3] / d2;
        }
    }

    private void a(Layer layer, Layer layer2) {
        for (int i = 0; i < layer2.units; i++) {
            double d = Utils.a;
            for (int i2 = 0; i2 <= layer.units; i2++) {
                d += layer2.weight[i][i2] * layer.output[i2];
            }
            if (layer2 != this.outputLayer || this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                layer2.output[i] = Math.r(d);
            } else {
                if (this.activationFunction != ActivationFunction.LINEAR && this.activationFunction != ActivationFunction.SOFTMAX) {
                    throw new UnsupportedOperationException("Unsupported activation function.");
                }
                layer2.output[i] = d;
            }
        }
        if (layer2 == this.outputLayer && this.activationFunction == ActivationFunction.SOFTMAX) {
            a();
        }
    }

    private void a(double[] dArr) {
        if (dArr.length != this.inputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.inputLayer.units)));
        }
        System.arraycopy(dArr, 0, this.inputLayer.output, 0, this.inputLayer.units);
    }

    private void b() {
        int i = 0;
        while (true) {
            Layer[] layerArr = this.net;
            if (i >= layerArr.length - 1) {
                return;
            }
            Layer layer = layerArr[i];
            i++;
            a(layer, layerArr[i]);
        }
    }

    private void b(Layer layer, Layer layer2) {
        for (int i = 0; i <= layer2.units; i++) {
            double d = layer2.output[i];
            double d2 = Utils.a;
            for (int i2 = 0; i2 < layer.units; i2++) {
                d2 += layer.weight[i2][i] * layer.error[i2];
            }
            layer2.error[i] = d * (1.0d - d) * d2;
        }
    }

    private void b(double[] dArr) {
        if (dArr.length != this.outputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.outputLayer.units)));
        }
        System.arraycopy(this.outputLayer.output, 0, dArr, 0, this.outputLayer.units);
    }

    private double c(double[] dArr) {
        return a(dArr, this.outputLayer.error);
    }

    private void c() {
        int length = this.net.length;
        while (true) {
            length--;
            if (length <= 0) {
                return;
            }
            Layer[] layerArr = this.net;
            b(layerArr[length], layerArr[length - 1]);
        }
    }

    private void d() {
        for (int i = 1; i < this.net.length; i++) {
            for (int i2 = 0; i2 < this.net[i].units; i2++) {
                int i3 = 0;
                while (true) {
                    int i4 = i - 1;
                    if (i3 <= this.net[i4].units) {
                        double d = this.net[i4].output[i3];
                        double d2 = this.net[i].error[i2];
                        double d3 = this.alpha;
                        double d4 = ((1.0d - d3) * this.eta * d2 * d) + (d3 * this.net[i].delta[i2][i3]);
                        this.net[i].delta[i2][i3] = d4;
                        double[] dArr = this.net[i].weight[i2];
                        dArr[i3] = dArr[i3] + d4;
                        if (this.lambda != Utils.a && i3 < this.net[i4].units) {
                            double[] dArr2 = this.net[i].weight[i2];
                            dArr2[i3] = dArr2[i3] * (1.0d - (this.eta * this.lambda));
                        }
                        i3++;
                    }
                }
            }
        }
    }

    public NeuralNetwork clone() {
        NeuralNetwork neuralNetwork = new NeuralNetwork();
        neuralNetwork.errorFunction = this.errorFunction;
        neuralNetwork.activationFunction = this.activationFunction;
        neuralNetwork.p = this.p;
        neuralNetwork.k = this.k;
        neuralNetwork.eta = this.eta;
        neuralNetwork.alpha = this.alpha;
        neuralNetwork.lambda = this.lambda;
        neuralNetwork.target = (double[]) this.target.clone();
        int length = this.net.length;
        neuralNetwork.net = new Layer[length];
        for (int i = 0; i < length; i++) {
            neuralNetwork.net[i] = new Layer();
            neuralNetwork.net[i].units = this.net[i].units;
            neuralNetwork.net[i].output = (double[]) this.net[i].output.clone();
            neuralNetwork.net[i].error = (double[]) this.net[i].error.clone();
            if (i > 0) {
                neuralNetwork.net[i].weight = Math.f(this.net[i].weight);
                neuralNetwork.net[i].delta = Math.f(this.net[i].delta);
            }
        }
        Layer[] layerArr = neuralNetwork.net;
        neuralNetwork.inputLayer = layerArr[0];
        neuralNetwork.outputLayer = layerArr[length - 1];
        return neuralNetwork;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public double[][] getWeight(int i) {
        return this.net[i].weight;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double learn(double[] dArr, double[] dArr2, double d) {
        a(dArr);
        b();
        double c = c(dArr2) * d;
        if (d != 1.0d) {
            for (int i = 0; i < this.outputLayer.units; i++) {
                double[] dArr3 = this.outputLayer.error;
                dArr3[i] = dArr3[i] * d;
            }
        }
        c();
        d();
        return c;
    }

    @Override // smile.classification.OnlineClassifier
    public void learn(double[] dArr, int i) {
        learn(dArr, i, 1.0d);
    }

    public void learn(double[] dArr, int i, double d) {
        double[] dArr2;
        double[] dArr3;
        if (d < Utils.a) {
            throw new IllegalArgumentException("Invalid weight: " + d);
        }
        if (d == Utils.a) {
            a.info("Ignore the training instance with zero weight.");
            return;
        }
        if (i < 0) {
            throw new IllegalArgumentException("Invalid class label: " + i);
        }
        if (this.outputLayer.units == 1 && i > 1) {
            throw new IllegalArgumentException("Invalid class label: " + i);
        }
        if (this.outputLayer.units > 1 && i >= this.outputLayer.units) {
            throw new IllegalArgumentException("Invalid class label: " + i);
        }
        int i2 = 0;
        if (this.errorFunction != ErrorFunction.CROSS_ENTROPY) {
            while (true) {
                dArr2 = this.target;
                if (i2 >= dArr2.length) {
                    break;
                }
                dArr2[i2] = 0.1d;
                i2++;
            }
            dArr2[i] = 0.9d;
        } else if (this.activationFunction != ActivationFunction.LOGISTIC_SIGMOID) {
            while (true) {
                dArr3 = this.target;
                if (i2 >= dArr3.length) {
                    break;
                }
                dArr3[i2] = 0.0d;
                i2++;
            }
            dArr3[i] = 1.0d;
        } else if (i == 0) {
            this.target[0] = 1.0d;
        } else {
            this.target[0] = 0.0d;
        }
        learn(dArr, this.target, d);
    }

    public void learn(double[][] dArr, int[] iArr) {
        int length = dArr.length;
        int[] f = Math.f(length);
        for (int i = 0; i < length; i++) {
            learn(dArr[f[i]], iArr[f[i]]);
        }
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        a(dArr);
        b();
        if (this.outputLayer.units == 1) {
            return this.outputLayer.output[0] > 0.5d ? 0 : 1;
        }
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.outputLayer.units; i2++) {
            if (this.outputLayer.output[i2] > d) {
                d = this.outputLayer.output[i2];
                i = i2;
            }
        }
        return i;
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        a(dArr);
        b();
        b(dArr2);
        if (this.outputLayer.units == 1) {
            return this.outputLayer.output[0] > 0.5d ? 0 : 1;
        }
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.outputLayer.units; i2++) {
            if (this.outputLayer.output[i2] > d) {
                d = this.outputLayer.output[i2];
                i = i2;
            }
        }
        return i;
    }

    @Override // smile.classification.Classifier
    public /* synthetic */ int[] predict(T[] tArr) {
        return Classifier.CC.$default$predict(this, tArr);
    }

    public void setLearningRate(double d) {
        if (d > Utils.a) {
            this.eta = d;
            return;
        }
        throw new IllegalArgumentException("Invalid learning rate: " + d);
    }

    public void setMomentum(double d) {
        if (d >= Utils.a && d < 1.0d) {
            this.alpha = d;
            return;
        }
        throw new IllegalArgumentException("Invalid momentum factor: " + d);
    }

    public void setWeightDecay(double d) {
        if (d >= Utils.a && d <= 0.1d) {
            this.lambda = d;
            return;
        }
        throw new IllegalArgumentException("Invalid weight decay factor: " + d);
    }
}
