package com.kotlinnlp.languagedetector;

import com.kotlinnlp.languagedetector.LanguageDetector;
import com.kotlinnlp.languagedetector.utils.ExtensionsKt;
import com.kotlinnlp.languagedetector.utils.FrequencyDictionary;
import com.kotlinnlp.languagedetector.utils.Language;
import com.kotlinnlp.languagedetector.utils.TextTokenizer;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HANEncoder;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HANParameters;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchySequence;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: LanguageDetector.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��t\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\u0002\n\u0002\b\u0003\n\u0002\u0010\u000e\n��\n\u0002\u0010\u0006\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2\u00020\u0001:\u00010B!\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\n\b\u0002\u0010\u0006\u001a\u0004\u0018\u00010\u0007¢\u0006\u0002\u0010\bJ\u000e\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u000bJ\u0018\u0010\u0015\u001a\u00020\u000b2\u0006\u0010\u0016\u001a\u00020\u00172\b\b\u0002\u0010\u0018\u001a\u00020\u0019J \u0010\u001a\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0017\u0012\u0004\u0012\u00020\u001d0\u001c0\u001b2\u0006\u0010\u001e\u001a\u00020\u0017J\u0016\u0010\u001f\u001a\u00020\u000b2\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u000b0\u001bH\u0002J\u000e\u0010!\u001a\u00020\"2\u0006\u0010\u001e\u001a\u00020\u0017J)\u0010#\u001a\u00020\u00132\u0006\u0010\u001e\u001a\u00020\u00172\u0012\u0010$\u001a\u000e\u0012\u0004\u0012\u00020\u0017\u0012\u0004\u0012\u00020\u00130%H��¢\u0006\u0002\b&J\u0016\u0010'\u001a\b\u0012\u0004\u0012\u00020\u000b0(2\b\b\u0002\u0010)\u001a\u00020*J\u000e\u0010+\u001a\u00020\"2\u0006\u0010,\u001a\u00020\u000bJ\u0010\u0010-\u001a\u00020.2\b\b\u0002\u0010)\u001a\u00020*J\u000e\u0010/\u001a\u00020\u000b2\u0006\u0010\u001e\u001a\u00020\u0017R\u0014\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\nX\u0082\u0004¢\u0006\u0002\n��R\u0013\u0010\u0006\u001a\u0004\u0018\u00010\u0007¢\u0006\b\n��\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u000e\u0010\u000fR\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011¨\u00061"}, d2 = {"Lcom/kotlinnlp/languagedetector/LanguageDetector;", "", "model", "Lcom/kotlinnlp/languagedetector/LanguageDetectorModel;", "tokenizer", "Lcom/kotlinnlp/languagedetector/utils/TextTokenizer;", "frequencyDictionary", "Lcom/kotlinnlp/languagedetector/utils/FrequencyDictionary;", "(Lcom/kotlinnlp/languagedetector/LanguageDetectorModel;Lcom/kotlinnlp/languagedetector/utils/TextTokenizer;Lcom/kotlinnlp/languagedetector/utils/FrequencyDictionary;)V", "encoder", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HANEncoder;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getFrequencyDictionary", "()Lcom/kotlinnlp/languagedetector/utils/FrequencyDictionary;", "getModel", "()Lcom/kotlinnlp/languagedetector/LanguageDetectorModel;", "getTokenizer", "()Lcom/kotlinnlp/languagedetector/utils/TextTokenizer;", "backward", "", "outputErrors", "classifyToken", "token", "", "dropout", "", "classifyTokens", "", "Lkotlin/Pair;", "Lcom/kotlinnlp/languagedetector/LanguageDetector$TokenClassification;", "text", "combineClassifications", "classifications", "detectLanguage", "Lcom/kotlinnlp/languagedetector/utils/Language;", "forEachToken", "callback", "Lkotlin/Function1;", "forEachToken$languagedetector", "getInputSequenceErrors", "Ljava/util/ArrayList;", "copy", "", "getLanguage", "prediction", "getParamsErrors", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HANParameters;", "predict", "TokenClassification", "languagedetector"})
/* loaded from: input_file:com/kotlinnlp/languagedetector/LanguageDetector.class */
public final class LanguageDetector {
    private final HANEncoder<DenseNDArray> encoder;

    @NotNull
    private final LanguageDetectorModel model;

    @NotNull
    private final TextTokenizer tokenizer;

    @Nullable
    private final FrequencyDictionary frequencyDictionary;

