package com.hankcs.hanlp.model.perceptron.model;

import com.bumptech.glide.load.Key;
import com.github.mikephil.charting.utils.Utils;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.algorithm.MaxHeap;
import com.hankcs.hanlp.classification.utilities.io.ConsoleLogger;
import com.hankcs.hanlp.classification.utilities.io.ILogger;
import com.hankcs.hanlp.collection.trie.datrie.MutableDoubleArrayTrieInteger;
import com.hankcs.hanlp.corpus.io.ByteArray;
import com.hankcs.hanlp.corpus.io.ByteArrayStream;
import com.hankcs.hanlp.corpus.io.ICacheAble;
import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.model.perceptron.common.TaskType;
import com.hankcs.hanlp.model.perceptron.feature.FeatureMap;
import com.hankcs.hanlp.model.perceptron.feature.FeatureSortItem;
import com.hankcs.hanlp.model.perceptron.feature.ImmutableFeatureMDatMap;
import com.hankcs.hanlp.model.perceptron.instance.Instance;
import com.hankcs.hanlp.model.perceptron.tagset.TagSet;
import com.hankcs.hanlp.utility.MathUtility;
import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.lang.reflect.Array;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: classes2.dex */
public class LinearModel implements ICacheAble {
    static final /* synthetic */ boolean $assertionsDisabled = false;
    public FeatureMap featureMap;
    public float[] parameter;

    public LinearModel(FeatureMap featureMap) {
        this.featureMap = featureMap;
        this.parameter = new float[featureMap.size() * featureMap.tagSet.size()];
    }

    public LinearModel(FeatureMap featureMap, float[] fArr) {
        this.featureMap = featureMap;
        this.parameter = fArr;
    }

    public LinearModel(String str) throws IOException {
        load(str);
    }

    public LinearModel compress(double d) {
        return compress(d, 0.0010000000474974513d);
    }

    public LinearModel compress(double d, double d2) {
        int i;
        MaxHeap maxHeap;
        if (d < Utils.DOUBLE_EPSILON || d >= 1.0d) {
            throw new IllegalArgumentException("压缩比必须介于 0 和 1 之间");
        }
        if (d == Utils.DOUBLE_EPSILON) {
            return this;
        }
        Set<Map.Entry<String, Integer>> entrySet = this.featureMap.entrySet();
        TagSet tagSet = this.featureMap.tagSet;
        MaxHeap maxHeap2 = new MaxHeap((int) ((entrySet.size() - tagSet.sizeIncludingBos()) * (1.0d - d)), new Comparator<FeatureSortItem>() { // from class: com.hankcs.hanlp.model.perceptron.model.LinearModel.1
            @Override // java.util.Comparator
            public int compare(FeatureSortItem featureSortItem, FeatureSortItem featureSortItem2) {
                return Float.compare(featureSortItem.total, featureSortItem2.total);
            }
        });
        ConsoleLogger.logger.start("裁剪特征...\n", new Object[0]);
        int ceil = (int) Math.ceil(this.featureMap.size() / 10000.0f);
        Iterator<Map.Entry<String, Integer>> it = entrySet.iterator();
        int i2 = 0;
        while (true) {
            i = 1;
            if (!it.hasNext()) {
                break;
            }
            Map.Entry<String, Integer> next = it.next();
            i2++;
            if (i2 % ceil == 0 || i2 == this.featureMap.size()) {
                ConsoleLogger.logger.out("\r%.2f%% ", Double.valueOf(MathUtility.percentage(i2, this.featureMap.size())));
            }
            if (next.getValue().intValue() >= tagSet.sizeIncludingBos()) {
                FeatureSortItem featureSortItem = new FeatureSortItem(next, this.parameter, tagSet.size());
                if (featureSortItem.total >= d2) {
                    maxHeap2.add(featureSortItem);
                }
            }
        }
        ConsoleLogger.logger.finish("\n裁剪完毕\n", new Object[0]);
        float[] fArr = new float[(maxHeap2.size() + tagSet.sizeIncludingBos()) * tagSet.size()];
        MutableDoubleArrayTrieInteger mutableDoubleArrayTrieInteger = new MutableDoubleArrayTrieInteger();
        Iterator<Map.Entry<String, Integer>> it2 = tagSet.iterator();
        while (it2.hasNext()) {
            mutableDoubleArrayTrieInteger.add("BL=" + it2.next().getKey());
        }
        mutableDoubleArrayTrieInteger.add("BL=_BL_");
        for (int i3 = 0; i3 < tagSet.size() * tagSet.sizeIncludingBos(); i3++) {
            fArr[i3] = this.parameter[i3];
        }
        ConsoleLogger.logger.start("构建双数组trie树...\n", new Object[0]);
        int ceil2 = (int) Math.ceil(maxHeap2.size() / 10000.0f);
        Iterator it3 = maxHeap2.iterator();
        int i4 = 0;
        while (it3.hasNext()) {
            FeatureSortItem featureSortItem2 = (FeatureSortItem) it3.next();
            i4 += i;
            if (i4 % ceil2 == 0 || i4 == maxHeap2.size()) {
                ILogger iLogger = ConsoleLogger.logger;
                Object[] objArr = new Object[i];
                maxHeap = maxHeap2;
                objArr[0] = Double.valueOf(MathUtility.percentage(i4, maxHeap2.size()));
                iLogger.out("\r%.2f%% ", objArr);
            } else {
                maxHeap = maxHeap2;
            }
            int size = mutableDoubleArrayTrieInteger.size();
            mutableDoubleArrayTrieInteger.put(featureSortItem2.key, size);
            for (int i5 = 0; i5 < tagSet.size(); i5++) {
                fArr[(tagSet.size() * size) + i5] = this.parameter[(featureSortItem2.id.intValue() * tagSet.size()) + i5];
            }
            maxHeap2 = maxHeap;
            i = 1;
        }
        ConsoleLogger.logger.finish("\n构建完毕\n", new Object[0]);
        this.featureMap = new ImmutableFeatureMDatMap(mutableDoubleArrayTrieInteger, tagSet);
        this.parameter = fArr;
        return this;
    }

