package ai.djl.training.optimizer.learningrate;

import ai.djl.TrainingDivergedException;
import ai.djl.training.optimizer.learningrate.FactorTracker;
import ai.djl.training.optimizer.learningrate.MultiFactorTracker;

/* loaded from: classes.dex */
public abstract class LearningRateTracker {
    float baseLearningRate;
    float warmUpBeginLearningRate;
    float warmUpFinalLearningRate;
    WarmUpMode warmUpMode;
    int warmUpSteps;

    /* loaded from: classes.dex */
    public static abstract class LrBaseBuilder<T extends LrBaseBuilder> {
        float warmUpBeginLearningRate;
        int warmUpSteps;
        float baseLearningRate = 0.01f;
        WarmUpMode warmUpMode = WarmUpMode.LINEAR;

        public T optBaseLearningRate(float f) {
            this.baseLearningRate = f;
            return self();
        }

        public T optWarmUpBeginLearningRate(float f) {
            this.warmUpBeginLearningRate = f;
            return self();
        }

        public T optWarmUpMode(WarmUpMode warmUpMode) {
            this.warmUpMode = warmUpMode;
            return self();
        }

        public T optWarmUpSteps(int i) {
            this.warmUpSteps = i;
            return self();
        }

        protected abstract T self();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LearningRateTracker(LrBaseBuilder<?> lrBaseBuilder) {
        this.baseLearningRate = lrBaseBuilder.baseLearningRate;
        this.warmUpSteps = lrBaseBuilder.warmUpSteps;
        this.warmUpBeginLearningRate = lrBaseBuilder.warmUpBeginLearningRate;
        this.warmUpMode = lrBaseBuilder.warmUpMode;
        this.warmUpFinalLearningRate = this.baseLearningRate;
    }

    public static FactorTracker.Builder factorTracker() {
        return new FactorTracker.Builder();
    }

    public static LearningRateTracker fixedLearningRate(float f) {
        return FixedLearningRate.builder().optBaseLearningRate(f).build();
    }

    public static MultiFactorTracker.Builder multiFactorTracker() {
        return new MultiFactorTracker.Builder();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void checkLearningRate(float f) {
        if (Float.isNaN(f)) {
            throw new TrainingDivergedException("Learning rate is Nan.");
        }
    }

    public abstract float getNewLearningRate(int i);

    /* JADX INFO: Access modifiers changed from: package-private */
    public float getWarmUpLearningRate(int i) {
        float f = this.warmUpBeginLearningRate;
        if (this.warmUpMode == WarmUpMode.LINEAR) {
            float f2 = this.warmUpBeginLearningRate;
            f = f2 + (((this.warmUpFinalLearningRate - f2) * i) / this.warmUpSteps);
        }
        checkLearningRate(f);
        return f;
    }
}
