package org.tensorflow.contrib.android;

import android.content.res.AssetManager;
import android.os.Trace;
import android.util.Log;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

/* loaded from: classes2.dex */
public class TensorFlowInferenceInterface {
    private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
    private static final String TAG = "TensorFlowInferenceInterface";
    private boolean enableStats;
    private List<Tensor> feedTensors = new ArrayList();
    private List<String> fetchNames = new ArrayList();
    private List<Tensor> fetchTensors = new ArrayList();
    private Graph g;
    private RunStats runStats;
    private Session.Runner runner;
    private Session sess;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes2.dex */
    public static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorId parse(String str) {
            TensorId tensorId = new TensorId();
            int lastIndexOf = str.lastIndexOf(58);
            if (lastIndexOf < 0) {
                tensorId.outputIndex = 0;
                tensorId.name = str;
                return tensorId;
            }
            try {
                tensorId.outputIndex = Integer.parseInt(str.substring(lastIndexOf + 1));
                tensorId.name = str.substring(0, lastIndexOf);
            } catch (NumberFormatException unused) {
                tensorId.outputIndex = 0;
                tensorId.name = str;
            }
            return tensorId;
        }
    }

    public TensorFlowInferenceInterface() {
        try {
            new RunStats();
            Log.i(TAG, "Native methods already loaded.");
        } catch (UnsatisfiedLinkError unused) {
            Log.i(TAG, "Loading tensorflow_inference.");
            try {
                System.loadLibrary("tensorflow_inference");
            } catch (UnsatisfiedLinkError unused2) {
                throw new RuntimeException("Native TF methods not found; check that the correct native libraries are present and loaded.");
            }
        }
    }

    private void addFeed(String str, Tensor tensor) {
        TensorId parse = TensorId.parse(str);
        this.runner.feed(parse.name, parse.outputIndex, tensor);
        this.feedTensors.add(tensor);
    }

    private void closeFeeds() {
        Iterator<Tensor> it = this.feedTensors.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
        this.feedTensors.clear();
    }

    private void closeFetches() {
        Iterator<Tensor> it = this.fetchTensors.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
        this.fetchTensors.clear();
        this.fetchNames.clear();
    }

    private Tensor getTensor(String str) {
        Iterator<String> it = this.fetchNames.iterator();
        int i = 0;
        while (it.hasNext()) {
            if (it.next().equals(str)) {
                return this.fetchTensors.get(i);
            }
            i++;
        }
        return null;
    }

    private void load(InputStream inputStream) throws IOException {
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();
        long currentTimeMillis = System.currentTimeMillis();
        Trace.beginSection("initializeTensorFlow");
        Trace.beginSection("readGraphDef");
        byte[] bArr = new byte[inputStream.available()];
        int read = inputStream.read(bArr);
        if (read != bArr.length) {
            throw new IOException("read error: read only " + read + " of the graph, expected to read " + bArr.length);
        }
        Trace.endSection();
        Trace.beginSection("importGraphDef");
        try {
            this.g.importGraphDef(bArr);
            Trace.endSection();
            Trace.endSection();
            Log.i(TAG, "Model load took " + (currentTimeMillis - System.currentTimeMillis()) + "ms, TensorFlow version: " + TensorFlow.version());
        } catch (IllegalArgumentException e) {
            throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
        }
    }

    private long[] mkDims(IntBuffer intBuffer) {
        if (intBuffer.hasArray()) {
            return mkDims(intBuffer.array());
        }
        int[] iArr = new int[intBuffer.remaining()];
        intBuffer.duplicate().get(iArr);
        return mkDims(iArr);
    }

    private long[] mkDims(int[] iArr) {
        long[] jArr = new long[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            jArr[i] = iArr[i];
        }
        return jArr;
    }

    public void close() {
        closeFeeds();
        closeFetches();
        this.sess.close();
        this.g.close();
        RunStats runStats = this.runStats;
        if (runStats != null) {
            runStats.close();
        }
        this.runStats = null;
        this.enableStats = false;
    }

    public void enableStatLogging(boolean z) {
        this.enableStats = z;
        if (this.enableStats && this.runStats == null) {
            this.runStats = new RunStats();
        }
    }

    public void fillNodeByte(String str, int[] iArr, byte[] bArr) {
        addFeed(str, Tensor.create(DataType.UINT8, mkDims(iArr), ByteBuffer.wrap(bArr)));
    }

    public void fillNodeDouble(String str, int[] iArr, double[] dArr) {
        addFeed(str, Tensor.create(mkDims(iArr), DoubleBuffer.wrap(dArr)));
    }

    public void fillNodeFloat(String str, int[] iArr, float[] fArr) {
        addFeed(str, Tensor.create(mkDims(iArr), FloatBuffer.wrap(fArr)));
    }

    public void fillNodeFromByteBuffer(String str, IntBuffer intBuffer, ByteBuffer byteBuffer) {
        addFeed(str, Tensor.create(DataType.UINT8, mkDims(intBuffer), byteBuffer));
    }

    public void fillNodeFromDoubleBuffer(String str, IntBuffer intBuffer, DoubleBuffer doubleBuffer) {
        addFeed(str, Tensor.create(mkDims(intBuffer), doubleBuffer));
    }

    public void fillNodeFromFloatBuffer(String str, IntBuffer intBuffer, FloatBuffer floatBuffer) {
        addFeed(str, Tensor.create(mkDims(intBuffer), floatBuffer));
    }

    public void fillNodeFromIntBuffer(String str, IntBuffer intBuffer, IntBuffer intBuffer2) {
        addFeed(str, Tensor.create(mkDims(intBuffer), intBuffer2));
    }

    public void fillNodeInt(String str, int[] iArr, int[] iArr2) {
        addFeed(str, Tensor.create(mkDims(iArr), IntBuffer.wrap(iArr2)));
    }

    public String getStatString() {
        RunStats runStats = this.runStats;
        return runStats == null ? "" : runStats.summary();
    }

    public Graph graph() {
        return this.g;
    }

    public int initializeTensorFlow(AssetManager assetManager, String str) {
        String str2;
        InputStream inputStream;
        boolean startsWith = str.startsWith(ASSET_FILE_PREFIX);
        if (startsWith) {
            try {
                str2 = str.split(ASSET_FILE_PREFIX)[1];
            } catch (IOException e) {
                if (startsWith) {
                    Log.e(TAG, "Failed to initialize: " + e.toString());
                    return 1;
                }
                try {
                    inputStream = new FileInputStream(str);
                } catch (IOException e2) {
                    Log.e(TAG, "Failed to open " + str + ": " + e2.toString());
                    return 1;
                }
            }
        } else {
            str2 = str;
        }
        inputStream = assetManager.open(str2);
        try {
            load(inputStream);
            inputStream.close();
            return 0;
        } catch (IOException e3) {
            Log.e(TAG, "Failed to initialize: " + e3.toString());
            return 1;
        }
    }

    public int readNodeByte(String str, byte[] bArr) {
        return readNodeIntoByteBuffer(str, ByteBuffer.wrap(bArr));
    }

    public int readNodeDouble(String str, double[] dArr) {
        return readNodeIntoDoubleBuffer(str, DoubleBuffer.wrap(dArr));
    }

    public int readNodeFloat(String str, float[] fArr) {
        return readNodeIntoFloatBuffer(str, FloatBuffer.wrap(fArr));
    }

    public int readNodeInt(String str, int[] iArr) {
        return readNodeIntoIntBuffer(str, IntBuffer.wrap(iArr));
    }

    public int readNodeIntoByteBuffer(String str, ByteBuffer byteBuffer) {
        Tensor tensor = getTensor(str);
        if (tensor == null) {
            return -1;
        }
        tensor.writeTo(byteBuffer);
        return 0;
    }

    public int readNodeIntoDoubleBuffer(String str, DoubleBuffer doubleBuffer) {
        Tensor tensor = getTensor(str);
        if (tensor == null) {
            return -1;
        }
        tensor.writeTo(doubleBuffer);
        return 0;
    }

    public int readNodeIntoFloatBuffer(String str, FloatBuffer floatBuffer) {
        Tensor tensor = getTensor(str);
        if (tensor == null) {
            return -1;
        }
        tensor.writeTo(floatBuffer);
        return 0;
    }

    public int readNodeIntoIntBuffer(String str, IntBuffer intBuffer) {
        Tensor tensor = getTensor(str);
        if (tensor == null) {
            return -1;
        }
        tensor.writeTo(intBuffer);
        return 0;
    }

    /* JADX WARN: Unreachable blocks removed: 1, instructions: 1 */
    public int runInference(String[] strArr) {
        closeFetches();
        for (String str : strArr) {
            try {
                this.fetchNames.add(str);
                TensorId parse = TensorId.parse(str);
                this.runner.fetch(parse.name, parse.outputIndex);
            } catch (RuntimeException e) {
                Log.e(TAG, "Failed to run TensorFlow session: " + e.toString());
                return -1;
            } finally {
                closeFeeds();
                this.runner = this.sess.runner();
            }
        }
        if (this.enableStats) {
            Session.Run runAndFetchMetadata = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
            this.fetchTensors = runAndFetchMetadata.outputs;
            this.runStats.add(runAndFetchMetadata.metadata);
        } else {
            this.fetchTensors = this.runner.run();
        }
        return 0;
    }
}
