package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.recurrent.RecurrentBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Iterator;

/* loaded from: classes.dex */
public class LSTM extends RecurrentBlock {
    private NDArray beginStateCell;
    private boolean clipLstmState;
    private double lstmStateClipMax;
    private double lstmStateClipMin;

    /* loaded from: classes.dex */
    public static final class Builder extends RecurrentBlock.BaseBuilder<Builder> {
        public LSTM build() {
            if (this.stateSize == -1 || this.numStackedLayers == -1) {
                throw new IllegalArgumentException("Must set stateSize and numStackedLayers");
            }
            return new LSTM(this);
        }

        public Builder optLstmStateClipMin(float f, float f2) {
            this.lstmStateClipMin = f;
            this.lstmStateClipMax = f2;
            this.clipLstmState = true;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.recurrent.RecurrentBlock.BaseBuilder
        public Builder self() {
            return this;
        }
    }

    LSTM(Builder builder) {
        super(builder);
        this.mode = "lstm";
        this.gates = 4;
        this.clipLstmState = builder.clipLstmState;
        this.lstmStateClipMin = builder.lstmStateClipMin;
        this.lstmStateClipMax = builder.lstmStateClipMax;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // ai.djl.nn.recurrent.RecurrentBlock, ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList);
        NDArrayEx nDArrayInternal = opInputs.head().getNDArrayInternal();
        NDList lstm = this.clipLstmState ? nDArrayInternal.lstm(opInputs, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, isBidirectional(), true, this.lstmStateClipMin, this.lstmStateClipMax, pairList) : nDArrayInternal.rnn(opInputs, this.mode, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, isBidirectional(), true, pairList);
        NDList nDList2 = new NDList(lstm.head().transpose(1, 0, 2));
        if (this.stateOutputs) {
            nDList2.add(lstm.get(1));
            nDList2.add(lstm.get(2));
        }
        resetBeginStates();
        return nDList2;
    }

    @Override // ai.djl.nn.recurrent.RecurrentBlock
    protected NDList opInputs(ParameterStore parameterStore, NDList nDList) {
        validateInputSize(nDList);
        long j = nDList.head().getShape().get(0);
        NDList updateInputLayoutToTNC = updateInputLayoutToTNC(nDList);
        NDArray singletonOrThrow = updateInputLayoutToTNC.singletonOrThrow();
        Device device = singletonOrThrow.getDevice();
        NDList nDList2 = new NDList(singletonOrThrow);
        NDList nDList3 = new NDList();
        try {
            Iterator<Parameter> it = this.parameters.values().iterator();
            while (it.hasNext()) {
                nDList3.add(parameterStore.getValue(it.next(), device).flatten());
            }
            nDList2.add(NDArrays.concat(nDList3));
            nDList3.close();
            Shape shape = new Shape(this.numStackedLayers * this.numDirections, j, this.stateSize);
            if (this.beginState != null) {
                nDList2.add(this.beginState);
                nDList2.add(this.beginStateCell);
            } else {
                nDList2.add(singletonOrThrow.getManager().zeros(shape, DataType.FLOAT32, device));
                nDList2.add(singletonOrThrow.getManager().zeros(shape, DataType.FLOAT32, device));
            }
            if (this.useSequenceLength) {
                nDList2.add(updateInputLayoutToTNC.get(1));
            }
            return nDList2;
        } catch (Throwable th) {
            try {
                throw th;
            } catch (Throwable th2) {
                if (th != null) {
                    try {
                        nDList3.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                } else {
                    nDList3.close();
                }
                throw th2;
            }
        }
    }

    @Override // ai.djl.nn.recurrent.RecurrentBlock
    protected void resetBeginStates() {
        this.beginState = null;
        this.beginStateCell = null;
    }

    @Override // ai.djl.nn.recurrent.RecurrentBlock
    public void setBeginStates(NDList nDList) {
        this.beginState = nDList.get(0);
        this.beginStateCell = nDList.get(1);
    }
}
