package org.elasticsearch.test;

import java.io.IOException;
import java.util.List;
import java.util.function.LongSupplier;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.hamcrest.Matchers;
import org.junit.Before;

/* loaded from: input_file:org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.class */
public abstract class AbstractQueryVectorBuilderTestCase<T extends QueryVectorBuilder> extends AbstractXContentSerializingTestCase<T> {
    private NamedWriteableRegistry namedWriteableRegistry;
    private NamedXContentRegistry namedXContentRegistry;

    /* loaded from: input_file:org/elasticsearch/test/AbstractQueryVectorBuilderTestCase$AssertingClient.class */
    private class AssertingClient extends NoOpClient {
        private final float[] array;
        private final T queryVectorBuilder;

        AssertingClient(ThreadPool threadPool, float[] fArr, T t) {
            super(threadPool);
            this.array = fArr;
            this.queryVectorBuilder = t;
        }

        @Override // org.elasticsearch.test.client.NoOpClient
        protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(ActionType<Response> actionType, Request request, ActionListener<Response> actionListener) {
            AbstractQueryVectorBuilderTestCase.this.doAssertClientRequest(request, this.queryVectorBuilder);
            actionListener.onResponse(AbstractQueryVectorBuilderTestCase.this.createResponse(this.array, this.queryVectorBuilder));
        }
    }

    protected List<SearchPlugin> additionalPlugins() {
        return List.of();
    }

    @Before
    public void registerNamedXContents() {
        SearchModule searchModule = new SearchModule(Settings.EMPTY, additionalPlugins());
        this.namedXContentRegistry = new NamedXContentRegistry(searchModule.getNamedXContents());
        this.namedWriteableRegistry = new NamedWriteableRegistry(searchModule.getNamedWriteables());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.elasticsearch.test.ESTestCase
    public NamedXContentRegistry xContentRegistry() {
        return this.namedXContentRegistry;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.elasticsearch.test.AbstractWireTestCase
    public NamedWriteableRegistry getNamedWriteableRegistry() {
        return this.namedWriteableRegistry;
    }

    protected T createTestInstance(float[] fArr) {
        return (T) createTestInstance();
    }

    protected KnnSearchBuilder parseKnnSearchBuilder(XContentParser xContentParser) throws IOException {
        return KnnSearchBuilder.fromXContent(xContentParser).build(10);
    }

    public final void testKnnSearchBuilderXContent() throws Exception {
        AbstractXContentTestCase.xContentTester((CheckedBiFunction<XContent, BytesReference, XContentParser, IOException>) this::createParser, () -> {
            return new KnnSearchBuilder.Builder().field(randomAlphaOfLength(10)).queryVectorBuilder((QueryVectorBuilder) createTestInstance()).k(5).numCandidates(10).similarity(randomBoolean() ? null : Float.valueOf(randomFloat())).build(10);
        }, getToXContentParams(), this::parseKnnSearchBuilder).test();
    }

    public final void testKnnSearchBuilderWireSerialization() throws IOException {
        for (int i = 0; i < 20; i++) {
            KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(randomAlphaOfLength(10), (QueryVectorBuilder) createTestInstance(), 5, 10, randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : Float.valueOf(randomFloat()));
            knnSearchBuilder.queryName(randomAlphaOfLengthBetween(5, 10));
            KnnSearchBuilder copyWriteable = copyWriteable(knnSearchBuilder, getNamedWriteableRegistry(), KnnSearchBuilder::new, TransportVersion.current());
            assertThat(copyWriteable, Matchers.equalTo(knnSearchBuilder));
            assertNotSame(copyWriteable, knnSearchBuilder);
        }
    }

    public final void testKnnSearchRewrite() throws Exception {
        for (int i = 0; i < 20; i++) {
            float[] randomVector = randomVector(randomIntBetween(10, 1024));
            T createTestInstance = createTestInstance(randomVector);
            KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(randomAlphaOfLength(10), createTestInstance, 5, 10, randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : Float.valueOf(randomFloat()));
            KnnSearchBuilder copyWriteable = copyWriteable(knnSearchBuilder, getNamedWriteableRegistry(), KnnSearchBuilder::new, TransportVersion.current());
            TestThreadPool createThreadPool = createThreadPool(new ExecutorBuilder[0]);
            try {
                QueryRewriteContext queryRewriteContext = new QueryRewriteContext((XContentParserConfiguration) null, new AssertingClient(createThreadPool, randomVector, createTestInstance), (LongSupplier) null);
                PlainActionFuture plainActionFuture = new PlainActionFuture();
                Rewriteable.rewriteAndFetch((KnnSearchBuilder) randomFrom(copyWriteable, knnSearchBuilder), queryRewriteContext, plainActionFuture);
                KnnSearchBuilder knnSearchBuilder2 = (KnnSearchBuilder) plainActionFuture.get();
                assertThat(knnSearchBuilder2.getQueryVector().asFloatVector(), Matchers.equalTo(randomVector));
                assertThat(knnSearchBuilder2.getQueryVectorBuilder(), Matchers.nullValue());
                if (createThreadPool != null) {
                    createThreadPool.close();
                }
            } catch (Throwable th) {
                if (createThreadPool != null) {
                    try {
                        createThreadPool.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
    }

    public final void testVectorFetch() throws Exception {
        float[] randomVector = randomVector(randomIntBetween(10, 1024));
        T createTestInstance = createTestInstance(randomVector);
        TestThreadPool createThreadPool = createThreadPool(new ExecutorBuilder[0]);
        try {
            AssertingClient assertingClient = new AssertingClient(createThreadPool, randomVector, createTestInstance);
            PlainActionFuture plainActionFuture = new PlainActionFuture();
            createTestInstance.buildVector(assertingClient, plainActionFuture);
            assertThat((float[]) plainActionFuture.get(), Matchers.equalTo(randomVector));
            if (createThreadPool != null) {
                createThreadPool.close();
            }
        } catch (Throwable th) {
            if (createThreadPool != null) {
                try {
                    createThreadPool.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    protected abstract void doAssertClientRequest(ActionRequest actionRequest, T t);

    protected abstract ActionResponse createResponse(float[] fArr, T t);

    protected static float[] randomVector(int i) {
        float[] fArr = new float[i];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            fArr[i2] = randomFloat();
        }
        return fArr;
    }
}
