package ai.djl.repository.zoo;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDList;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Metadata;
import ai.djl.repository.Repository;
import ai.djl.repository.VersionRange;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/* loaded from: classes.dex */
public abstract class BaseModelLoader<I, O> implements ModelLoader<I, O> {
    protected Map<Pair<Type, Type>, TranslatorFactory<?, ?>> factories;
    private Metadata metadata;
    protected ModelZoo modelZoo;
    protected MRL mrl;
    protected Repository repository;
    protected String version;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseModelLoader(Repository repository, MRL mrl, String str, ModelZoo modelZoo) {
        this.repository = repository;
        this.mrl = mrl;
        this.version = str;
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        this.factories = concurrentHashMap;
        concurrentHashMap.put(new Pair(NDList.class, NDList.class), new TranslatorFactory() { // from class: ai.djl.repository.zoo.-$$Lambda$BaseModelLoader$NAcD0sOVd-lMXBYmuAoHjrVoYaw
            @Override // ai.djl.translate.TranslatorFactory
            public final Translator newInstance(Map map) {
                return BaseModelLoader.lambda$new$0(map);
            }
        });
        this.modelZoo = modelZoo;
    }

    private Metadata getMetadata() throws IOException, ModelNotFoundException {
        if (this.metadata == null) {
            Metadata locate = this.repository.locate(this.mrl);
            this.metadata = locate;
            if (locate == null) {
                throw new ModelNotFoundException(this.mrl.getArtifactId() + " Models not found.");
            }
        }
        return this.metadata;
    }

    private <S, T> TranslatorFactory<S, T> getTranslatorFactory(Criteria<S, T> criteria) {
        return (TranslatorFactory) this.factories.get(new Pair(criteria.getInputClass(), criteria.getOutputClass()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Translator lambda$new$0(Map map) {
        return new NoopTranslator();
    }

    private List<Artifact> search(Map<String, String> map) throws IOException, ModelNotFoundException {
        return getMetadata().search(VersionRange.parse(this.version), map);
    }

    protected Model createModel(String str, Device device, Artifact artifact, Map<String, Object> map, String str2) throws IOException {
        return Model.newInstance(str, device, str2);
    }

    @Override // ai.djl.repository.zoo.ModelLoader
    public String getArtifactId() {
        return this.mrl.getArtifactId();
    }

    public /* synthetic */ boolean lambda$listModels$1$BaseModelLoader(Artifact artifact) {
        String str = this.version;
        return str == null || str.equals(artifact.getVersion());
    }

    @Override // ai.djl.repository.zoo.ModelLoader
    public List<Artifact> listModels() throws IOException, ModelNotFoundException {
        return (List) getMetadata().getArtifacts().stream().filter(new Predicate() { // from class: ai.djl.repository.zoo.-$$Lambda$BaseModelLoader$LM0N0QvF03R5r01sO1tBf3zVLrU
            @Override // java.util.function.Predicate
            public final boolean test(Object obj) {
                return BaseModelLoader.this.lambda$listModels$1$BaseModelLoader((Artifact) obj);
            }
        }).collect(Collectors.toList());
    }

    @Override // ai.djl.repository.zoo.ModelLoader
    public <S, T> ZooModel<S, T> loadModel(Criteria<S, T> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
        Artifact match = match(criteria.getFilters());
        if (match == null) {
            throw new ModelNotFoundException("Model not found.");
        }
        Map<String, Object> arguments = criteria.getArguments();
        Progress progress = criteria.getProgress();
        Map<String, Object> arguments2 = match.getArguments(arguments);
        try {
            Translator<S, T> translator = criteria.getTranslator();
            if (translator == null) {
                TranslatorFactory<S, T> translatorFactory = getTranslatorFactory(criteria);
                if (translatorFactory == null) {
                    throw new ModelNotFoundException("No matching default translator found.");
                }
                translator = translatorFactory.newInstance(arguments2);
            }
            this.repository.prepare(match, progress);
            if (progress != null) {
                progress.reset("Loading", 2L);
                progress.update(1L);
            }
            Path resourceDirectory = this.repository.getResourceDirectory(match);
            String engine = criteria.getEngine();
            if (engine == null && this.modelZoo != null) {
                String engineName = Engine.getInstance().getEngineName();
                Iterator<String> it = this.modelZoo.getSupportedEngines().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    String next = it.next();
                    if (next.equals(engineName)) {
                        engine = next;
                        break;
                    }
                    if (Engine.hasEngine(next)) {
                        engine = next;
                    }
                }
                if (engine == null) {
                    throw new ModelNotFoundException("No supported engine available for model zoo: " + this.modelZoo.getGroupId());
                }
            }
            String str = engine;
            if (str != null && !Engine.hasEngine(str)) {
                throw new ModelNotFoundException(str + " is not supported.");
            }
            String modelName = criteria.getModelName();
            if (modelName == null) {
                modelName = match.getName();
            }
            Model createModel = createModel(modelName, Device.defaultDevice(), match, arguments2, str);
            if (criteria.getBlock() != null) {
                createModel.setBlock(criteria.getBlock());
            }
            createModel.load(resourceDirectory, null, criteria.getOptions());
            return new ZooModel<>(createModel, translator);
        } finally {
            if (progress != null) {
                progress.end();
            }
        }
    }

    protected Artifact match(Map<String, String> map) throws IOException, ModelNotFoundException {
        List<Artifact> search = search(map);
        if (search.isEmpty()) {
            return null;
        }
        return search.get(0);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.repository.getName()).append(':').append(this.mrl.getGroupId()).append(':').append(this.mrl.getArtifactId()).append(" [\n");
        try {
            Iterator<Artifact> it = listModels().iterator();
            while (it.hasNext()) {
                sb.append('\t').append(it.next()).append('\n');
            }
        } catch (ModelNotFoundException | IOException unused) {
            sb.append("\tFailed load metadata.");
        }
        sb.append("\n]");
        return sb.toString();
    }
}
