package ai.djl.modality.cv;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Mask;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.RandomUtils;
import com.itextpdf.text.pdf.PdfBoolean;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.FontMetrics;
import java.awt.Graphics2D;
import java.awt.Rectangle;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.Iterator;
import javax.imageio.ImageIO;

/* loaded from: classes.dex */
public class BufferedImageFactory extends ImageFactory {

    /* loaded from: classes.dex */
    private class BufferedImageWrapper implements Image {
        private final BufferedImage image;

        BufferedImageWrapper(BufferedImage bufferedImage) {
            this.image = bufferedImage;
        }

        private void drawMask(BufferedImage bufferedImage, Mask mask) {
            float nextFloat = RandomUtils.nextFloat();
            float nextFloat2 = RandomUtils.nextFloat();
            float nextFloat3 = RandomUtils.nextFloat();
            int width = bufferedImage.getWidth();
            int height = bufferedImage.getHeight();
            int x = (int) (mask.getX() * width);
            int y = (int) (mask.getY() * height);
            float[][] probDist = mask.getProbDist();
            if (x < 0) {
                x = 0;
            }
            if (y < 0) {
                y = 0;
            }
            BufferedImage bufferedImage2 = new BufferedImage(probDist.length, probDist[0].length, 2);
            for (int i = 0; i < probDist.length; i++) {
                for (int i2 = 0; i2 < probDist[i].length; i2++) {
                    float f = probDist[i][i2];
                    if (f < 0.1d) {
                        f = 0.0f;
                    }
                    if (f > 0.8d) {
                        f = 0.8f;
                    }
                    bufferedImage2.setRGB(i, i2, new Color(nextFloat, nextFloat2, nextFloat3, f).darker().getRGB());
                }
            }
            Graphics2D graphics = bufferedImage.getGraphics();
            graphics.drawImage(bufferedImage2, x, y, (ImageObserver) null);
            graphics.dispose();
        }

        private void drawText(Graphics2D graphics2D, String str, int i, int i2, int i3, int i4) {
            FontMetrics fontMetrics = graphics2D.getFontMetrics();
            int i5 = i3 / 2;
            int i6 = i + i5;
            int i7 = i2 + i5;
            int stringWidth = (fontMetrics.stringWidth(str) + (i4 * 2)) - i5;
            int height = fontMetrics.getHeight() + fontMetrics.getDescent();
            int ascent = fontMetrics.getAscent();
            graphics2D.fill(new Rectangle(i6, i7, stringWidth, height));
            graphics2D.setPaint(Color.WHITE);
            graphics2D.drawString(str, i6 + i4, i7 + ascent);
        }

        private int getType(Image.Type type) {
            if (type == Image.Type.TYPE_INT_ARGB) {
                return 2;
            }
            throw new IllegalArgumentException("the type is not supported!");
        }

        private Color randomColor() {
            return new Color(RandomUtils.nextInt(255));
        }

        @Override // ai.djl.modality.cv.Image
        public void drawBoundingBoxes(DetectedObjects detectedObjects) {
            Graphics2D graphics2D = (Graphics2D) this.image.getGraphics();
            graphics2D.setStroke(new BasicStroke(2));
            graphics2D.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
            int width = this.image.getWidth();
            int height = this.image.getHeight();
            Iterator it = detectedObjects.items().iterator();
            while (it.hasNext()) {
                DetectedObjects.DetectedObject detectedObject = (DetectedObjects.DetectedObject) it.next();
                String className = detectedObject.getClassName();
                BoundingBox boundingBox = detectedObject.getBoundingBox();
                graphics2D.setPaint(randomColor().darker());
                ai.djl.modality.cv.output.Rectangle bounds = boundingBox.getBounds();
                double d = width;
                int x = (int) (bounds.getX() * d);
                Iterator it2 = it;
                double d2 = height;
                int y = (int) (bounds.getY() * d2);
                graphics2D.drawRect(x, y, (int) (bounds.getWidth() * d), (int) (bounds.getHeight() * d2));
                drawText(graphics2D, className, x, y, 2, 4);
                if (boundingBox instanceof Mask) {
                    drawMask(this.image, (Mask) boundingBox);
                }
                it = it2;
            }
            graphics2D.dispose();
        }

