package com.mayabot.nlp.fasttext.loss;

import com.lzy.okserver.download.DownloadInfo;
import com.mayabot.nlp.blas.Matrix;
import com.mayabot.nlp.blas.Vector;
import com.mayabot.nlp.common.IntArrayList;
import com.mayabot.nlp.fasttext.Model;
import com.mayabot.nlp.fasttext.ScoreIdPair;
import com.videogo.openapi.model.resp.GetCameraInfoListResp;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.comparisons.ComparisonsKt;
import kotlin.jvm.internal.Intrinsics;

/* compiled from: HierarchicalSoftmaxLoss.kt */
@Metadata(bv = {1, 0, 3}, d1 = {"\u0000x\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0016\n\u0002\b\u0002\n\u0002\u0010!\n\u0002\u0010\u0018\n\u0002\b\u0003\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0007\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018\u00002\u00020\u0001:\u0001,B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J>\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\r2\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\r2\u0006\u0010\u001c\u001a\u00020\u001a2\f\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u001e0\b2\u0006\u0010\u001f\u001a\u00020 H\u0002J0\u0010!\u001a\u00020\u001a2\u0006\u0010\"\u001a\u00020#2\u0006\u0010$\u001a\u00020\r2\u0006\u0010%\u001a\u00020&2\u0006\u0010'\u001a\u00020\u001a2\u0006\u0010(\u001a\u00020)H\u0016J2\u0010*\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\r2\u0006\u0010\u0019\u001a\u00020\u001a2\u0010\u0010\u001d\u001a\f\u0012\u0004\u0012\u00020\u001e0\bj\u0002`+2\u0006\u0010%\u001a\u00020&H\u0016R\u0017\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\t0\b¢\u0006\b\n\u0000\u001a\u0004\b\n\u0010\u000bR\u0011\u0010\f\u001a\u00020\r¢\u0006\b\n\u0000\u001a\u0004\b\u000e\u0010\u000fR\u0017\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00110\b¢\u0006\b\n\u0000\u001a\u0004\b\u0012\u0010\u000bR\u0017\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00140\b¢\u0006\b\n\u0000\u001a\u0004\b\u0015\u0010\u000b¨\u0006-"}, d2 = {"Lcom/mayabot/nlp/fasttext/loss/HierarchicalSoftmaxLoss;", "Lcom/mayabot/nlp/fasttext/loss/BinaryLogisticLoss;", "wo", "Lcom/mayabot/nlp/blas/Matrix;", "targetCounts", "", "(Lcom/mayabot/nlp/blas/Matrix;[J)V", "codes", "", "", "getCodes", "()Ljava/util/List;", "osz", "", "getOsz", "()I", "paths", "", "getPaths", "tree", "Lcom/mayabot/nlp/fasttext/loss/HierarchicalSoftmaxLoss$Node;", "getTree", "dfs", "", "k", "threshold", "", "node", "score", "heap", "Lcom/mayabot/nlp/fasttext/ScoreIdPair;", "hidden", "Lcom/mayabot/nlp/blas/Vector;", "forward", "targets", "Lcom/mayabot/nlp/common/IntArrayList;", "targetIndex", DownloadInfo.STATE, "Lcom/mayabot/nlp/fasttext/Model$State;", "lr", "backprop", "", "predict", "Lcom/mayabot/nlp/fasttext/Predictions;", "Node", "mynlp"}, k = 1, mv = {1, 4, 1})
/* loaded from: classes.dex */
public final class HierarchicalSoftmaxLoss extends BinaryLogisticLoss {
    private final List<boolean[]> codes;
    private final int osz;
    private final List<int[]> paths;
    private final List<Node> tree;

