package ai.djl.training;

import ai.djl.ndarray.NDList;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.listener.TrainingListener;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

/* loaded from: classes.dex */
public final class EasyTrain {
    private EasyTrain() {
    }

    public static void fit(final Trainer trainer, int i, Dataset dataset, Dataset dataset2) {
        for (int i2 = 0; i2 < i; i2++) {
            for (Batch batch : trainer.iterateDataset(dataset)) {
                trainBatch(trainer, batch);
                trainer.step();
                batch.close();
            }
            if (dataset2 != null) {
                for (Batch batch2 : trainer.iterateDataset(dataset2)) {
                    validateBatch(trainer, batch2);
                    batch2.close();
                }
            }
            trainer.notifyListeners(new Consumer() { // from class: ai.djl.training.-$$Lambda$EasyTrain$3QKn1OlHcEYveHn-Kdv-fHAffhA
                @Override // java.util.function.Consumer
                public final void accept(Object obj) {
                    ((TrainingListener) obj).onEpoch(Trainer.this);
                }
            });
        }
    }

    public static void trainBatch(final Trainer trainer, Batch batch) {
        if (trainer.getManager().getEngine() != batch.getManager().getEngine()) {
            throw new IllegalArgumentException("The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        }
        Batch[] split = batch.split(trainer.getDevices(), false);
        final TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap(), new ConcurrentHashMap());
        GradientCollector newGradientCollector = trainer.newGradientCollector();
        try {
            for (Batch batch2 : split) {
                NDList data = batch2.getData();
                NDList labels = batch2.getLabels();
                NDList forward = trainer.forward(data, labels);
                long nanoTime = System.nanoTime();
                newGradientCollector.backward(trainer.getLoss().evaluate(labels, forward));
                trainer.addMetric("backward", nanoTime);
                long nanoTime2 = System.nanoTime();
                batchData.getLabels().put(labels.get(0).getDevice(), labels);
                batchData.getPredictions().put(forward.get(0).getDevice(), forward);
                trainer.addMetric("training-metrics", nanoTime2);
            }
            if (newGradientCollector != null) {
                newGradientCollector.close();
            }
            trainer.notifyListeners(new Consumer() { // from class: ai.djl.training.-$$Lambda$EasyTrain$t62GQ3FK6QAn2nR9cjP_V_p5Bks
                @Override // java.util.function.Consumer
                public final void accept(Object obj) {
                    ((TrainingListener) obj).onTrainingBatch(Trainer.this, batchData);
                }
            });
        } catch (Throwable th) {
            try {
                throw th;
            } catch (Throwable th2) {
                if (newGradientCollector != null) {
                    if (th != null) {
                        try {
                            newGradientCollector.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        newGradientCollector.close();
                    }
                }
                throw th2;
            }
        }
    }

    public static void validateBatch(final Trainer trainer, Batch batch) {
        if (trainer.getManager().getEngine() != batch.getManager().getEngine()) {
            throw new IllegalArgumentException("The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        }
        Batch[] split = batch.split(trainer.getDevices(), false);
        final TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap(), new ConcurrentHashMap());
        for (Batch batch2 : split) {
            NDList data = batch2.getData();
            NDList labels = batch2.getLabels();
            NDList forward = trainer.forward(data, labels);
            batchData.getLabels().put(labels.get(0).getDevice(), labels);
            batchData.getPredictions().put(forward.get(0).getDevice(), forward);
        }
        trainer.notifyListeners(new Consumer() { // from class: ai.djl.training.-$$Lambda$EasyTrain$J6SdO9mnDfJpKVFW3nPabi0CKTM
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                ((TrainingListener) obj).onValidationBatch(Trainer.this, batchData);
            }
        });
    }
}
