package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;

/* loaded from: classes.dex */
public class InstanceSegmentationTranslator extends BaseImageTranslator<DetectedObjects> implements Transform {
    private List<String> classes;
    private int maxEdge;
    private int rescaledHeight;
    private int rescaledWidth;
    private int shortEdge;
    private BaseImageTranslator.SynsetLoader synsetLoader;
    private float threshold;

    /* loaded from: classes.dex */
    public static class Builder extends BaseImageTranslator.ClassificationBuilder<Builder> {
        float threshold = 0.3f;
        int shortEdge = 600;
        int maxEdge = 1000;

        Builder() {
        }

        public InstanceSegmentationTranslator build() {
            validate();
            return new InstanceSegmentationTranslator(this);
        }

        public Builder optMaxEdge(int i) {
            this.maxEdge = i;
            return this;
        }

        public Builder optShortEdge(int i) {
            this.shortEdge = i;
            return this;
        }

        public Builder optThreshold(float f) {
            this.threshold = f;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }
    }

    public InstanceSegmentationTranslator(Builder builder) {
        super(builder);
        this.synsetLoader = builder.synsetLoader;
        this.threshold = builder.threshold;
        this.shortEdge = builder.shortEdge;
        this.maxEdge = builder.maxEdge;
    }

    public static Builder builder() {
        return new Builder();
    }

    private NDArray resizeShort(NDArray nDArray) {
        Shape shape = nDArray.getShape();
        int i = (int) shape.get(1);
        int i2 = (int) shape.get(0);
        int min = Math.min(i, i2);
        float f = this.shortEdge / min;
        float max = Math.max(i, i2);
        int round = Math.round(f * max);
        int i3 = this.maxEdge;
        if (round > i3) {
            f = i3 / max;
        }
        this.rescaledHeight = Math.round(i2 * f);
        int round2 = Math.round(i * f);
        this.rescaledWidth = round2;
        return NDImageUtils.resize(nDArray, round2, this.rescaledHeight);
    }

    @Override // ai.djl.translate.Translator
    public void prepare(NDManager nDManager, Model model) throws IOException {
        if (this.classes == null) {
            this.classes = this.synsetLoader.load(model);
        }
    }

    @Override // ai.djl.modality.cv.translator.BaseImageTranslator, ai.djl.translate.PreProcessor
    public NDList processInput(TranslatorContext translatorContext, Image image) {
        getPipeline().insert(0, (String) null, this);
        translatorContext.setAttachment("originalHeight", Integer.valueOf(image.getHeight()));
        translatorContext.setAttachment("originalWidth", Integer.valueOf(image.getWidth()));
        return super.processInput(translatorContext, image);
    }

    @Override // ai.djl.translate.PostProcessor
    public DetectedObjects processOutput(TranslatorContext translatorContext, NDList nDList) {
        char c;
        float[] fArr;
        float[] fArr2;
        char c2;
        NDArray nDArray;
        ArrayList arrayList;
        int i;
        ArrayList arrayList2;
        InstanceSegmentationTranslator instanceSegmentationTranslator = this;
        char c3 = 0;
        float[] floatArray = nDList.get(0).toFloatArray();
        int i2 = 1;
        float[] floatArray2 = nDList.get(1).toFloatArray();
        char c4 = 2;
        NDArray nDArray2 = nDList.get(2);
        NDArray nDArray3 = nDList.get(3);
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        int i3 = 0;
        while (i3 < floatArray.length) {
            int i4 = (int) floatArray[i3];
            ArrayList arrayList6 = arrayList3;
            double d = floatArray2[i3];
            if (i4 >= 0) {
                nDArray = nDArray2;
                if (d <= instanceSegmentationTranslator.threshold) {
                    c = c3;
                    fArr = floatArray;
                    fArr2 = floatArray2;
                    arrayList = arrayList5;
                    i = i3;
                    c2 = 2;
                } else {
                    if (i4 >= instanceSegmentationTranslator.classes.size()) {
                        throw new AssertionError("Unexpected index: " + i4);
                    }
                    String str = instanceSegmentationTranslator.classes.get(i4);
                    long[] jArr = new long[i2];
                    long j = i3;
                    jArr[c3] = j;
                    float[] floatArray3 = nDArray.get(jArr).toFloatArray();
                    fArr = floatArray;
                    float f = floatArray3[c3] / instanceSegmentationTranslator.rescaledWidth;
                    nDArray = nDArray;
                    double d2 = f;
                    fArr2 = floatArray2;
                    ArrayList arrayList7 = arrayList5;
                    i = i3;
                    double d3 = floatArray3[1] / instanceSegmentationTranslator.rescaledHeight;
                    double d4 = (floatArray3[2] / r3) - d2;
                    double d5 = (floatArray3[3] / r6) - d3;
                    int intValue = (int) (((Integer) translatorContext.getAttachment("originalWidth")).intValue() * d4);
                    int intValue2 = (int) (((Integer) translatorContext.getAttachment("originalHeight")).intValue() * d5);
                    c = 0;
                    NDArray nDArray4 = nDArray3.get(j);
                    float[] floatArray4 = NDImageUtils.resize(nDArray4.reshape(nDArray4.getShape().addAll(new Shape(1))), intValue, intValue2).transpose().toFloatArray();
                    c2 = 2;
                    float[][] fArr3 = (float[][]) Array.newInstance((Class<?>) float.class, intValue, intValue2);
                    for (int i5 = 0; i5 < intValue; i5++) {
                        System.arraycopy(floatArray4, i5 * intValue2, fArr3[i5], 0, intValue2);
                    }
                    Mask mask = new Mask(d2, d3, d4, d5, fArr3);
                    arrayList2 = arrayList6;
                    arrayList2.add(str);
                    arrayList4.add(Double.valueOf(d));
                    arrayList = arrayList7;
                    arrayList.add(mask);
                    i3 = i + 1;
                    instanceSegmentationTranslator = this;
                    arrayList5 = arrayList;
                    c4 = c2;
                    nDArray2 = nDArray;
                    floatArray = fArr;
                    floatArray2 = fArr2;
                    i2 = 1;
                    char c5 = c;
                    arrayList3 = arrayList2;
                    c3 = c5;
                }
            } else {
                c = c3;
                fArr = floatArray;
                fArr2 = floatArray2;
                c2 = c4;
                nDArray = nDArray2;
                arrayList = arrayList5;
                i = i3;
            }
            arrayList2 = arrayList6;
            i3 = i + 1;
            instanceSegmentationTranslator = this;
            arrayList5 = arrayList;
            c4 = c2;
            nDArray2 = nDArray;
            floatArray = fArr;
            floatArray2 = fArr2;
            i2 = 1;
            char c52 = c;
            arrayList3 = arrayList2;
            c3 = c52;
        }
        return new DetectedObjects(arrayList3, arrayList4, arrayList5);
    }

    @Override // ai.djl.translate.Transform
    public NDArray transform(NDArray nDArray) {
        return resizeShort(nDArray);
    }
}
