package com.hankcs.hanlp.model.perceptron;

import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.model.perceptron.common.TaskType;
import com.hankcs.hanlp.model.perceptron.feature.FeatureMap;
import com.hankcs.hanlp.model.perceptron.feature.LockableFeatureMap;
import com.hankcs.hanlp.model.perceptron.model.AveragedPerceptron;
import com.hankcs.hanlp.model.perceptron.model.LinearModel;
import com.hankcs.hanlp.model.perceptron.tagset.TagSet;
import com.hankcs.hanlp.model.perceptron.utility.Utility;
import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: classes2.dex */
public abstract class PerceptronClassifier {
    public LinearModel model;

    /* loaded from: classes2.dex */
    public static class BinaryClassificationFMeasure {
        public float F1;
        public float P;
        public float R;

        public BinaryClassificationFMeasure(float f2, float f3, float f4) {
            this.P = f2;
            this.R = f3;
            this.F1 = f4;
        }

        public String toString() {
            return String.format("P=%.2f R=%.2f F1=%.2f", Float.valueOf(this.P), Float.valueOf(this.R), Float.valueOf(this.F1));
        }
    }

    /* loaded from: classes2.dex */
    public static class Instance {
        public List<Integer> x;
        public int y;

        public Instance(List<Integer> list, int i2) {
            this.x = list;
            this.y = i2;
        }
    }

    public PerceptronClassifier() {
    }

    public PerceptronClassifier(LinearModel linearModel) {
        if (linearModel != null && linearModel.taskType() != TaskType.CLASSIFICATION) {
            throw new IllegalArgumentException("传入的模型并非分类模型");
        }
        this.model = linearModel;
    }

    public PerceptronClassifier(String str) throws IOException {
        this(new LinearModel(str));
    }

    public static void addFeature(String str, FeatureMap featureMap, List<Integer> list) {
        int idOf = featureMap.idOf(str);
        if (idOf != -1) {
            list.add(Integer.valueOf(idOf));
        }
    }

    private Instance[] readInstance(String str, FeatureMap featureMap) {
        IOUtil.LineIterator lineIterator = new IOUtil.LineIterator(str);
        LinkedList linkedList = new LinkedList();
        Iterator<String> it = lineIterator.iterator();
        while (it.hasNext()) {
            String[] split = it.next().split(",");
            String str2 = split[0];
            String str3 = split[1];
            List<Integer> extractFeature = extractFeature(str2, featureMap);
            int add = featureMap.tagSet.add(str3);
            if (add == 0) {
                add = -1;
            } else if (add > 1) {
                throw new IllegalArgumentException("类别数大于2，目前只支持二分类。");
            }
            linkedList.add(new Instance(extractFeature, add));
        }
        return (Instance[]) linkedList.toArray(new Instance[0]);
    }

    private static LinearModel trainAveragedPerceptron(Instance[] instanceArr, FeatureMap featureMap, int i2) {
        float[] fArr = new float[featureMap.size()];
        double[] dArr = new double[featureMap.size()];
        int[] iArr = new int[featureMap.size()];
        AveragedPerceptron averagedPerceptron = new AveragedPerceptron(featureMap, fArr);
        int i3 = 0;
        for (int i4 = 0; i4 < i2; i4++) {
            Utility.shuffleArray(instanceArr);
            int length = instanceArr.length;
            int i5 = 0;
            while (i5 < length) {
                Instance instance = instanceArr[i5];
                int i6 = i3 + 1;
                int decode = averagedPerceptron.decode(instance.x);
                int i7 = instance.y;
                if (decode != i7) {
                    averagedPerceptron.update(instance.x, i7, dArr, iArr, i6);
                }
                i5++;
                i3 = i6;
            }
        }
        averagedPerceptron.average(dArr, iArr, i3);
        return averagedPerceptron;
    }

    private static LinearModel trainNaivePerceptron(Instance[] instanceArr, FeatureMap featureMap, int i2) {
        LinearModel linearModel = new LinearModel(featureMap, new float[featureMap.size()]);
        for (int i3 = 0; i3 < i2; i3++) {
            Utility.shuffleArray(instanceArr);
            for (Instance instance : instanceArr) {
                int decode = linearModel.decode(instance.x);
                int i4 = instance.y;
                if (decode != i4) {
                    linearModel.update(instance.x, i4);
                }
            }
        }
        return linearModel;
    }

    public BinaryClassificationFMeasure evaluate(String str) {
        return evaluate(readInstance(str, this.model.featureMap));
    }

    public BinaryClassificationFMeasure evaluate(Instance[] instanceArr) {
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (Instance instance : instanceArr) {
            if (this.model.decode(instance.x) == 1) {
                if (instance.y == 1) {
                    i2++;
                } else {
                    i3++;
                }
            } else if (instance.y == 1) {
                i4++;
            }
        }
        float f2 = i2;
        float f3 = (f2 / (i3 + i2)) * 100.0f;
        float f4 = (f2 / (i2 + i4)) * 100.0f;
        return new BinaryClassificationFMeasure(f3, f4, ((2.0f * f3) * f4) / (f3 + f4));
    }

    public abstract List<Integer> extractFeature(String str, FeatureMap featureMap);

    public LinearModel getModel() {
        return this.model;
    }

    public String predict(String str) {
        LinearModel linearModel = this.model;
        int decode = linearModel.decode(extractFeature(str, linearModel.featureMap));
        if (decode == -1) {
            decode = 0;
        }
        return this.model.tagSet().stringOf(decode);
    }

    public BinaryClassificationFMeasure train(String str, int i2) {
        return train(str, i2, true);
    }

    public BinaryClassificationFMeasure train(String str, int i2, boolean z) {
        LockableFeatureMap lockableFeatureMap = new LockableFeatureMap(new TagSet(TaskType.CLASSIFICATION));
        lockableFeatureMap.mutable = true;
        Instance[] readInstance = readInstance(str, lockableFeatureMap);
        this.model = z ? trainAveragedPerceptron(readInstance, lockableFeatureMap, i2) : trainNaivePerceptron(readInstance, lockableFeatureMap, i2);
        lockableFeatureMap.mutable = false;
        return evaluate(readInstance);
    }
}