    public int decode(Collection<Integer> collection) {
        Iterator<Integer> it = collection.iterator();
        float f = 0.0f;
        while (it.hasNext()) {
            f += this.parameter[it.next().intValue()];
        }
        return f < 0.0f ? -1 : 1;
    }

    public void load(String str) throws IOException {
        if (HanLP.Config.DEBUG) {
            ConsoleLogger.logger.start("加载 %s ... ", str);
        }
        if (!load(ByteArrayStream.createByteArrayStream(str))) {
            throw new IOException(String.format("%s 加载失败", str));
        }
        if (HanLP.Config.DEBUG) {
            ConsoleLogger.logger.finish(" 加载完毕\n", new Object[0]);
        }
    }

    @Override // com.hankcs.hanlp.corpus.io.ICacheAble
    public boolean load(ByteArray byteArray) {
        if (byteArray == null) {
            return false;
        }
        ImmutableFeatureMDatMap immutableFeatureMDatMap = new ImmutableFeatureMDatMap();
        this.featureMap = immutableFeatureMDatMap;
        immutableFeatureMDatMap.load(byteArray);
        int size = this.featureMap.size();
        TagSet tagSet = this.featureMap.tagSet;
        if (tagSet.type == TaskType.CLASSIFICATION) {
            this.parameter = new float[size];
            for (int i = 0; i < size; i++) {
                this.parameter[i] = byteArray.nextFloat();
            }
        } else {
            this.parameter = new float[tagSet.size() * size];
            for (int i2 = 0; i2 < size; i2++) {
                for (int i3 = 0; i3 < tagSet.size(); i3++) {
                    this.parameter[(tagSet.size() * i2) + i3] = byteArray.nextFloat();
                }
            }
        }
        if (byteArray.hasMore()) {
            return true;
        }
        byteArray.close();
        return true;
    }

    @Override // com.hankcs.hanlp.corpus.io.ICacheAble
    public void save(DataOutputStream dataOutputStream) throws IOException {
        if (!(this.featureMap instanceof ImmutableFeatureMDatMap)) {
            this.featureMap = new ImmutableFeatureMDatMap(this.featureMap.entrySet(), tagSet());
        }
        this.featureMap.save(dataOutputStream);
        for (float f : this.parameter) {
            dataOutputStream.writeFloat(f);
        }
    }

