package com.xiaomi.ai.minmt.common;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: classes2.dex */
public class Encoder {
    private static final int PAD = SpecialToken.PAD.getId();
    private static final int EOS = SpecialToken.EOS.getId();
    private static final String UNK = SpecialToken.UNK.getToken();
    private final Logger logger = LoggerFactory.getLogger((Class<?>) Encoder.class);
    private Map<String, Integer> tokenToId = new Hashtable();
    private List<String> idToToken = new ArrayList();

    /* loaded from: classes2.dex */
    public enum DecodeType {
        DEFAULT,
        TRIM_LAST,
        TRIM_EOS_PAD
    }

    public Encoder(InputStream inputStream) throws IOException {
        Objects.requireNonNull(inputStream);
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
        int i = 0;
        while (true) {
            try {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                this.tokenToId.put(readLine, Integer.valueOf(i));
                this.idToToken.add(readLine);
                i++;
            } finally {
                bufferedReader.close();
            }
        }
    }

    @Deprecated
    private String getText(int i) {
        return getToken(i);
    }

    public int getId(String str) {
        Objects.requireNonNull(str);
        if (this.tokenToId.get(str) != null) {
            return this.tokenToId.get(str).intValue();
        }
        this.logger.warn("input word {} not in vocab", str);
        return this.tokenToId.get(UNK).intValue();
    }

    public String getToken(int i) {
        if (i < 0) {
            throw new IllegalStateException("Token id should not be negative.");
        }
        if (i < this.idToToken.size()) {
            return this.idToToken.get(i);
        }
        this.logger.error("input id {} exceed vocab size", Integer.valueOf(i));
        return UNK;
    }

    public String ids2text(int[] iArr) {
        return ids2text(iArr, DecodeType.TRIM_EOS_PAD);
    }

    public String ids2text(int[] iArr, DecodeType decodeType) {
        return Utils.join(" ", ids2tokens(iArr, decodeType)).trim();
    }

    public List<String> ids2tokens(int[] iArr, DecodeType decodeType) {
        int length;
        if (decodeType == DecodeType.TRIM_LAST) {
            length = iArr.length - 1;
        } else if (decodeType == DecodeType.TRIM_EOS_PAD) {
            int i = 0;
            for (int i2 : iArr) {
                if (i2 == PAD || i2 == EOS) {
                    break;
                }
                i++;
            }
            length = i;
        } else {
            length = iArr.length;
        }
        int[] iArr2 = new int[length];
        System.arraycopy(iArr, 0, iArr2, 0, length);
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < length; i3++) {
            arrayList.add(getText(iArr2[i3]));
        }
        return arrayList;
    }

    public int size() {
        return this.idToToken.size();
    }

    @Deprecated
    public int[] text2ids(String str) {
        return textToIds(str);
    }

    @Deprecated
    public int[] text2ids(String str, int i) {
        int min = Math.min(str.length() + 1, i);
        String[] split = str.trim().split("\\s+");
        int[] iArr = new int[min];
        if (split.length >= min) {
            this.logger.error("input size {} exceeds the maxLength {}", Integer.valueOf(split.length + 1), Integer.valueOf(min));
        }
        for (int i2 = 0; i2 < split.length && i2 < min; i2++) {
            iArr[i2] = getId(split[i2]);
        }
        iArr[min - 1] = EOS;
        return iArr;
    }

    @Deprecated
    public int[] text2ids(List<String> list) {
        return tokensToIds((String[]) list.toArray(new String[0]));
    }

    public int[] textToIds(String str) {
        if (str == null) {
            return null;
        }
        return tokensToIds(str.trim().split("\\s+"));
    }

    public int[] tokensToIds(String[] strArr) {
        if (strArr == null) {
            return null;
        }
        int length = strArr.length + 1;
        int[] iArr = new int[length];
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = getId(strArr[i]);
        }
        iArr[length - 1] = EOS;
        return iArr;
    }
}