    /* compiled from: LanguageDetector.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��&\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\b\n��\n\u0002\u0010\u000e\n��\b\u0086\b\u0018��2\u00020\u0001B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003¢\u0006\u0002\u0010\u0005J\t\u0010\t\u001a\u00020\u0003HÆ\u0003J\t\u0010\n\u001a\u00020\u0003HÆ\u0003J\u001d\u0010\u000b\u001a\u00020��2\b\b\u0002\u0010\u0002\u001a\u00020\u00032\b\b\u0002\u0010\u0004\u001a\u00020\u0003HÆ\u0001J\u0013\u0010\f\u001a\u00020\r2\b\u0010\u000e\u001a\u0004\u0018\u00010\u0001HÖ\u0003J\t\u0010\u000f\u001a\u00020\u0010HÖ\u0001J\t\u0010\u0011\u001a\u00020\u0012HÖ\u0001R\u0011\u0010\u0004\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0006\u0010\u0007R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\b\u0010\u0007¨\u0006\u0013"}, d2 = {"Lcom/kotlinnlp/languagedetector/LanguageDetector$TokenClassification;", "", "languages", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "charsImportance", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "getCharsImportance", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getLanguages", "component1", "component2", "copy", "equals", "", "other", "hashCode", "", "toString", "", "languagedetector"})
    /* loaded from: input_file:com/kotlinnlp/languagedetector/LanguageDetector$TokenClassification.class */
    public static final class TokenClassification {

        @NotNull
        private final DenseNDArray languages;

        @NotNull
        private final DenseNDArray charsImportance;

        @NotNull
        public final DenseNDArray getLanguages() {
            return this.languages;
        }

        @NotNull
        public final DenseNDArray getCharsImportance() {
            return this.charsImportance;
        }

        public TokenClassification(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "languages");
            Intrinsics.checkParameterIsNotNull(denseNDArray2, "charsImportance");
            this.languages = denseNDArray;
            this.charsImportance = denseNDArray2;
        }

        @NotNull
        public final DenseNDArray component1() {
            return this.languages;
        }

        @NotNull
        public final DenseNDArray component2() {
            return this.charsImportance;
        }

