package ai.djl;

import ai.djl.engine.Engine;
import ai.djl.engine.StandardCapabilities;
import ai.djl.util.cuda.CudaUtils;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

/* loaded from: classes.dex */
public final class Device {
    private static final Map<String, Device> CACHE = new ConcurrentHashMap();
    private static final Device CPU = new Device(Type.CPU, -1);
    private static final Device GPU = of(Type.GPU, 0);
    private int deviceId;
    private String deviceType;

    /* loaded from: classes.dex */
    public interface Type {
        public static final String CPU = "cpu";
        public static final String GPU = "gpu";
    }

    private Device(String str, int i) {
        this.deviceType = str;
        this.deviceId = i;
    }

    public static Device cpu() {
        return CPU;
    }

    public static Device defaultDevice() {
        return getGpuCount() > 0 ? GPU : CPU;
    }

    public static Device defaultIfNull(Device device) {
        return device != null ? device : defaultDevice();
    }

    public static Device defaultIfNull(Device device, Device device2) {
        return device != null ? device : defaultIfNull(device2);
    }

    public static Device[] getDevices(int i) {
        int gpuCount = getGpuCount();
        if (i <= 0 || gpuCount <= 0) {
            return new Device[]{CPU};
        }
        int min = Math.min(i, gpuCount);
        Device[] deviceArr = new Device[min];
        for (int i2 = 0; i2 < min; i2++) {
            deviceArr[i2] = gpu(i2);
        }
        return deviceArr;
    }

    public static int getGpuCount() {
        if (Engine.getInstance().hasCapability(StandardCapabilities.CUDA)) {
            return CudaUtils.getGpuCount();
        }
        return 0;
    }

    public static int getGpuCount(String str) {
        if (Engine.getEngine(str).hasCapability(StandardCapabilities.CUDA)) {
            return CudaUtils.getGpuCount();
        }
        return 0;
    }

    public static Device gpu() {
        return GPU;
    }

    public static Device gpu(int i) {
        return of(Type.GPU, i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Device lambda$of$0(String str, int i, String str2) {
        return new Device(str, i);
    }

    public static Device of(final String str, final int i) {
        if (Type.CPU.equals(str)) {
            return CPU;
        }
        return CACHE.computeIfAbsent(str + '-' + i, new Function() { // from class: ai.djl.-$$Lambda$Device$HmEoQ0HkmQL_B302VcRfslkptsc
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Device.lambda$of$0(str, i, (String) obj);
            }
        });
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Device device = (Device) obj;
        return Type.CPU.equals(this.deviceType) ? Objects.equals(this.deviceType, device.deviceType) : this.deviceId == device.deviceId && Objects.equals(this.deviceType, device.deviceType);
    }

    public int getDeviceId() {
        if (Type.CPU.equals(this.deviceType)) {
            throw new IllegalStateException("CPU doesn't have device id");
        }
        return this.deviceId;
    }

    public String getDeviceType() {
        return this.deviceType;
    }

    public int hashCode() {
        return Objects.hash(this.deviceType, Integer.valueOf(this.deviceId));
    }

    public String toString() {
        return Type.CPU.equals(this.deviceType) ? this.deviceType + "()" : this.deviceType + '(' + this.deviceId + ')';
    }
}
