package ai.djl.nn.pooling;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import java.util.function.Function;

/* loaded from: classes.dex */
public final class Pool {
    private Pool() {
    }

    public static NDArray avgPool(NDArray nDArray, Shape shape, Shape shape2, Shape shape3) {
        return avgPool(nDArray, shape, shape2, shape3, PoolingConvention.VALID, true);
    }

    private static NDArray avgPool(NDArray nDArray, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention, boolean z) {
        return nDArray.getNDArrayInternal().avgPool(shape, shape2, shape3, poolingConvention, z);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static NDList avgPool(NDList nDList, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention, boolean z) {
        return new NDList(avgPool(nDList.singletonOrThrow(), shape, shape2, shape3, poolingConvention, z));
    }

    public static Block avgPool1DBlock(Shape shape) {
        return avgPool1DBlock(shape, shape, new Shape(0), PoolingConvention.VALID, true);
    }

    public static Block avgPool1DBlock(Shape shape, Shape shape2) {
        return avgPool1DBlock(shape, shape2, new Shape(0), PoolingConvention.VALID, true);
    }

    public static Block avgPool1DBlock(Shape shape, Shape shape2, Shape shape3) {
        return avgPool1DBlock(shape, shape2, shape3, PoolingConvention.VALID, true);
    }

    public static Block avgPool1DBlock(Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention) {
        return avgPool1DBlock(shape, shape2, shape3, poolingConvention, true);
    }

    public static Block avgPool1DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention, final boolean z) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for avgPool1DBlock Block");
        }
        if (shape.dimension() == 1 && shape2.dimension() == 1 && shape3.dimension() == 1) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$IPMd1rn-KMbqCHe0QigXYhOc7sU
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList avgPool;
                    avgPool = Pool.avgPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention, z);
                    return avgPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for avgPool1DBlock layer should be 1");
    }

    public static Block avgPool2DBlock(Shape shape) {
        return avgPool2DBlock(shape, shape, new Shape(0, 0), PoolingConvention.VALID, true);
    }

    public static Block avgPool2DBlock(Shape shape, Shape shape2) {
        return avgPool2DBlock(shape, shape2, new Shape(0, 0), PoolingConvention.VALID, true);
    }

    public static Block avgPool2DBlock(Shape shape, Shape shape2, Shape shape3) {
        return avgPool2DBlock(shape, shape2, shape3, PoolingConvention.VALID, true);
    }

    public static Block avgPool2DBlock(Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention) {
        return avgPool2DBlock(shape, shape2, shape3, poolingConvention, true);
    }

    public static Block avgPool2DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention, final boolean z) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for avgPool3DBlock Block");
        }
        if (shape.dimension() == 2 && shape2.dimension() == 2 && shape3.dimension() == 2) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$3IUp63iUrwpLSx9lOaYsV6JsL5M
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList avgPool;
                    avgPool = Pool.avgPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention, z);
                    return avgPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for avgPool2DBlock layer should be 2");
    }

    public static Block avgPool3DBlock(Shape shape) {
        return avgPool3DBlock(shape, shape, new Shape(0, 0, 0), PoolingConvention.VALID, true);
    }

    public static Block avgPool3DBlock(Shape shape, Shape shape2) {
        return avgPool3DBlock(shape, shape2, new Shape(0, 0, 0), PoolingConvention.VALID, true);
    }

    public static Block avgPool3DBlock(Shape shape, Shape shape2, Shape shape3) {
        return avgPool3DBlock(shape, shape2, shape3, PoolingConvention.VALID, true);
    }

    public static Block avgPool3DBlock(Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention) {
        return avgPool3DBlock(shape, shape2, shape3, poolingConvention, true);
    }

    public static Block avgPool3DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention, final boolean z) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for avgPool3DBlock Block");
        }
        if (shape.dimension() == 3 && shape2.dimension() == 3 && shape3.dimension() == 3) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$cm1y_u_3LuNYTgC_YVfvHAFSkyg
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList avgPool;
                    avgPool = Pool.avgPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention, z);
                    return avgPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for avgPool3DBlock layer should be 3");
    }

    private static NDArray globalAvgPool(NDArray nDArray) {
        return nDArray.getNDArrayInternal().globalAvgPool();
    }

    public static Block globalAvgPool1DBlock() {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$ZfE31KIlqT-J76IYck44W6ZP0Ag
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalAvgPool1DBlock$9((NDList) obj);
            }
        });
    }

    public static Block globalAvgPool2DBlock() {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$190i721IVcoF0SHzAH7makKciu8
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalAvgPool2DBlock$10((NDList) obj);
            }
        });
    }

    public static Block globalAvgPool3DBlock() {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$RkTlnqhhW-RcpO40AtSHceALgq8
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalAvgPool3DBlock$11((NDList) obj);
            }
        });
    }

    private static NDArray globalLpPool(NDArray nDArray, int i) {
        return nDArray.getNDArrayInternal().globalLpPool(i);
    }

    public static Block globalLpPool1DBlock(final int i) {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$FQgdbIpT5d79KRWapvDSREzsEr0
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalLpPool1DBlock$15(i, (NDList) obj);
            }
        });
    }

    public static Block globalLpPool2DBlock(final int i) {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$BpByWIuD0oMU1wkUB8Nw0VDCdr4
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalLpPool2DBlock$16(i, (NDList) obj);
            }
        });
    }

    public static Block globalLpPool3DBlock(final int i) {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$JUcV_6lyP8-CKF7eMXg-YEsxkZY
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalLpPool3DBlock$17(i, (NDList) obj);
            }
        });
    }

    private static NDArray globalMaxPool(NDArray nDArray) {
        return nDArray.getNDArrayInternal().globalMaxPool();
    }

    public static Block globalMaxPool1DBlock() {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$9tWSG3guJHfuTX_PAAId2XWmnCo
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalMaxPool1DBlock$3((NDList) obj);
            }
        });
    }

    public static Block globalMaxPool2DBlock() {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$RL-UxHNOEdWWygc6hf8zcrKh__c
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalMaxPool2DBlock$4((NDList) obj);
            }
        });
    }

    public static Block globalMaxPool3DBlock() {
        return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$LGUwUvAsRtY0ZcKeGTA8A_AQ23o
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Pool.lambda$globalMaxPool3DBlock$5((NDList) obj);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalAvgPool1DBlock$9(NDList nDList) {
        return new NDList(globalAvgPool(nDList.singletonOrThrow()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalAvgPool2DBlock$10(NDList nDList) {
        return new NDList(globalAvgPool(nDList.singletonOrThrow()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalAvgPool3DBlock$11(NDList nDList) {
        return new NDList(globalAvgPool(nDList.singletonOrThrow()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalLpPool1DBlock$15(int i, NDList nDList) {
        return new NDList(globalLpPool(nDList.singletonOrThrow(), i));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalLpPool2DBlock$16(int i, NDList nDList) {
        return new NDList(globalLpPool(nDList.singletonOrThrow(), i));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalLpPool3DBlock$17(int i, NDList nDList) {
        return new NDList(globalLpPool(nDList.singletonOrThrow(), i));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalMaxPool1DBlock$3(NDList nDList) {
        return new NDList(globalMaxPool(nDList.singletonOrThrow()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalMaxPool2DBlock$4(NDList nDList) {
        return new NDList(globalMaxPool(nDList.singletonOrThrow()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$globalMaxPool3DBlock$5(NDList nDList) {
        return new NDList(globalMaxPool(nDList.singletonOrThrow()));
    }

    private static NDArray lpPool(NDArray nDArray, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention, int i) {
        return nDArray.getNDArrayInternal().lpPool(shape, shape2, shape3, poolingConvention, i);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static NDList lpPool(NDList nDList, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention, int i) {
        return new NDList(lpPool(nDList.singletonOrThrow(), shape, shape2, shape3, poolingConvention, i));
    }

    public static Block lpPool1DBlock(Shape shape, int i) {
        return lpPool1DBlock(shape, shape, new Shape(0), PoolingConvention.VALID, i);
    }

    public static Block lpPool1DBlock(Shape shape, Shape shape2, int i) {
        return lpPool1DBlock(shape, shape2, new Shape(0), PoolingConvention.VALID, i);
    }

    public static Block lpPool1DBlock(Shape shape, Shape shape2, Shape shape3, int i) {
        return lpPool1DBlock(shape, shape2, shape3, PoolingConvention.VALID, i);
    }

    public static Block lpPool1DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention, final int i) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for lpPool1D Block");
        }
        if (shape.dimension() == 1 && shape2.dimension() == 1 && shape3.dimension() == 1) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$tf8Y7lcFgHmmYWHWtExY2UwxxBk
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList lpPool;
                    lpPool = Pool.lpPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention, i);
                    return lpPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for lpPool1D layer should be 1");
    }

    public static Block lpPool2DBlock(Shape shape, int i) {
        return lpPool2DBlock(shape, shape, new Shape(0, 0), PoolingConvention.VALID, i);
    }

    public static Block lpPool2DBlock(Shape shape, Shape shape2, int i) {
        return lpPool2DBlock(shape, shape2, new Shape(0, 0), PoolingConvention.VALID, i);
    }

    public static Block lpPool2DBlock(Shape shape, Shape shape2, Shape shape3, int i) {
        return lpPool2DBlock(shape, shape2, shape3, PoolingConvention.VALID, i);
    }

    public static Block lpPool2DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention, final int i) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for lpPool2D Block");
        }
        if (shape.dimension() == 2 && shape2.dimension() == 2 && shape3.dimension() == 2) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$MSRCh2bK4FDAuspiE4E-_3Q4FOU
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList lpPool;
                    lpPool = Pool.lpPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention, i);
                    return lpPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for lpPool2D layer should be 2");
    }

    public static Block lpPool3DBlock(Shape shape, int i) {
        return lpPool3DBlock(shape, shape, new Shape(0, 0, 0), PoolingConvention.VALID, i);
    }

    public static Block lpPool3DBlock(Shape shape, Shape shape2, int i) {
        return lpPool3DBlock(shape, shape2, new Shape(0, 0, 0), PoolingConvention.VALID, i);
    }

    public static Block lpPool3DBlock(Shape shape, Shape shape2, Shape shape3, int i) {
        return lpPool3DBlock(shape, shape2, shape3, PoolingConvention.VALID, i);
    }

    public static Block lpPool3DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention, final int i) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for lpPool3D Block");
        }
        if (shape.dimension() == 3 && shape2.dimension() == 3 && shape3.dimension() == 3) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$9CcHP8po3N-Ox6-KEtqeTehA-v8
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList lpPool;
                    lpPool = Pool.lpPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention, i);
                    return lpPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for lpPool3D layer should be 3");
    }

    private static NDArray maxPool(NDArray nDArray, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention) {
        return nDArray.getNDArrayInternal().maxPool(shape, shape2, shape3, poolingConvention);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static NDList maxPool(NDList nDList, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention) {
        return new NDList(maxPool(nDList.singletonOrThrow(), shape, shape2, shape3, poolingConvention));
    }

    public static Block maxPool1DBlock(Shape shape) {
        return maxPool1DBlock(shape, shape, new Shape(0), PoolingConvention.VALID);
    }

    public static Block maxPool1DBlock(Shape shape, Shape shape2) {
        return maxPool1DBlock(shape, shape2, new Shape(0), PoolingConvention.VALID);
    }

    public static Block maxPool1DBlock(Shape shape, Shape shape2, Shape shape3) {
        return maxPool1DBlock(shape, shape2, shape3, PoolingConvention.VALID);
    }

    public static Block maxPool1DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for maxPool1DBlock Block");
        }
        if (shape.dimension() == 1 && shape2.dimension() == 1 && shape3.dimension() == 1) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$D0EBs6jjMAccsIK5hRON_RypDOw
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList maxPool;
                    maxPool = Pool.maxPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention);
                    return maxPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for maxPool1DBlock layer should be 1");
    }

    public static Block maxPool2DBlock(Shape shape) {
        return maxPool2DBlock(shape, shape, new Shape(0, 0), PoolingConvention.VALID);
    }

    public static Block maxPool2DBlock(Shape shape, Shape shape2) {
        return maxPool2DBlock(shape, shape2, new Shape(0, 0), PoolingConvention.VALID);
    }

    public static Block maxPool2DBlock(Shape shape, Shape shape2, Shape shape3) {
        return maxPool2DBlock(shape, shape2, shape3, PoolingConvention.VALID);
    }

    public static Block maxPool2DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for maxPool2DBlock Block");
        }
        if (shape.dimension() == 2 && shape2.dimension() == 2 && shape3.dimension() == 2) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$81f-LR2HeQX9-38w68yAI-BURH4
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList maxPool;
                    maxPool = Pool.maxPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention);
                    return maxPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for maxPool2DBlock layer should be 2");
    }

    public static Block maxPool3DBlock(Shape shape) {
        return maxPool3DBlock(shape, shape, new Shape(0, 0, 0), PoolingConvention.VALID);
    }

    public static Block maxPool3DBlock(Shape shape, Shape shape2) {
        return maxPool3DBlock(shape, shape2, new Shape(0, 0, 0), PoolingConvention.VALID);
    }

    public static Block maxPool3DBlock(Shape shape, Shape shape2, Shape shape3) {
        return maxPool3DBlock(shape, shape2, shape3, PoolingConvention.VALID);
    }

    public static Block maxPool3DBlock(final Shape shape, final Shape shape2, final Shape shape3, final PoolingConvention poolingConvention) {
        if (shape == null) {
            throw new IllegalArgumentException("Kernel cannot be null for maxPool3DBlock Block");
        }
        if (shape.dimension() == 3 && shape2.dimension() == 3 && shape3.dimension() == 3) {
            return new LambdaBlock(new Function() { // from class: ai.djl.nn.pooling.-$$Lambda$Pool$1-B_GNY9O_Hrp6VB_pvGnmN1uvQ
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    NDList maxPool;
                    maxPool = Pool.maxPool((NDList) obj, Shape.this, shape2, shape3, poolingConvention);
                    return maxPool;
                }
            });
        }
        throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for maxPool3DBlock layer should be 3");
    }
}
