package org.elasticsearch.xpack.core.ml.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.FieldExpression;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.class */
public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQueryBuilder> {
    public static final String NAME = "sparse_vector";
    public static final String ALLOWED_FIELD_TYPE = "sparse_vector";
    private static final boolean DEFAULT_PRUNE = false;
    private final String fieldName;
    private final List<WeightedToken> queryVectors;
    private final String inferenceId;
    private final String query;
    private final boolean shouldPruneTokens;
    private final SetOnce<TextExpansionResults> weightedTokensSupplier;

    @Nullable
    private final TokenPruningConfig tokenPruningConfig;
    public static final ParseField FIELD_FIELD = new ParseField(FieldExpression.NAME, new String[0]);
    public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
    public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id", new String[0]);
    public static final ParseField QUERY_FIELD = new ParseField("query", new String[0]);
    public static final ParseField PRUNE_FIELD = new ParseField("prune", new String[0]);
    public static final ParseField PRUNING_CONFIG_FIELD = new ParseField("pruning_config", new String[0]);
    private static final ConstructingObjectParser<SparseVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>("sparse_vector", objArr -> {
        return new SparseVectorQueryBuilder((String) objArr[0], parseWeightedTokens((Map) objArr[1]), (String) objArr[2], (String) objArr[3], (Boolean) objArr[4], (TokenPruningConfig) objArr[5]);
    });

    public SparseVectorQueryBuilder(String str, String str2, String str3) {
        this(str, null, str2, str3, false, null);
    }

    public SparseVectorQueryBuilder(String str, @Nullable List<WeightedToken> list, @Nullable String str2, @Nullable String str3, @Nullable Boolean bool, @Nullable TokenPruningConfig tokenPruningConfig) {
        this.fieldName = (String) Objects.requireNonNull(str, "[sparse_vector] requires a [" + FIELD_FIELD.getPreferredName() + "]");
        this.shouldPruneTokens = bool != null ? bool.booleanValue() : false;
        this.queryVectors = list;
        this.inferenceId = str2;
        this.query = str3;
        this.tokenPruningConfig = tokenPruningConfig != null ? tokenPruningConfig : this.shouldPruneTokens ? new TokenPruningConfig() : null;
        this.weightedTokensSupplier = null;
        if (list != null && str2 != null) {
            throw new IllegalArgumentException("[sparse_vector] requires one of [" + QUERY_VECTOR_FIELD.getPreferredName() + "] or [" + INFERENCE_ID_FIELD.getPreferredName() + "] for sparse_vector fields");
        }
        if ((list == null) == (str3 == null)) {
            throw new IllegalArgumentException("[sparse_vector] requires one of [" + QUERY_VECTOR_FIELD.getPreferredName() + "] or [" + INFERENCE_ID_FIELD.getPreferredName() + "] for sparse_vector fields");
        }
    }

    public SparseVectorQueryBuilder(StreamInput streamInput) throws IOException {
        super(streamInput);
        this.fieldName = streamInput.readString();
        this.shouldPruneTokens = streamInput.readBoolean();
        this.queryVectors = streamInput.readOptionalCollectionAsList(WeightedToken::new);
        this.inferenceId = streamInput.readOptionalString();
        this.query = streamInput.readOptionalString();
        this.tokenPruningConfig = (TokenPruningConfig) streamInput.readOptionalWriteable(TokenPruningConfig::new);
        this.weightedTokensSupplier = null;
    }

    private SparseVectorQueryBuilder(SparseVectorQueryBuilder sparseVectorQueryBuilder, SetOnce<TextExpansionResults> setOnce) {
        this.fieldName = sparseVectorQueryBuilder.fieldName;
        this.shouldPruneTokens = sparseVectorQueryBuilder.shouldPruneTokens;
        this.queryVectors = sparseVectorQueryBuilder.queryVectors;
        this.inferenceId = sparseVectorQueryBuilder.inferenceId;
        this.query = sparseVectorQueryBuilder.query;
        this.tokenPruningConfig = sparseVectorQueryBuilder.tokenPruningConfig;
        this.weightedTokensSupplier = setOnce;
    }

    public String getFieldName() {
        return this.fieldName;
    }

    public List<WeightedToken> getQueryVectors() {
        return this.queryVectors;
    }

    public String getInferenceId() {
        return this.inferenceId;
    }

    public String getQuery() {
        return this.query;
    }

    public boolean shouldPruneTokens() {
        return this.shouldPruneTokens;
    }

    public TokenPruningConfig getTokenPruningConfig() {
        return this.tokenPruningConfig;
    }

    protected void doWriteTo(StreamOutput streamOutput) throws IOException {
        if (this.weightedTokensSupplier != null) {
            throw new IllegalStateException("weighted tokens supplier must be null, can't serialize suppliers, missing a rewriteAndFetch?");
        }
        streamOutput.writeString(this.fieldName);
        streamOutput.writeBoolean(this.shouldPruneTokens);
        streamOutput.writeOptionalCollection(this.queryVectors);
        streamOutput.writeOptionalString(this.inferenceId);
        streamOutput.writeOptionalString(this.query);
        streamOutput.writeOptionalWriteable(this.tokenPruningConfig);
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject("sparse_vector");
        xContentBuilder.field(FIELD_FIELD.getPreferredName(), this.fieldName);
        if (this.queryVectors != null) {
            xContentBuilder.startObject(QUERY_VECTOR_FIELD.getPreferredName());
            Iterator<WeightedToken> it = this.queryVectors.iterator();
            while (it.hasNext()) {
                it.next().toXContent(xContentBuilder, params);
            }
            xContentBuilder.endObject();
        } else {
            if (this.inferenceId != null) {
                xContentBuilder.field(INFERENCE_ID_FIELD.getPreferredName(), this.inferenceId);
            }
            xContentBuilder.field(QUERY_FIELD.getPreferredName(), this.query);
        }
        xContentBuilder.field(PRUNE_FIELD.getPreferredName(), this.shouldPruneTokens);
        if (this.tokenPruningConfig != null) {
            xContentBuilder.field(PRUNING_CONFIG_FIELD.getPreferredName(), this.tokenPruningConfig);
        }
        boostAndQueryNameToXContent(xContentBuilder);
        xContentBuilder.endObject();
    }

    protected Query doToQuery(SearchExecutionContext searchExecutionContext) throws IOException {
        if (this.queryVectors == null) {
            return new MatchNoDocsQuery("Empty query vectors");
        }
        MappedFieldType fieldType = searchExecutionContext.getFieldType(this.fieldName);
        if (fieldType == null) {
            return new MatchNoDocsQuery("The \"" + getName() + "\" query is against a field that does not exist");
        }
        String typeName = fieldType.typeName();
        if (typeName.equals("sparse_vector")) {
            return this.shouldPruneTokens ? WeightedTokensUtils.queryBuilderWithPrunedTokens(this.fieldName, this.tokenPruningConfig, this.queryVectors, fieldType, searchExecutionContext) : WeightedTokensUtils.queryBuilderWithAllTokens(this.fieldName, this.queryVectors, fieldType, searchExecutionContext);
        }
        throw new IllegalArgumentException("field [" + this.fieldName + "] must be type [sparse_vector] but is type [" + typeName + "]");
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (this.queryVectors != null) {
            return this;
        }
        if (this.weightedTokensSupplier != null) {
            TextExpansionResults textExpansionResults = (TextExpansionResults) this.weightedTokensSupplier.get();
            return textExpansionResults == null ? this : new SparseVectorQueryBuilder(this.fieldName, textExpansionResults.getWeightedTokens(), null, null, Boolean.valueOf(this.shouldPruneTokens), this.tokenPruningConfig);
        }
        if (this.inferenceId == null) {
            throw new IllegalArgumentException("inference_id required to perform vector search on query string");
        }
        CoordinatedInferenceAction.Request forTextInput = CoordinatedInferenceAction.Request.forTextInput(this.inferenceId, List.of(this.query), TextExpansionConfigUpdate.EMPTY_UPDATE, false, InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API);
        forTextInput.setHighPriority(true);
        forTextInput.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
        SetOnce setOnce = new SetOnce();
        queryRewriteContext.registerAsyncAction((client, actionListener) -> {
            CoordinatedInferenceAction coordinatedInferenceAction = CoordinatedInferenceAction.INSTANCE;
            CheckedConsumer checkedConsumer = response -> {
                List<InferenceResults> inferenceResults = response.getInferenceResults();
                if (inferenceResults.isEmpty()) {
                    actionListener.onFailure(new IllegalStateException("inference response contain no results"));
                    return;
                }
                if (inferenceResults.size() > 1) {
                    actionListener.onFailure(new IllegalStateException("inference response should contain only one result"));
                    return;
                }
                InferenceResults inferenceResults2 = inferenceResults.get(0);
                if (inferenceResults2 instanceof TextExpansionResults) {
                    setOnce.set((TextExpansionResults) inferenceResults2);
                    actionListener.onResponse((Object) null);
                    return;
                }
                InferenceResults inferenceResults3 = inferenceResults.get(0);
                if (inferenceResults3 instanceof WarningInferenceResults) {
                    actionListener.onFailure(new IllegalStateException(((WarningInferenceResults) inferenceResults3).getWarning()));
                } else {
                    actionListener.onFailure(new IllegalArgumentException("expected a result of type [text_expansion_result] received [" + inferenceResults.get(0).getWriteableName() + "]. Is [" + this.inferenceId + "] a compatible model?"));
                }
            };
            Objects.requireNonNull(actionListener);
            ClientHelper.executeAsyncWithOrigin(client, "ml", coordinatedInferenceAction, forTextInput, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
        });
        return new SparseVectorQueryBuilder(this, setOnce);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean doEquals(SparseVectorQueryBuilder sparseVectorQueryBuilder) {
        return Objects.equals(this.fieldName, sparseVectorQueryBuilder.fieldName) && Objects.equals(this.tokenPruningConfig, sparseVectorQueryBuilder.tokenPruningConfig) && Objects.equals(this.queryVectors, sparseVectorQueryBuilder.queryVectors) && Objects.equals(Boolean.valueOf(this.shouldPruneTokens), Boolean.valueOf(sparseVectorQueryBuilder.shouldPruneTokens)) && Objects.equals(this.inferenceId, sparseVectorQueryBuilder.inferenceId) && Objects.equals(this.query, sparseVectorQueryBuilder.query);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.queryVectors, this.tokenPruningConfig, Boolean.valueOf(this.shouldPruneTokens), this.inferenceId, this.query);
    }

    public String getWriteableName() {
        return "sparse_vector";
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_15_0;
    }

    private static List<WeightedToken> parseWeightedTokens(Map<String, Object> map) {
        ArrayList arrayList = null;
        if (map != null) {
            arrayList = new ArrayList();
            for (Map.Entry<String, Object> entry : map.entrySet()) {
                String key = entry.getKey();
                Object value = entry.getValue();
                if (!(value instanceof Number)) {
                    throw new IllegalArgumentException("weight must be a number, was [" + String.valueOf(value) + "]");
                }
                arrayList.add(new WeightedToken(key, ((Number) value).floatValue()));
            }
        }
        return arrayList;
    }

    public static SparseVectorQueryBuilder fromXContent(XContentParser xContentParser) {
        try {
            return (SparseVectorQueryBuilder) PARSER.apply(xContentParser, (Object) null);
        } catch (IllegalArgumentException e) {
            throw new ParsingException(xContentParser.getTokenLocation(), e.getMessage(), e, new Object[0]);
        }
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (xContentParser, r3) -> {
            return xContentParser.map();
        }, QUERY_VECTOR_FIELD);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), INFERENCE_ID_FIELD);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), PRUNE_FIELD);
        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (xContentParser2, r32) -> {
            return TokenPruningConfig.fromXContent(xContentParser2);
        }, PRUNING_CONFIG_FIELD);
        declareStandardFields(PARSER);
    }
}
