package org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.class */
public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel {
    public static final ParseField NAME;
    public static final ParseField EMBEDDED_VECTOR_FEATURE_NAME;
    public static final ParseField HIDDEN_LAYER;
    public static final ParseField SOFTMAX_LAYER;
    public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> STRICT_PARSER;
    public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> LENIENT_PARSER;
    private static final List<String> LANGUAGE_NAMES;
    private static final int MISSING_VALID_TXT_CLASSIFICATION;
    private static final String MISSING_VALID_TXT_CLASSIFICATION_STR = "zxx";
    private static final long SHALLOW_SIZE;
    private final LangNetLayer hiddenLayer;
    private final LangNetLayer softmaxLayer;
    private final String embeddedVectorFeatureName;
    static final /* synthetic */ boolean $assertionsDisabled;

    private static ConstructingObjectParser<LangIdentNeuralNetwork, Void> createParser(boolean z) {
        ConstructingObjectParser<LangIdentNeuralNetwork, Void> constructingObjectParser = new ConstructingObjectParser<>(NAME.getPreferredName(), z, objArr -> {
            return new LangIdentNeuralNetwork((String) objArr[0], (LangNetLayer) objArr[1], (LangNetLayer) objArr[2]);
        });
        constructingObjectParser.declareString(ConstructingObjectParser.constructorArg(), EMBEDDED_VECTOR_FEATURE_NAME);
        constructingObjectParser.declareObject(ConstructingObjectParser.constructorArg(), (xContentParser, r6) -> {
            return z ? (LangNetLayer) LangNetLayer.LENIENT_PARSER.apply(xContentParser, r6) : (LangNetLayer) LangNetLayer.STRICT_PARSER.apply(xContentParser, r6);
        }, HIDDEN_LAYER);
        constructingObjectParser.declareObject(ConstructingObjectParser.constructorArg(), (xContentParser2, r62) -> {
            return z ? (LangNetLayer) LangNetLayer.LENIENT_PARSER.apply(xContentParser2, r62) : (LangNetLayer) LangNetLayer.STRICT_PARSER.apply(xContentParser2, r62);
        }, SOFTMAX_LAYER);
        return constructingObjectParser;
    }

    public static LangIdentNeuralNetwork fromXContentStrict(XContentParser xContentParser) {
        return (LangIdentNeuralNetwork) STRICT_PARSER.apply(xContentParser, (Object) null);
    }

    public static LangIdentNeuralNetwork fromXContentLenient(XContentParser xContentParser) {
        return (LangIdentNeuralNetwork) LENIENT_PARSER.apply(xContentParser, (Object) null);
    }

    public LangIdentNeuralNetwork(String str, LangNetLayer langNetLayer, LangNetLayer langNetLayer2) {
        this.embeddedVectorFeatureName = (String) ExceptionsHelper.requireNonNull(str, EMBEDDED_VECTOR_FEATURE_NAME);
        this.hiddenLayer = (LangNetLayer) ExceptionsHelper.requireNonNull(langNetLayer, HIDDEN_LAYER);
        this.softmaxLayer = (LangNetLayer) ExceptionsHelper.requireNonNull(langNetLayer2, SOFTMAX_LAYER);
    }

    public LangIdentNeuralNetwork(StreamInput streamInput) throws IOException {
        this.embeddedVectorFeatureName = streamInput.readString();
        this.hiddenLayer = new LangNetLayer(streamInput);
        this.softmaxLayer = new LangNetLayer(streamInput);
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel
    public InferenceResults infer(Map<String, Object> map, InferenceConfig inferenceConfig, Map<String, String> map2) {
        if (inferenceConfig.requestingImportance()) {
            throw ExceptionsHelper.badRequestException("[{}] model does not supports feature importance", NAME.getPreferredName());
        }
        if (!(inferenceConfig instanceof ClassificationConfig)) {
            throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName());
        }
        Object obj = map.get(this.embeddedVectorFeatureName);
        if (!(obj instanceof List)) {
            throw ExceptionsHelper.badRequestException("[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. Please verify that the input is a string.", NAME.getPreferredName(), this.embeddedVectorFeatureName);
        }
        List list = (List) obj;
        ClassificationConfig classificationConfig = (ClassificationConfig) inferenceConfig;
        if (list.isEmpty()) {
            return new ClassificationInferenceResults(MISSING_VALID_TXT_CLASSIFICATION, MISSING_VALID_TXT_CLASSIFICATION_STR, (List<TopClassEntry>) Collections.emptyList(), (List<ClassificationFeatureImportance>) Collections.emptyList(), (InferenceConfig) classificationConfig, Double.valueOf(1.0d), Double.valueOf(1.0d));
        }
        double[] dArr = new double[LANGUAGE_NAMES.size()];
        int i = 0;
        for (Object obj2 : list) {
            if (obj2 instanceof CustomWordEmbedding.StringLengthAndEmbedding) {
                CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding) obj2;
                int utf8StringLen = stringLengthAndEmbedding.getUtf8StringLen() * stringLengthAndEmbedding.getUtf8StringLen();
                i += utf8StringLen;
                InferenceHelpers.sumDoubleArrays(dArr, Statistics.softMax(this.softmaxLayer.productPlusBias(true, this.hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding()))), Math.max(utf8StringLen, 1));
            }
        }
        if (i != 0) {
            InferenceHelpers.divMut(dArr, i);
        }
        Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> tuple = InferenceHelpers.topClasses(dArr, LANGUAGE_NAMES, null, classificationConfig.getNumTopClasses(), PredictionFieldType.STRING);
        InferenceHelpers.TopClassificationValue topClassificationValue = (InferenceHelpers.TopClassificationValue) tuple.v1();
        if ($assertionsDisabled || (topClassificationValue.getValue() >= 0 && topClassificationValue.getValue() < LANGUAGE_NAMES.size())) {
            return new ClassificationInferenceResults(topClassificationValue.getValue(), LANGUAGE_NAMES.get(topClassificationValue.getValue()), (List<TopClassEntry>) tuple.v2(), (List<ClassificationFeatureImportance>) Collections.emptyList(), (InferenceConfig) classificationConfig, Double.valueOf(topClassificationValue.getProbability()), Double.valueOf(topClassificationValue.getScore()));
        }
        throw new AssertionError("Invalid language predicted. Predicted language index " + String.valueOf(tuple.v1()));
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel
    public InferenceResults infer(double[] dArr, InferenceConfig inferenceConfig) {
        throw new UnsupportedOperationException("[lang_ident] does not support nested inference");
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel
    public void rewriteFeatureIndices(Map<String, Integer> map) {
        if (map != null && !map.isEmpty()) {
            throw new UnsupportedOperationException("[lang_ident] does not support nested inference");
        }
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel
    public String[] getFeatureNames() {
        return new String[]{this.embeddedVectorFeatureName};
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public TargetType targetType() {
        return TargetType.CLASSIFICATION;
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public void validate() {
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public long estimatedNumOperations() {
        return this.hiddenLayer.getBias().length + this.hiddenLayer.getWeights().length + this.softmaxLayer.getBias().length + this.softmaxLayer.getWeights().length;
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel
    public boolean supportsFeatureImportance() {
        return false;
    }

    public long ramBytesUsed() {
        return SHALLOW_SIZE + RamUsageEstimator.sizeOf(this.hiddenLayer) + RamUsageEstimator.sizeOf(this.softmaxLayer);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeString(this.embeddedVectorFeatureName);
        this.hiddenLayer.writeTo(streamOutput);
        this.softmaxLayer.writeTo(streamOutput);
    }

    @Override // org.elasticsearch.xpack.core.ml.utils.NamedXContentObject
    public String getName() {
        return NAME.getPreferredName();
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.field(EMBEDDED_VECTOR_FEATURE_NAME.getPreferredName(), this.embeddedVectorFeatureName);
        xContentBuilder.field(HIDDEN_LAYER.getPreferredName(), this.hiddenLayer);
        xContentBuilder.field(SOFTMAX_LAYER.getPreferredName(), this.softmaxLayer);
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        LangIdentNeuralNetwork langIdentNeuralNetwork = (LangIdentNeuralNetwork) obj;
        return Objects.equals(this.embeddedVectorFeatureName, langIdentNeuralNetwork.embeddedVectorFeatureName) && Objects.equals(this.hiddenLayer, langIdentNeuralNetwork.hiddenLayer) && Objects.equals(this.softmaxLayer, langIdentNeuralNetwork.softmaxLayer);
    }

    public int hashCode() {
        return Objects.hash(this.embeddedVectorFeatureName, this.hiddenLayer, this.softmaxLayer);
    }

    static {
        $assertionsDisabled = !LangIdentNeuralNetwork.class.desiredAssertionStatus();
        NAME = new ParseField("lang_ident_neural_network", new String[0]);
        EMBEDDED_VECTOR_FEATURE_NAME = new ParseField("embedded_vector_feature_name", new String[0]);
        HIDDEN_LAYER = new ParseField("hidden_layer", new String[0]);
        SOFTMAX_LAYER = new ParseField("softmax_layer", new String[0]);
        STRICT_PARSER = createParser(false);
        LENIENT_PARSER = createParser(true);
        LANGUAGE_NAMES = Arrays.asList("eo", "co", "eu", "ta", "de", "mt", "ps", "te", "su", "uz", "zh-Latn", "ne", "nl", "sw", "sq", "hmn", "ja", "no", "mn", "so", "ko", "kk", "sl", "ig", "mr", "th", "zu", "ml", "hr", "bs", "lo", "sd", "cy", "hy", "uk", "pt", "lv", "iw", "cs", "vi", "jv", "be", "km", "mk", "tr", "fy", "am", "zh", "da", "sv", "fi", "ht", "af", "la", StreamingUnifiedChatCompletionResults.ID_FIELD, "fil", "sm", "ca", "el", "ka", "sr", "it", "sk", "ru", "ru-Latn", "bg", "ny", "fa", "haw", "gl", "et", "ms", "gd", "bg-Latn", "ha", "is", "ur", "mi", "hi", "bn", "hi-Latn", "fr", "yi", "hu", "xh", "my", "tg", "ro", "ar", "lb", "el-Latn", "st", "ceb", "kn", "az", "si", "ky", "mg", "en", "gu", "es", "pl", "ja-Latn", "ga", "lt", "sn", "yo", "pa", "ku");
        MISSING_VALID_TXT_CLASSIFICATION = LANGUAGE_NAMES.size() - 1;
        SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LangIdentNeuralNetwork.class);
    }
}