        @NotNull
        public final TokenClassification copy(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "languages");
            Intrinsics.checkParameterIsNotNull(denseNDArray2, "charsImportance");
            return new TokenClassification(denseNDArray, denseNDArray2);
        }

        @NotNull
        public static /* bridge */ /* synthetic */ TokenClassification copy$default(TokenClassification tokenClassification, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, int i, Object obj) {
            if ((i & 1) != 0) {
                denseNDArray = tokenClassification.languages;
            }
            if ((i & 2) != 0) {
                denseNDArray2 = tokenClassification.charsImportance;
            }
            return tokenClassification.copy(denseNDArray, denseNDArray2);
        }

        public String toString() {
            return "TokenClassification(languages=" + this.languages + ", charsImportance=" + this.charsImportance + ")";
        }

        public int hashCode() {
            DenseNDArray denseNDArray = this.languages;
            int hashCode = (denseNDArray != null ? denseNDArray.hashCode() : 0) * 31;
            DenseNDArray denseNDArray2 = this.charsImportance;
            return hashCode + (denseNDArray2 != null ? denseNDArray2.hashCode() : 0);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof TokenClassification)) {
                return false;
            }
            TokenClassification tokenClassification = (TokenClassification) obj;
            return Intrinsics.areEqual(this.languages, tokenClassification.languages) && Intrinsics.areEqual(this.charsImportance, tokenClassification.charsImportance);
        }
    }

    @NotNull
    public final Language detectLanguage(@NotNull String str) {
        Intrinsics.checkParameterIsNotNull(str, "text");
        return getLanguage(predict(str));
    }

    @NotNull
    public final Language getLanguage(@NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "prediction");
        return denseNDArray.sum() == 0.0d ? Language.Unknown : this.model.getSupportedLanguages().get(NDArray.DefaultImpls.argMaxIndex$default(denseNDArray, 0, 1, (Object) null));
    }

    @NotNull
    public final DenseNDArray predict(@NotNull String str) {
        Intrinsics.checkParameterIsNotNull(str, "text");
        final ArrayList arrayList = new ArrayList();
        forEachToken$languagedetector(str, new Function1<String, Unit>() { // from class: com.kotlinnlp.languagedetector.LanguageDetector$predict$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                invoke((String) obj);
                return Unit.INSTANCE;
            }

            public final void invoke(@NotNull String str2) {
                DenseNDArray freqOf;
                Intrinsics.checkParameterIsNotNull(str2, "token");
                arrayList.add(LanguageDetector.classifyToken$default(LanguageDetector.this, str2, 0.0d, 2, null));
                if (LanguageDetector.this.getFrequencyDictionary() == null || (freqOf = LanguageDetector.this.getFrequencyDictionary().getFreqOf(str2)) == null) {
                    return;
                }
                arrayList.add(freqOf);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }
        });
        return !arrayList.isEmpty() ? combineClassifications(arrayList) : DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getSupportedLanguages().size(), 0, 2, (DefaultConstructorMarker) null));
    }

    @NotNull
    public final List<Pair<String, TokenClassification>> classifyTokens(@NotNull String str) {
        Intrinsics.checkParameterIsNotNull(str, "text");
        final ArrayList arrayList = new ArrayList();
        forEachToken$languagedetector(str, new Function1<String, Unit>() { // from class: com.kotlinnlp.languagedetector.LanguageDetector$classifyTokens$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                invoke((String) obj);
                return Unit.INSTANCE;
            }

            public final void invoke(@NotNull String str2) {
                HANEncoder hANEncoder;
                DenseNDArray freqOf;
                DenseNDArray combineClassifications;
                Intrinsics.checkParameterIsNotNull(str2, "token");
                DenseNDArray classifyToken$default = LanguageDetector.classifyToken$default(LanguageDetector.this, str2, 0.0d, 2, null);
                if (LanguageDetector.this.getFrequencyDictionary() != null && (freqOf = LanguageDetector.this.getFrequencyDictionary().getFreqOf(str2)) != null) {
                    combineClassifications = LanguageDetector.this.combineClassifications(CollectionsKt.listOf(new DenseNDArray[]{classifyToken$default, freqOf}));
                    classifyToken$default = combineClassifications;
                }
                DenseNDArray denseNDArray = classifyToken$default;
                hANEncoder = LanguageDetector.this.encoder;
                HierarchySequence inputImportanceScores = hANEncoder.getInputImportanceScores();
                if (inputImportanceScores == null) {
                    throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchySequence<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray>");
                }
                Object obj = inputImportanceScores.get(0);
                Intrinsics.checkExpressionValueIsNotNull(obj, "(this.encoder.getInputIm…equence<DenseNDArray>)[0]");
                arrayList.add(new Pair(str2, new LanguageDetector.TokenClassification(denseNDArray, (DenseNDArray) obj)));
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }
        });
        return CollectionsKt.toList(arrayList);
    }

    @NotNull
    public final DenseNDArray classifyToken(@NotNull String str, double d) {
        Intrinsics.checkParameterIsNotNull(str, "token");
        if (str.length() > 0) {
            return this.encoder.forward(ExtensionsKt.toHierarchySequence(str, this.model.getEmbeddings(), d)).copy();
        }
        throw new IllegalArgumentException("Empty chars sequence".toString());
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray classifyToken$default(LanguageDetector languageDetector, String str, double d, int i, Object obj) {
        if ((i & 2) != 0) {
            d = 0.0d;
        }
        return languageDetector.classifyToken(str, d);
    }

    public final void backward(@NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "outputErrors");
        this.encoder.backward(denseNDArray);
    }

    @NotNull
    public final HANParameters getParamsErrors(boolean z) {
        return this.encoder.getParamsErrors(z);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ HANParameters getParamsErrors$default(LanguageDetector languageDetector, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return languageDetector.getParamsErrors(z);
    }

    @NotNull
    public final ArrayList<DenseNDArray> getInputSequenceErrors(boolean z) {
        ArrayList<DenseNDArray> inputErrors = this.encoder.getInputErrors(z);
        if (inputErrors == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchySequence<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray>");
        }
        return (HierarchySequence) inputErrors;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ ArrayList getInputSequenceErrors$default(LanguageDetector languageDetector, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return languageDetector.getInputSequenceErrors(z);
    }

    public final void forEachToken$languagedetector(@NotNull String str, @NotNull Function1<? super String, Unit> function1) {
        Intrinsics.checkParameterIsNotNull(str, "text");
        Intrinsics.checkParameterIsNotNull(function1, "callback");
        Iterator<T> it = this.tokenizer.tokenize(str, this.model.getMaxTokensLength()).iterator();
        while (it.hasNext()) {
            function1.invoke((String) it.next());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final DenseNDArray combineClassifications(List<DenseNDArray> list) {
        DenseNDArray copy = list.get(0).copy();
        IntIterator it = RangesKt.until(1, list.size()).iterator();
        while (it.hasNext()) {
            copy.assignProd(list.get(it.nextInt()));
        }
        return copy.assignDiv(copy.sum());
    }

    @NotNull
    public final LanguageDetectorModel getModel() {
        return this.model;
    }

    @NotNull
    public final TextTokenizer getTokenizer() {
        return this.tokenizer;
    }

    @Nullable
    public final FrequencyDictionary getFrequencyDictionary() {
        return this.frequencyDictionary;
    }

    public LanguageDetector(@NotNull LanguageDetectorModel languageDetectorModel, @NotNull TextTokenizer textTokenizer, @Nullable FrequencyDictionary frequencyDictionary) {
        Intrinsics.checkParameterIsNotNull(languageDetectorModel, "model");
        Intrinsics.checkParameterIsNotNull(textTokenizer, "tokenizer");
        this.model = languageDetectorModel;
        this.tokenizer = textTokenizer;
        this.frequencyDictionary = frequencyDictionary;
        this.encoder = new HANEncoder<>(this.model.getHan(), false, true, (Double) null, 0, 24, (DefaultConstructorMarker) null);
    }

    public /* synthetic */ LanguageDetector(LanguageDetectorModel languageDetectorModel, TextTokenizer textTokenizer, FrequencyDictionary frequencyDictionary, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(languageDetectorModel, textTokenizer, (i & 4) != 0 ? (FrequencyDictionary) null : frequencyDictionary);
    }
}