    public void save(String str) throws IOException {
        DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(IOUtil.newOutputStream(str)));
        save(dataOutputStream);
        dataOutputStream.close();
    }

    public void save(String str, double d) throws IOException {
        save(str, this.featureMap.entrySet(), d);
    }

    public void save(String str, Set<Map.Entry<String, Integer>> set, double d) throws IOException {
        save(str, set, d, false);
    }

    public void save(String str, Set<Map.Entry<String, Integer>> set, double d, boolean z) throws IOException {
        float[] fArr = this.parameter;
        compress(d, 0.0010000000474974513d);
        DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(IOUtil.newOutputStream(str)));
        save(dataOutputStream);
        dataOutputStream.close();
        if (z) {
            BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(IOUtil.newOutputStream(str + ".txt"), Key.STRING_CHARSET_NAME));
            TagSet tagSet = this.featureMap.tagSet;
            for (Map.Entry<String, Integer> entry : set) {
                bufferedWriter.write(entry.getKey());
                if (set.size() == fArr.length) {
                    bufferedWriter.write("\t");
                    bufferedWriter.write(String.valueOf(fArr[entry.getValue().intValue()]));
                } else {
                    for (int i = 0; i < tagSet.size(); i++) {
                        bufferedWriter.write("\t");
                        bufferedWriter.write(String.valueOf(fArr[(entry.getValue().intValue() * tagSet.size()) + i]));
                    }
                }
                bufferedWriter.newLine();
            }
            bufferedWriter.close();
        }
    }

    public double score(int[] iArr, int i) {
        double d = Utils.DOUBLE_EPSILON;
        for (int i2 : iArr) {
            if (i2 != -1) {
                if (i2 < -1 || i2 >= this.featureMap.size()) {
                    throw new IllegalArgumentException("在打分时传入了非法的下标");
                }
                d += this.parameter[(i2 * this.featureMap.tagSet.size()) + i];
            }
        }
        return d;
    }

    public TagSet tagSet() {
        return this.featureMap.tagSet;
    }

    public TaskType taskType() {
        return this.featureMap.tagSet.type;
    }

    public void update(Collection<Integer> collection, int i) {
        for (Integer num : collection) {
            float[] fArr = this.parameter;
            int intValue = num.intValue();
            fArr[intValue] = fArr[intValue] + i;
        }
    }

    public double viterbiDecode(Instance instance) {
        return viterbiDecode(instance, instance.tagArray);
    }

    public double viterbiDecode(Instance instance, int[] iArr) {
        int[] allLabels = this.featureMap.allLabels();
        int bosTag = this.featureMap.bosTag();
        int length = instance.tagArray.length;
        int length2 = allLabels.length;
        int i = 1;
        int i2 = 0;
        int[][] iArr2 = (int[][]) Array.newInstance((Class<?>) int.class, length, length2);
        double[][] dArr = (double[][]) Array.newInstance((Class<?>) double.class, 2, length2);
        int i3 = 0;
        while (i3 < length) {
            int i4 = i3 & 1;
            int i5 = 1 - i4;
            int[] featureAt = instance.getFeatureAt(i3);
            int length3 = featureAt.length - i;
            if (i3 == 0) {
                featureAt[length3] = bosTag;
                for (int i6 = i2; i6 < allLabels.length; i6++) {
                    iArr2[i2][i6] = i6;
                    dArr[i2][i6] = score(featureAt, i6);
                }
            } else {
                int i7 = i2;
                while (i7 < allLabels.length) {
                    double d = -2.147483648E9d;
                    while (i2 < allLabels.length) {
                        featureAt[length3] = i2;
                        double score = dArr[i5][i2] + score(featureAt, i7);
                        if (d < score) {
                            iArr2[i3][i7] = i2;
                            dArr[i4][i7] = score;
                            d = score;
                        }
                        i2++;
                    }
                    i7++;
                    i2 = 0;
                }
            }
            i3++;
            i = 1;
            i2 = 0;
        }
        int i8 = i;
        int i9 = length - i8;
        int i10 = i9 & 1;
        double d2 = dArr[i10][0];
        int i11 = 0;
        for (int i12 = i8; i12 < allLabels.length; i12++) {
            if (d2 < dArr[i10][i12]) {
                d2 = dArr[i10][i12];
                i11 = i12;
            }
        }
        while (i9 >= 0) {
            iArr[i9] = allLabels[i11];
            i11 = iArr2[i9][i11];
            i9--;
        }
        return d2;
    }
}
