package com.hankcs.hanlp.model.crf.crfpp;

import com.bumptech.glide.load.Key;
import com.github.mikephil.charting.utils.Utils;
import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.model.crf.crfpp.TaggerImpl;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/* loaded from: classes.dex */
public class Encoder {
    public static int MODEL_VERSION = 100;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.hankcs.hanlp.model.crf.crfpp.Encoder$1, reason: invalid class name */
    /* loaded from: classes.dex */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm;

        static {
            int[] iArr = new int[Algorithm.values().length];
            $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm = iArr;
            try {
                iArr[Algorithm.CRF_L1.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm[Algorithm.CRF_L2.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm[Algorithm.MIRA.ordinal()] = 3;
            } catch (NoSuchFieldError unused3) {
            }
        }
    }

    /* loaded from: classes.dex */
    public enum Algorithm {
        CRF_L2,
        CRF_L1,
        MIRA;

        public static Algorithm fromString(String str) {
            String lowerCase = str.toLowerCase();
            if (lowerCase.equals("crf") || lowerCase.equals("crf-l2")) {
                return CRF_L2;
            }
            if (lowerCase.equals("crf-l1")) {
                return CRF_L1;
            }
            if (lowerCase.equals("mira")) {
                return MIRA;
            }
            throw new IllegalArgumentException("invalid algorithm: " + lowerCase);
        }
    }

    public static void main(String[] strArr) {
        if (strArr.length < 3) {
            System.err.println("incorrect No. of args");
            return;
        }
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        Encoder encoder = new Encoder();
        long time = new Date().getTime();
        if (encoder.learn(str, str2, str3, false, 100000, 1, 1.0E-4d, 1.0d, 1, 20, Algorithm.CRF_L2)) {
            System.out.println(new Date().getTime() - time);
        } else {
            System.err.println("error training model");
        }
    }

    private boolean runCRF(List<TaggerImpl> list, EncoderFeatureIndex encoderFeatureIndex, double[] dArr, int i, double d, double d2, int i2, int i3, boolean z) {
        ExecutorService executorService;
        int size;
        LbfgsOptimizer lbfgsOptimizer = new LbfgsOptimizer();
        ArrayList arrayList = new ArrayList();
        int i4 = 0;
        for (int i5 = 0; i5 < i3; i5++) {
            CRFEncoderThread cRFEncoderThread = new CRFEncoderThread(dArr.length);
            cRFEncoderThread.start_i = i5;
            cRFEncoderThread.size = list.size();
            cRFEncoderThread.threadNum = i3;
            cRFEncoderThread.x = list;
            arrayList.add(cRFEncoderThread);
        }
        int i6 = 0;
        for (int i7 = 0; i7 < list.size(); i7++) {
            i6 += list.get(i7).size();
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i3);
        double d3 = 1.0E37d;
        int i8 = 0;
        int i9 = 0;
        while (i9 < i) {
            encoderFeatureIndex.clear();
            try {
                newFixedThreadPool.invokeAll(arrayList);
                int i10 = 1;
                while (i10 < i3) {
                    ((CRFEncoderThread) arrayList.get(i4)).obj += ((CRFEncoderThread) arrayList.get(i10)).obj;
                    ((CRFEncoderThread) arrayList.get(0)).err += ((CRFEncoderThread) arrayList.get(i10)).err;
                    ((CRFEncoderThread) arrayList.get(0)).zeroone += ((CRFEncoderThread) arrayList.get(i10)).zeroone;
                    i10++;
                    i8 = i8;
                    i6 = i6;
                    newFixedThreadPool = newFixedThreadPool;
                    i4 = 0;
                }
                int i11 = i8;
                int i12 = i6;
                ExecutorService executorService2 = newFixedThreadPool;
                for (int i13 = 1; i13 < i3; i13++) {
                    for (int i14 = 0; i14 < encoderFeatureIndex.size(); i14++) {
                        double[] dArr2 = ((CRFEncoderThread) arrayList.get(0)).expected;
                        dArr2[i14] = dArr2[i14] + ((CRFEncoderThread) arrayList.get(i13)).expected[i14];
                    }
                }
                if (z) {
                    size = 0;
                    for (int i15 = 0; i15 < encoderFeatureIndex.size(); i15++) {
                        ((CRFEncoderThread) arrayList.get(0)).obj += Math.abs(dArr[i15] / d);
                        if (dArr[i15] != Utils.DOUBLE_EPSILON) {
                            size++;
                        }
                    }
                } else {
                    size = encoderFeatureIndex.size();
                    for (int i16 = 0; i16 < encoderFeatureIndex.size(); i16++) {
                        ((CRFEncoderThread) arrayList.get(0)).obj += (dArr[i16] * dArr[i16]) / (2.0d * d);
                        double[] dArr3 = ((CRFEncoderThread) arrayList.get(0)).expected;
                        dArr3[i16] = dArr3[i16] + (dArr[i16] / d);
                    }
                }
                for (int i17 = 1; i17 < i3; i17++) {
                    ((CRFEncoderThread) arrayList.get(i17)).expected = null;
                }
                double abs = i9 == 0 ? 1.0d : Math.abs(d3 - ((CRFEncoderThread) arrayList.get(0)).obj) / d3;
                System.out.println("iter=" + i9 + " terr=" + ((((CRFEncoderThread) arrayList.get(0)).err * 1.0d) / i12) + " serr=" + ((((CRFEncoderThread) arrayList.get(0)).zeroone * 1.0d) / list.size()) + " act=" + size + " obj=" + ((CRFEncoderThread) arrayList.get(0)).obj + " diff=" + abs);
                double d4 = ((CRFEncoderThread) arrayList.get(0)).obj;
                int i18 = abs < d2 ? i11 + 1 : 0;
                if (i9 > i || i18 == 3) {
                    executorService = executorService2;
                    break;
                }
                int i19 = i9;
                int i20 = i18;
                if (lbfgsOptimizer.optimize(encoderFeatureIndex.size(), dArr, ((CRFEncoderThread) arrayList.get(0)).obj, ((CRFEncoderThread) arrayList.get(0)).expected, z, d) <= 0) {
                    return false;
                }
                i9 = i19 + 1;
                newFixedThreadPool = executorService2;
                d3 = d4;
                i6 = i12;
                i8 = i20;
                i4 = 0;
            } catch (Exception e) {
                e.printStackTrace();
                return false;
            }
        }
        executorService = newFixedThreadPool;
        executorService.shutdown();
        try {
            executorService.awaitTermination(-1L, TimeUnit.SECONDS);
            return true;
        } catch (Exception e2) {
            e2.printStackTrace();
            System.err.println("fail waiting executor to shutdown");
            return true;
        }
    }

    public boolean learn(String str, String str2, String str3, boolean z, int i, int i2, double d, double d2, int i3, int i4, Algorithm algorithm) {
        if (d <= Utils.DOUBLE_EPSILON) {
            System.err.println("eta must be > 0.0");
            return false;
        }
        if (d2 < Utils.DOUBLE_EPSILON) {
            System.err.println("C must be >= 0.0");
            return false;
        }
        if (i4 < 1) {
            System.err.println("shrinkingSize must be >= 1");
            return false;
        }
        if (i3 <= 0) {
            System.err.println("thread must be  > 0");
            return false;
        }
        EncoderFeatureIndex encoderFeatureIndex = new EncoderFeatureIndex(i3);
        List<TaggerImpl> arrayList = new ArrayList<>();
        if (!encoderFeatureIndex.open(str, str2)) {
            System.err.println("Fail to open " + str + " " + str2);
        }
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(str2), Key.STRING_CHARSET_NAME));
            int i5 = 0;
            while (true) {
                TaggerImpl taggerImpl = new TaggerImpl(TaggerImpl.Mode.LEARN);
                taggerImpl.open(encoderFeatureIndex);
                TaggerImpl.ReadStatus read = taggerImpl.read(bufferedReader);
                if (read == TaggerImpl.ReadStatus.ERROR) {
                    System.err.println("error when reading " + str2);
                    return false;
                }
                if (taggerImpl.empty()) {
                    if (read == TaggerImpl.ReadStatus.EOF) {
                        bufferedReader.close();
                        encoderFeatureIndex.shrink(i2, arrayList);
                        double[] dArr = new double[encoderFeatureIndex.size()];
                        Arrays.fill(dArr, Utils.DOUBLE_EPSILON);
                        encoderFeatureIndex.setAlpha_(dArr);
                        System.out.println("Number of sentences: " + arrayList.size());
                        System.out.println("Number of features:  " + encoderFeatureIndex.size());
                        System.out.println("Number of thread(s): " + i3);
                        System.out.println("Freq:                " + i2);
                        System.out.println("eta:                 " + d);
                        System.out.println("C:                   " + d2);
                        System.out.println("shrinking size:      " + i4);
                        int i6 = AnonymousClass1.$SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm[algorithm.ordinal()];
                        if (i6 != 1) {
                            if (i6 != 2) {
                                if (i6 == 3 && !runMIRA(arrayList, encoderFeatureIndex, dArr, i, d2, d, i4, i3)) {
                                    System.err.println("MIRA execute error");
                                    return false;
                                }
                            } else if (!runCRF(arrayList, encoderFeatureIndex, dArr, i, d2, d, i4, i3, false)) {
                                System.err.println("CRF_L2 execute error");
                                return false;
                            }
                        } else if (!runCRF(arrayList, encoderFeatureIndex, dArr, i, d2, d, i4, i3, true)) {
                            System.err.println("CRF_L1 execute error");
                            return false;
                        }
                        if (!encoderFeatureIndex.save(str3, z)) {
                            System.err.println("Failed to save model");
                        }
                        System.out.println("Done!");
                        return true;
                    }
                } else {
                    if (!taggerImpl.shrink()) {
                        System.err.println("fail to build feature index ");
                        return false;
                    }
                    taggerImpl.setThread_id_(i5 % i3);
                    arrayList.add(taggerImpl);
                    i5++;
                    if (i5 % 100 == 0) {
                        System.out.print(i5 + ".. ");
                    }
                }
            }
        } catch (IOException unused) {
            System.err.println("train file " + str2 + " does not exist.");
            return false;
        }
    }

    public boolean runMIRA(List<TaggerImpl> list, EncoderFeatureIndex encoderFeatureIndex, double[] dArr, int i, double d, double d2, int i2, int i3) {
        int i4;
        List<TaggerImpl> list2 = list;
        int i5 = i;
        double d3 = d;
        Integer[] numArr = new Integer[list.size()];
        int i6 = 0;
        Arrays.fill((Object[]) numArr, (Object) 0);
        List asList = Arrays.asList(numArr);
        Double[] dArr2 = new Double[list.size()];
        double d4 = Utils.DOUBLE_EPSILON;
        Double valueOf = Double.valueOf(Utils.DOUBLE_EPSILON);
        Arrays.fill(dArr2, valueOf);
        List asList2 = Arrays.asList(dArr2);
        List<Double> asList3 = Arrays.asList(new Double[encoderFeatureIndex.size()]);
        if (i3 > 1) {
            System.err.println("WARN: MIRA does not support multi-threading");
        }
        int i7 = 0;
        for (int i8 = 0; i8 < list.size(); i8++) {
            i7 += list2.get(i8).size();
        }
        int i9 = 0;
        int i10 = 0;
        while (i9 < i5) {
            int i11 = i9;
            int i12 = i7;
            int i13 = i10;
            int i14 = 0;
            int i15 = 0;
            double d5 = d4;
            int i16 = 0;
            int i17 = 0;
            while (i6 < list.size()) {
                double d6 = d5;
                if (((Integer) asList.get(i6)).intValue() < i2) {
                    i16++;
                    for (int i18 = 0; i18 < asList3.size(); i18++) {
                        asList3.set(i18, valueOf);
                    }
                    double collins = list2.get(i6).collins(asList3);
                    int eval = list2.get(i6).eval();
                    i15 += eval;
                    if (eval != 0) {
                        i17++;
                    }
                    if (eval == 0) {
                        asList.set(i6, Integer.valueOf(((Integer) asList.get(i6)).intValue() + 1));
                    } else {
                        asList.set(i6, 0);
                        double d7 = Utils.DOUBLE_EPSILON;
                        for (int i19 = 0; i19 < asList3.size(); i19++) {
                            d7 += asList3.get(i19).doubleValue() * asList3.get(i19).doubleValue();
                        }
                        int i20 = i17;
                        double d8 = eval - collins;
                        double max = Math.max(Utils.DOUBLE_EPSILON, d8 / d7);
                        if (((Double) asList2.get(i6)).doubleValue() + max > d3) {
                            max = d3 - ((Double) asList2.get(i6)).doubleValue();
                            i14++;
                            d5 = d6;
                        } else {
                            d5 = Math.max(d8, d6);
                        }
                        if (max > 1.0E-10d) {
                            asList2.set(i6, Double.valueOf(((Double) asList2.get(i6)).doubleValue() + max));
                            asList2.set(i6, Double.valueOf(Math.min(d3, ((Double) asList2.get(i6)).doubleValue())));
                            for (int i21 = 0; i21 < asList3.size(); i21++) {
                                dArr[i21] = dArr[i21] + (asList3.get(i21).doubleValue() * max);
                            }
                        }
                        i16 = i16;
                        i15 = i15;
                        i17 = i20;
                        i6++;
                        list2 = list;
                    }
                }
                d5 = d6;
                i6++;
                list2 = list;
            }
            double d9 = Utils.DOUBLE_EPSILON;
            for (int i22 = 0; i22 < encoderFeatureIndex.size(); i22++) {
                d9 += dArr[i22] * dArr[i22];
            }
            StringBuilder sb = new StringBuilder();
            sb.append("iter=");
            sb.append(i11);
            sb.append(" terr=");
            List list3 = asList2;
            Double d10 = valueOf;
            List<Double> list4 = asList3;
            sb.append((i15 * 1.0d) / i12);
            sb.append(" serr=");
            sb.append((i17 * 1.0d) / list.size());
            sb.append(" act=");
            sb.append(i16);
            sb.append(" uact=");
            sb.append(i14);
            sb.append(" obj=");
            sb.append(d9);
            sb.append(" kkt=");
            sb.append(d5);
            System.out.println(sb.toString());
            if (d5 <= Utils.DOUBLE_EPSILON) {
                for (int i23 = 0; i23 < asList.size(); i23++) {
                    asList.set(i23, 0);
                }
                i10 = i13 + 1;
                i4 = i;
            } else {
                i4 = i;
                i10 = 0;
            }
            if (i11 > i4 || i10 == 2) {
                return true;
            }
            i9 = i11 + 1;
            d4 = 0.0d;
            i5 = i4;
            i7 = i12;
            asList2 = list3;
            valueOf = d10;
            asList3 = list4;
            i6 = 0;
            list2 = list;
            d3 = d;
        }
        return true;
    }
}
