package smile.classification;

import com.github.mikephil.charting.utils.Utils;
import java.lang.reflect.Array;
import java.util.Arrays;
import smile.classification.Classifier;
import smile.math.Math;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;

/* loaded from: classes2.dex */
public class RDA implements SoftClassifier<double[]> {
    private static final long serialVersionUID = 1;
    private final double[] ct;
    private double[][] ev;
    private int k;
    private double[][] mu;
    private int p;
    private double[] priori;
    private DenseMatrix[] scaling;

    /* loaded from: classes2.dex */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private double b;
        private double[] c;
        private double d;

        @Override // smile.classification.ClassifierTrainer
        public RDA a(double[][] dArr, int[] iArr) {
            return new RDA(dArr, iArr, this.c, this.b, this.d);
        }
    }

    public RDA(double[][] dArr, int[] iArr, double d) {
        this(dArr, iArr, null, d);
    }

    public RDA(double[][] dArr, int[] iArr, double[] dArr2, double d) {
        this(dArr, iArr, dArr2, d, 1.0E-4d);
    }

    public RDA(double[][] dArr, int[] iArr, double[] dArr2, double d, double d2) {
        int[] iArr2 = iArr;
        double[] dArr3 = dArr2;
        if (dArr.length != iArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        double d3 = Utils.a;
        if (d >= Utils.a) {
            double d4 = 1.0d;
            if (d <= 1.0d) {
                if (dArr3 != null) {
                    if (dArr3.length < 2) {
                        throw new IllegalArgumentException("Invalid number of priori probabilities: " + dArr3.length);
                    }
                    int length = dArr3.length;
                    double d5 = 0.0d;
                    int i = 0;
                    while (i < length) {
                        double d6 = dArr3[i];
                        if (d6 <= d3 || d6 >= 1.0d) {
                            throw new IllegalArgumentException("Invalid priori probability: " + d6);
                        }
                        d5 += d6;
                        i++;
                        d3 = Utils.a;
                        d4 = 1.0d;
                    }
                    double d7 = d5;
                    if (Math.a(d7 - d4) > 1.0E-10d) {
                        throw new IllegalArgumentException("The sum of priori probabilities is not one: " + d7);
                    }
                }
                int[] h = Math.h(iArr);
                Arrays.sort(h);
                for (int i2 = 0; i2 < h.length; i2++) {
                    if (h[i2] < 0) {
                        throw new IllegalArgumentException("Negative class label: " + h[i2]);
                    }
                    if (i2 > 0) {
                        int i3 = i2 - 1;
                        if (h[i2] - h[i3] > 1) {
                            throw new IllegalArgumentException("Missing class: " + (h[i3] + 1));
                        }
                    }
                }
                int length2 = h.length;
                this.k = length2;
                if (length2 < 2) {
                    throw new IllegalArgumentException("Only one class.");
                }
                if (dArr3 != null && length2 != dArr3.length) {
                    throw new IllegalArgumentException("The number of classes and the number of priori probabilities don't match.");
                }
                if (d2 < Utils.a) {
                    throw new IllegalArgumentException("Invalid tol: " + d2);
                }
                int length3 = dArr.length;
                int i4 = this.k;
                if (length3 <= i4) {
                    throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", Integer.valueOf(length3), Integer.valueOf(this.k)));
                }
                this.p = dArr[0].length;
                int[] iArr3 = new int[i4];
                double[] d8 = Math.d(dArr);
                int i5 = this.p;
                DenseMatrix zeros = Matrix.CC.zeros(i5, i5);
                this.mu = (double[][]) Array.newInstance((Class<?>) double.class, this.k, this.p);
                DenseMatrix[] denseMatrixArr = new DenseMatrix[this.k];
                for (int i6 = 0; i6 < length3; i6++) {
                    int i7 = iArr2[i6];
                    iArr3[i7] = iArr3[i7] + 1;
                    for (int i8 = 0; i8 < this.p; i8++) {
                        double[] dArr4 = this.mu[i7];
                        dArr4[i8] = dArr4[i8] + dArr[i6][i8];
                    }
                }
                int i9 = 0;
                while (true) {
                    int i10 = this.k;
                    if (i9 >= i10) {
                        int i11 = length3;
                        if (dArr3 == null) {
                            dArr3 = new double[i10];
                            int i12 = 0;
                            while (i12 < this.k) {
                                dArr3[i12] = iArr3[i12] / i11;
                                i12++;
                                iArr3 = iArr3;
                            }
                        }
                        int[] iArr4 = iArr3;
                        this.priori = dArr3;
                        int i13 = 0;
                        while (i13 < i11) {
                            int i14 = iArr2[i13];
                            for (int i15 = 0; i15 < this.p; i15++) {
                                int i16 = 0;
                                while (i16 <= i15) {
                                    DenseMatrix denseMatrix = denseMatrixArr[i14];
                                    double d9 = dArr[i13][i15];
                                    double[][] dArr5 = this.mu;
                                    denseMatrix.add(i15, i16, (d9 - dArr5[i14][i15]) * (dArr[i13][i16] - dArr5[i14][i16]));
                                    zeros.add(i15, i16, (dArr[i13][i15] - d8[i15]) * (dArr[i13][i16] - d8[i16]));
                                    i16++;
                                    dArr3 = dArr3;
                                }
                            }
                            i13++;
                            iArr2 = iArr;
                        }
                        double[] dArr6 = dArr3;
                        double d10 = d2 * d2;
                        for (int i17 = 0; i17 < this.p; i17++) {
                            for (int i18 = 0; i18 <= i17; i18++) {
                                zeros.div(i17, i18, i11 - this.k);
                                zeros.set(i18, i17, zeros.get(i17, i18));
                            }
                            if (zeros.get(i17, i17) < d10) {
                                throw new IllegalArgumentException(String.format("Covariance matrix (variable %d) is close to singular.", Integer.valueOf(i17)));
                            }
                        }
                        this.ev = new double[this.k];
                        int i19 = 0;
                        while (true) {
                            int i20 = this.k;
                            if (i19 >= i20) {
                                this.scaling = denseMatrixArr;
                                this.ct = new double[i20];
                                for (int i21 = 0; i21 < this.k; i21++) {
                                    double d11 = Utils.a;
                                    for (int i22 = 0; i22 < this.p; i22++) {
                                        d11 += Math.g(this.ev[i21][i22]);
                                    }
                                    this.ct[i21] = Math.g(dArr6[i21]) - (d11 * 0.5d);
                                }
                                return;
                            }
                            for (int i23 = 0; i23 < this.p; i23++) {
                                for (int i24 = 0; i24 <= i23; i24++) {
                                    denseMatrixArr[i19].div(i23, i24, iArr4[i19] - 1);
                                    denseMatrixArr[i19].set(i23, i24, (denseMatrixArr[i19].get(i23, i24) * d) + ((1.0d - d) * zeros.get(i23, i24)));
                                    denseMatrixArr[i19].set(i24, i23, denseMatrixArr[i19].get(i23, i24));
                                }
                                if (denseMatrixArr[i19].get(i23, i23) < d10) {
                                    throw new IllegalArgumentException(String.format("Class %d covariance matrix (variable %d) is close to singular.", Integer.valueOf(i19), Integer.valueOf(i23)));
                                }
                            }
                            denseMatrixArr[i19].setSymmetric(true);
                            EVD eigen = denseMatrixArr[i19].eigen();
                            for (double d12 : eigen.b()) {
                                if (d12 < d10) {
                                    throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", Integer.valueOf(i19)));
                                }
                            }
                            this.ev[i19] = eigen.b();
                            denseMatrixArr[i19] = eigen.a();
                            i19++;
                        }
                    } else {
                        if (iArr3[i9] <= 1) {
                            throw new IllegalArgumentException(String.format("Class %d has only one sample.", Integer.valueOf(i9)));
                        }
                        int i25 = this.p;
                        denseMatrixArr[i9] = Matrix.CC.zeros(i25, i25);
                        int i26 = 0;
                        while (i26 < this.p) {
                            double[] dArr7 = this.mu[i9];
                            dArr7[i26] = dArr7[i26] / iArr3[i9];
                            i26++;
                            length3 = length3;
                        }
                        i9++;
                    }
                }
            }
        }
        throw new IllegalArgumentException("Invalid regularization factor: " + d);
    }

    public double[] getPriori() {
        return this.priori;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        return predict(dArr, (double[]) null);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        double d;
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        if (dArr2 != null && dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i = this.p;
        double[] dArr3 = new double[i];
        double[] dArr4 = new double[i];
        int i2 = 0;
        int i3 = 0;
        while (true) {
            int i4 = this.k;
            d = Utils.a;
            if (i2 >= i4) {
                break;
            }
            for (int i5 = 0; i5 < this.p; i5++) {
                dArr3[i5] = dArr[i5] - this.mu[i2][i5];
            }
            this.scaling[i2].atx(dArr3, dArr4);
            for (int i6 = 0; i6 < this.p; i6++) {
                d += (dArr4[i6] * dArr4[i6]) / this.ev[i2][i6];
            }
            double d3 = this.ct[i2] - (d * 0.5d);
            if (d2 < d3) {
                i3 = i2;
                d2 = d3;
            }
            if (dArr2 != null) {
                dArr2[i2] = d3;
            }
            i2++;
        }
        if (dArr2 != null) {
            for (int i7 = 0; i7 < this.k; i7++) {
                dArr2[i7] = Math.e(dArr2[i7] - d2);
                d += dArr2[i7];
            }
            for (int i8 = 0; i8 < this.k; i8++) {
                dArr2[i8] = dArr2[i8] / d;
            }
        }
        return i3;
    }

    @Override // smile.classification.Classifier
    public /* synthetic */ int[] predict(T[] tArr) {
        return Classifier.CC.$default$predict(this, tArr);
    }
}
