package deepboof.graph;

import deepboof.Function;
import deepboof.Tensor;
import deepboof.misc.TensorFactory;
import deepboof.misc.TensorOps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ddogleg.struct.Tuple2;

/* loaded from: classes2.dex */
public class FunctionSequence<T extends Tensor<T>, F extends Function<T>> {
    protected TensorFactory<T> factory;
    protected List<Node<T, F>> sequence;
    protected Map<String, Node<T, F>> lookup = new HashMap();
    protected Map<String, Tuple2<T, T>> outputStorage = new HashMap();
    boolean verbose = false;

    public FunctionSequence(List<Node<T, F>> list, Class<T> cls) {
        this.sequence = new ArrayList();
        this.sequence = list;
        for (Node<T, F> node : list) {
            if (this.lookup.containsKey(node.name)) {
                throw new IllegalArgumentException("Conflict.  Multiple nodes with the same name.  " + node.name);
            }
            this.lookup.put(node.name, node);
        }
        this.factory = new TensorFactory<>(cls);
    }

    private void declareOutputStorage(int i) {
        if (this.sequence.size() == 1) {
            return;
        }
        for (int i2 = 0; i2 < this.sequence.size(); i2++) {
            Node<T, F> node = this.sequence.get(i2);
            Tuple2<T, T> tuple2 = this.outputStorage.get(node.name);
            if (i2 == 0 || node.sources.size() == 1) {
                if (i2 != this.sequence.size() - 1) {
                    tuple2.data0.reshape(TensorOps.WI(i, node.function.getOutputShape()));
                }
                tuple2.data1 = null;
            } else {
                if (i2 != this.sequence.size() - 1) {
                    tuple2.data0.reshape(TensorOps.WI(node.function.getOutputShape()));
                }
                tuple2.data1.reshape(TensorOps.WI(node.combine.getOutputShape()));
            }
        }
    }

    private void initializeSequence(int[] iArr) {
        if (this.sequence.get(0).sources.size() != 0) {
            throw new RuntimeException("Input sequence can't have a source address!");
        }
        ArrayList arrayList = new ArrayList();
        this.sequence.get(0).function.initialize(iArr);
        this.outputStorage.put(this.sequence.get(0).name, new Tuple2<>(this.factory.create(new int[0]), this.factory.create(new int[0])));
        if (this.verbose) {
            System.out.println("ROOT ========= " + this.sequence.get(0).name);
            printOutput(this.sequence.get(0), iArr);
        }
        for (int i = 1; i < this.sequence.size(); i++) {
            Node<T, F> node = this.sequence.get(i);
            if (this.verbose) {
                System.out.println("============== " + node.name);
            }
            this.outputStorage.put(node.name, new Tuple2<>(this.factory.create(new int[0]), this.factory.create(new int[0])));
            if (node.sources.size() == 0) {
                throw new RuntimeException("No sources!  Node = " + node.name);
            }
            arrayList.clear();
            for (int i2 = 0; i2 < node.sources.size(); i2++) {
                InputAddress inputAddress = node.sources.get(i2);
                Node<T, F> node2 = this.lookup.get(inputAddress.nodeName);
                if (node2 == null) {
                    throw new RuntimeException("Can't find input node from name.  Bad network");
                }
                arrayList.add(node2.function.getOutputShape());
                if (this.verbose) {
                    System.out.println("   input addr " + inputAddress.nodeName);
                }
            }
            if (arrayList.size() == 1) {
                node.function.initialize((int[]) arrayList.get(0));
                if (this.verbose) {
                    printOutput(node, (int[]) arrayList.get(0));
                }
            } else {
                if (node.combine == null) {
                    throw new RuntimeException("Must specify a combine operator if there are multiple sources");
                }
                node.combine.initialize(arrayList);
                node.function.initialize(node.combine.getOutputShape());
                if (this.verbose) {
                    printOutput(node, node.combine.getOutputShape());
                }
            }
        }
    }

    private void printOutput(Node<T, F> node, int[] iArr) {
        int[] outputShape = node.function.getOutputShape();
        System.out.printf("%30s input %25s  out = %25s\n", node.function.getClass().getSimpleName(), TensorOps.toStringShape(iArr), TensorOps.toStringShape(outputShape));
    }

    public T getNodeOutput(int i) {
        return this.outputStorage.get(this.sequence.get(i).name).data0;
    }

    public int[] getOutputShape() {
        return this.sequence.get(r0.size() - 1).function.getOutputShape();
    }

    public List<Node<T, F>> getSequence() {
        return this.sequence;
    }

    public Class<T> getTensorType() {
        return this.factory.getTensorType();
    }

    public void initialize(int[] iArr) {
        initializeSequence(iArr);
    }

    public void process(T t, T t2) {
        if (this.sequence.size() == 1) {
            this.sequence.get(0).function.forward(t, t2);
            return;
        }
        declareOutputStorage(t.length(0));
        Node<T, F> node = this.sequence.get(0);
        node.function.forward(t, this.outputStorage.get(node.name).data0);
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i < this.sequence.size() - 1; i++) {
            Node<T, F> node2 = this.sequence.get(i);
            Tuple2<T, T> tuple2 = this.outputStorage.get(node2.name);
            arrayList.clear();
            for (int i2 = 0; i2 < node2.sources.size(); i2++) {
                arrayList.add(this.outputStorage.get(node2.sources.get(i2).nodeName).data0);
            }
            if (node2.sources.size() == 1) {
                node2.function.forward((Tensor) arrayList.get(0), tuple2.data0);
            } else {
                node2.combine.combine(arrayList, tuple2.data1);
                node2.function.forward(tuple2.data1, tuple2.data0);
            }
        }
        List<Node<T, F>> list = this.sequence;
        Node<T, F> node3 = list.get(list.size() - 1);
        arrayList.clear();
        for (int i3 = 0; i3 < node3.sources.size(); i3++) {
            arrayList.add(this.outputStorage.get(node3.sources.get(i3).nodeName).data0);
        }
        if (node3.sources.size() == 1) {
            node3.function.forward((Tensor) arrayList.get(0), t2);
            return;
        }
        Tuple2<T, T> tuple22 = this.outputStorage.get(node3.name);
        node3.combine.combine(arrayList, tuple22.data1);
        node3.function.forward(tuple22.data1, t2);
    }

    public void setParameters(Map<String, List<T>> map) {
        for (int i = 0; i < this.sequence.size(); i++) {
            Node<T, F> node = this.sequence.get(i);
            List<T> list = map.get(node.name);
            if (list != null) {
                node.function.setParameters(list);
            }
        }
    }
}