        @Override // ai.djl.modality.cv.Image
        public void drawJoints(Joints joints) {
            Graphics2D graphics = this.image.getGraphics();
            graphics.setStroke(new BasicStroke(2));
            int width = this.image.getWidth();
            int height = this.image.getHeight();
            for (Joints.Joint joint : joints.getJoints()) {
                graphics.setPaint(randomColor().darker());
                graphics.fillOval((int) (joint.getX() * width), (int) (joint.getY() * height), 10, 10);
            }
            graphics.dispose();
        }

        @Override // ai.djl.modality.cv.Image
        public Image duplicate(Image.Type type) {
            BufferedImage bufferedImage = new BufferedImage(this.image.getWidth(), this.image.getHeight(), getType(type));
            Graphics2D createGraphics = bufferedImage.createGraphics();
            createGraphics.drawImage(this.image, 0, 0, (ImageObserver) null);
            createGraphics.dispose();
            return new BufferedImageWrapper(bufferedImage);
        }

        @Override // ai.djl.modality.cv.Image
        public int getHeight() {
            return this.image.getHeight();
        }

        @Override // ai.djl.modality.cv.Image
        public Image getSubimage(int i, int i2, int i3, int i4) {
            return new BufferedImageWrapper(this.image.getSubimage(i, i2, i3, i4));
        }

        @Override // ai.djl.modality.cv.Image
        public int getWidth() {
            return this.image.getWidth();
        }

        @Override // ai.djl.modality.cv.Image
        public Object getWrappedImage() {
            return this.image;
        }

        @Override // ai.djl.modality.cv.Image
        public void save(OutputStream outputStream, String str) throws IOException {
            BufferedImageFactory.this.save(this.image, outputStream, str);
        }

        @Override // ai.djl.modality.cv.Image
        public NDArray toNDArray(NDManager nDManager, Image.Flag flag) {
            ByteBuffer byteBuffer;
            int width = this.image.getWidth();
            int height = this.image.getHeight();
            int i = flag == Image.Flag.GRAYSCALE ? 1 : 3;
            ByteBuffer allocateDirect = nDManager.allocateDirect(i * height * width);
            if (this.image.getType() == 10) {
                for (byte b2 : this.image.getData().getDataBuffer().getData()) {
                    allocateDirect.put(b2);
                    if (flag != Image.Flag.GRAYSCALE) {
                        allocateDirect.put(b2);
                        allocateDirect.put(b2);
                    }
                }
                byteBuffer = allocateDirect;
            } else {
                byteBuffer = allocateDirect;
                for (int i2 : this.image.getRGB(0, 0, width, height, (int[]) null, 0, width)) {
                    int i3 = (i2 >> 16) & 255;
                    int i4 = (i2 >> 8) & 255;
                    int i5 = i2 & 255;
                    if (flag == Image.Flag.GRAYSCALE) {
                        byteBuffer.put((byte) (((i3 + i4) + i5) / 3));
                    } else {
                        byteBuffer.put((byte) i3);
                        byteBuffer.put((byte) i4);
                        byteBuffer.put((byte) i5);
                    }
                }
            }
            byteBuffer.rewind();
            return nDManager.create(byteBuffer, new Shape(height, width, i), DataType.UINT8);
        }
    }

    static {
        if (System.getProperty("apple.awt.UIElement") == null) {
            System.setProperty("apple.awt.UIElement", PdfBoolean.TRUE);
        }
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromFile(Path path) throws IOException {
        BufferedImage read = ImageIO.read(path.toFile());
        if (read != null) {
            return new BufferedImageWrapper(read);
        }
        throw new IOException("Failed to read image from: " + path);
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromImage(Object obj) {
        if (obj instanceof BufferedImage) {
            return new BufferedImageWrapper((BufferedImage) obj);
        }
        throw new IllegalArgumentException("only BufferedImage allowed");
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromInputStream(InputStream inputStream) throws IOException {
        BufferedImage read = ImageIO.read(inputStream);
        if (read != null) {
            return new BufferedImageWrapper(read);
        }
        throw new IOException("Failed to read image from input stream");
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromUrl(URL url) throws IOException {
        BufferedImage read = ImageIO.read(url);
        if (read != null) {
            return new BufferedImageWrapper(read);
        }
        throw new IOException("Failed to read image from: " + url);
    }

    protected void save(BufferedImage bufferedImage, OutputStream outputStream, String str) throws IOException {
        ImageIO.write(bufferedImage, str, outputStream);
    }
}
