package org.elasticsearch.xpack.core.ml.search;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.class */
public final class WeightedTokensUtils {
    private WeightedTokensUtils() {
    }

    public static Query queryBuilderWithAllTokens(String str, List<WeightedToken> list, MappedFieldType mappedFieldType, SearchExecutionContext searchExecutionContext) {
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        for (WeightedToken weightedToken : list) {
            builder.add(new BoostQuery(mappedFieldType.termQuery(weightedToken.token(), searchExecutionContext), weightedToken.weight()), BooleanClause.Occur.SHOULD);
        }
        return new SparseVectorQueryWrapper(str, builder.setMinimumNumberShouldMatch(1).build());
    }

    public static Query queryBuilderWithPrunedTokens(String str, TokenPruningConfig tokenPruningConfig, List<WeightedToken> list, MappedFieldType mappedFieldType, SearchExecutionContext searchExecutionContext) throws IOException {
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        int docCount = searchExecutionContext.getIndexReader().getDocCount(str);
        float floatValue = ((Float) list.stream().map((v0) -> {
            return v0.weight();
        }).reduce(Float.valueOf(0.0f), (v0, v1) -> {
            return Math.max(v0, v1);
        })).floatValue();
        float averageTokenFreqRatio = getAverageTokenFreqRatio(str, searchExecutionContext.getIndexReader(), docCount);
        if (averageTokenFreqRatio == 0.0f) {
            return new MatchNoDocsQuery("query is against an empty field");
        }
        for (WeightedToken weightedToken : list) {
            if (shouldKeepToken(str, tokenPruningConfig, searchExecutionContext.getIndexReader(), weightedToken, docCount, averageTokenFreqRatio, floatValue) ^ (tokenPruningConfig != null && tokenPruningConfig.isOnlyScorePrunedTokens())) {
                builder.add(new BoostQuery(mappedFieldType.termQuery(weightedToken.token(), searchExecutionContext), weightedToken.weight()), BooleanClause.Occur.SHOULD);
            }
        }
        return new SparseVectorQueryWrapper(str, builder.setMinimumNumberShouldMatch(1).build());
    }

    private static float getAverageTokenFreqRatio(String str, IndexReader indexReader, int i) throws IOException {
        int i2 = 0;
        Iterator it = indexReader.getContext().leaves().iterator();
        while (it.hasNext()) {
            Terms terms = ((LeafReaderContext) it.next()).reader().terms(str);
            if (terms != null) {
                i2 = (int) Math.max(terms.size(), i2);
            }
        }
        if (i2 == 0) {
            return 0.0f;
        }
        return (((float) indexReader.getSumDocFreq(str)) / i) / i2;
    }

    private static boolean shouldKeepToken(String str, TokenPruningConfig tokenPruningConfig, IndexReader indexReader, WeightedToken weightedToken, int i, float f, float f2) throws IOException {
        if (tokenPruningConfig == null) {
            return true;
        }
        int docFreq = indexReader.docFreq(new Term(str, weightedToken.token()));
        if (docFreq == 0) {
            return false;
        }
        return ((float) docFreq) / ((float) i) < tokenPruningConfig.getTokensFreqRatioThreshold() * f || weightedToken.weight() > tokenPruningConfig.getTokensWeightThreshold() * f2;
    }
}