    /* compiled from: HierarchicalSoftmaxLoss.kt */
    @Metadata(bv = {1, 0, 3}, d1 = {"\u0000 \n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0000\n\u0002\u0010\t\n\u0000\n\u0002\u0010\b\n\u0002\b\u0003\u0018\u00002\u00020\u0001B\u0005¢\u0006\u0002\u0010\u0002R\u0012\u0010\u0003\u001a\u00020\u00048\u0006@\u0006X\u0087\u000e¢\u0006\u0002\n\u0000R\u0012\u0010\u0005\u001a\u00020\u00068\u0006@\u0006X\u0087\u000e¢\u0006\u0002\n\u0000R\u0012\u0010\u0007\u001a\u00020\b8\u0006@\u0006X\u0087\u000e¢\u0006\u0002\n\u0000R\u0012\u0010\t\u001a\u00020\b8\u0006@\u0006X\u0087\u000e¢\u0006\u0002\n\u0000R\u0012\u0010\n\u001a\u00020\b8\u0006@\u0006X\u0087\u000e¢\u0006\u0002\n\u0000¨\u0006\u000b"}, d2 = {"Lcom/mayabot/nlp/fasttext/loss/HierarchicalSoftmaxLoss$Node;", "", "()V", "binary", "", GetCameraInfoListResp.COUNT, "", "left", "", "parent", "right", "mynlp"}, k = 1, mv = {1, 4, 1})
    /* loaded from: classes.dex */
    public static final class Node {
        public boolean binary;
        public long count;
        public int left;
        public int parent;
        public int right;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public HierarchicalSoftmaxLoss(Matrix wo, long[] targetCounts) {
        super(wo);
        Intrinsics.checkNotNullParameter(wo, "wo");
        Intrinsics.checkNotNullParameter(targetCounts, "targetCounts");
        this.osz = targetCounts.length;
        int row = wo.getRow();
        ArrayList arrayList = new ArrayList(row);
        ArrayList arrayList2 = new ArrayList(row);
        int i = 1;
        int i2 = (row * 2) - 1;
        ArrayList arrayList3 = new ArrayList(i2);
        int i3 = 0;
        for (int i4 = 0; i4 < i2; i4++) {
            Node node = new Node();
            node.parent = -1;
            node.left = -1;
            node.right = -1;
            node.count = 1000000000000000L;
            node.binary = false;
            Unit unit = Unit.INSTANCE;
            arrayList3.add(node);
        }
        for (int i5 = 0; i5 < row; i5++) {
            ((Node) arrayList3.get(i5)).count = targetCounts[i5];
        }
        int i6 = row - 1;
        int i7 = row;
        int i8 = i7;
        while (i7 < i2) {
            int[] iArr = new int[2];
            int i9 = 0;
            while (i9 <= i) {
                if (i6 < 0 || ((Node) arrayList3.get(i6)).count >= ((Node) arrayList3.get(i8)).count) {
                    iArr[i9] = i8;
                    i8++;
                } else {
                    iArr[i9] = i6;
                    i6--;
                }
                i9++;
                i = 1;
            }
            Node node2 = (Node) arrayList3.get(i7);
            i3 = 0;
            node2.left = iArr[0];
            node2.right = iArr[1];
            node2.count = ((Node) arrayList3.get(iArr[0])).count + ((Node) arrayList3.get(iArr[1])).count;
            ((Node) arrayList3.get(iArr[0])).parent = i7;
            i = 1;
            ((Node) arrayList3.get(iArr[1])).parent = i7;
            ((Node) arrayList3.get(iArr[1])).binary = true;
            i7++;
            i8 = i8;
        }
        while (i3 < row) {
            ArrayList arrayList4 = new ArrayList();
            ArrayList arrayList5 = new ArrayList();
            for (int i10 = i3; ((Node) arrayList3.get(i10)).parent != -1; i10 = ((Node) arrayList3.get(i10)).parent) {
                arrayList4.add(Integer.valueOf(((Node) arrayList3.get(i10)).parent - row));
                arrayList5.add(Boolean.valueOf(((Node) arrayList3.get(i10)).binary));
            }
            arrayList.add(CollectionsKt.toIntArray(arrayList4));
            arrayList2.add(CollectionsKt.toBooleanArray(arrayList5));
            i3++;
        }
        this.paths = arrayList;
        this.codes = arrayList2;
        this.tree = arrayList3;
    }

    private final void dfs(int k, float threshold, int node, float score, List<ScoreIdPair> heap, Vector hidden) {
        if (score < LossKt.stdLog(threshold)) {
            return;
        }
        if (heap.size() != k || score >= heap.get(heap.size() - 1).getScore()) {
            if (this.tree.get(node).left != -1 || this.tree.get(node).right != -1) {
                float exp = 1.0f / (1 + ((float) Math.exp(-getWo().dotRow(hidden, node - this.osz))));
                dfs(k, threshold, this.tree.get(node).left, score + ((float) LossKt.stdLog(1.0f - exp)), heap, hidden);
                dfs(k, threshold, this.tree.get(node).right, score + ((float) LossKt.stdLog(exp)), heap, hidden);
                return;
            }
            heap.add(new ScoreIdPair(score, node));
            if (heap.size() > 1) {
                CollectionsKt.sortWith(heap, new Comparator<T>() { // from class: com.mayabot.nlp.fasttext.loss.HierarchicalSoftmaxLoss$dfs$$inlined$sortByDescending$1
                    /* JADX WARN: Multi-variable type inference failed */
                    @Override // java.util.Comparator
                    public final int compare(T t, T t2) {
                        return ComparisonsKt.compareValues(Float.valueOf(((ScoreIdPair) t2).getScore()), Float.valueOf(((ScoreIdPair) t).getScore()));
                    }
                });
            }
            if (heap.size() > k) {
                if (heap.size() > 1) {
                    CollectionsKt.sortWith(heap, new Comparator<T>() { // from class: com.mayabot.nlp.fasttext.loss.HierarchicalSoftmaxLoss$dfs$$inlined$sortByDescending$2
                        /* JADX WARN: Multi-variable type inference failed */
                        @Override // java.util.Comparator
                        public final int compare(T t, T t2) {
                            return ComparisonsKt.compareValues(Float.valueOf(((ScoreIdPair) t2).getScore()), Float.valueOf(((ScoreIdPair) t).getScore()));
                        }
                    });
                }
                heap.remove(heap.size() - 1);
            }
        }
    }

    @Override // com.mayabot.nlp.fasttext.loss.Loss
    public float forward(IntArrayList targets, int targetIndex, Model.State state, float lr, boolean backprop) {
        Intrinsics.checkNotNullParameter(targets, "targets");
        Intrinsics.checkNotNullParameter(state, "state");
        int i = targets.get(targetIndex);
        boolean[] zArr = this.codes.get(i);
        int[] iArr = this.paths.get(i);
        int length = iArr.length;
        float f = 0.0f;
        for (int i2 = 0; i2 < length; i2++) {
            f += binaryLogistic(iArr[i2], state, zArr[i2], lr, backprop);
        }
        return f;
    }

    public final List<boolean[]> getCodes() {
        return this.codes;
    }

    public final int getOsz() {
        return this.osz;
    }

    public final List<int[]> getPaths() {
        return this.paths;
    }

    public final List<Node> getTree() {
        return this.tree;
    }

    @Override // com.mayabot.nlp.fasttext.loss.Loss
    public void predict(int k, float threshold, List<ScoreIdPair> heap, Model.State state) {
        Intrinsics.checkNotNullParameter(heap, "heap");
        Intrinsics.checkNotNullParameter(state, "state");
        dfs(k, threshold, (this.osz * 2) - 2, 0.0f, heap, state.getHidden());
        if (heap.size() > 1) {
            CollectionsKt.sortWith(heap, new Comparator<T>() { // from class: com.mayabot.nlp.fasttext.loss.HierarchicalSoftmaxLoss$predict$$inlined$sortByDescending$1
                /* JADX WARN: Multi-variable type inference failed */
                @Override // java.util.Comparator
                public final int compare(T t, T t2) {
                    return ComparisonsKt.compareValues(Float.valueOf(((ScoreIdPair) t2).getScore()), Float.valueOf(((ScoreIdPair) t).getScore()));
                }
            });
        }
    }
}
