package org.apache.kafka.common.network;

import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
import java.security.Security;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import javax.net.ssl.SSLEngine;
import org.apache.kafka.common.memory.SimpleMemoryPool;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.metrics.Sensor;
import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.ssl.SslFactory;
import org.apache.kafka.common.security.ssl.mock.TestKeyManagerFactory;
import org.apache.kafka.common.security.ssl.mock.TestProviderCreator;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.test.TestSslUtils;
import org.apache.kafka.test.TestUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/kafka/common/network/SslSelectorTest.class */
public abstract class SslSelectorTest extends SelectorTest {
    private Map<String, Object> sslClientConfigs;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/kafka/common/network/SslSelectorTest$TestSslChannelBuilder.class */
    public static class TestSslChannelBuilder extends SslChannelBuilder {

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:org/apache/kafka/common/network/SslSelectorTest$TestSslChannelBuilder$TestSslTransportLayer.class */
        public static class TestSslTransportLayer extends SslTransportLayer {
            static Map<String, TestSslTransportLayer> transportLayers = new HashMap();
            boolean muteSocket;

            public TestSslTransportLayer(String str, SelectionKey selectionKey, SSLEngine sSLEngine, ChannelMetadataRegistry channelMetadataRegistry) {
                super(str, selectionKey, sSLEngine, channelMetadataRegistry);
                this.muteSocket = false;
                transportLayers.put(str, this);
            }

            protected int readFromSocketChannel() throws IOException {
                if (!this.muteSocket) {
                    this.muteSocket = true;
                    return super.readFromSocketChannel();
                }
                if ((selectionKey().interestOps() & 1) == 0) {
                    return 0;
                }
                this.muteSocket = false;
                return 0;
            }

            void truncateReadBuffer() throws Exception {
                netReadBuffer().position(1);
                appReadBuffer().position(0);
                this.muteSocket = true;
            }
        }

        public TestSslChannelBuilder(Mode mode) {
            super(mode, (ListenerName) null, false, new LogContext());
        }

        protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String str, SelectionKey selectionKey, ChannelMetadataRegistry channelMetadataRegistry) throws IOException {
            return new TestSslTransportLayer(str, selectionKey, sslFactory.createSslEngine(((SocketChannel) selectionKey.channel()).socket()), channelMetadataRegistry);
        }
    }

    @Override // org.apache.kafka.common.network.SelectorTest
    @BeforeEach
    public void setUp() throws Exception {
        File tempFile = TestUtils.tempFile("truststore", ".jks");
        this.server = new EchoServer(SecurityProtocol.SSL, TestSslUtils.createSslConfig(false, true, Mode.SERVER, tempFile, "server"));
        this.server.start();
        this.time = new MockTime();
        this.sslClientConfigs = createSslClientConfigs(tempFile);
        LogContext logContext = new LogContext();
        this.channelBuilder = new SslChannelBuilder(Mode.CLIENT, (ListenerName) null, false, logContext);
        this.channelBuilder.configure(this.sslClientConfigs);
        this.metrics = new Metrics();
        this.selector = new Selector(5000L, this.metrics, this.time, "MetricGroup", this.channelBuilder, logContext);
    }

    protected abstract Map<String, Object> createSslClientConfigs(File file) throws GeneralSecurityException, IOException;

    @Override // org.apache.kafka.common.network.SelectorTest
    @AfterEach
    public void tearDown() throws Exception {
        this.selector.close();
        this.server.close();
        this.metrics.close();
    }

    @Override // org.apache.kafka.common.network.SelectorTest
    protected Map<String, Object> clientConfigs() {
        return this.sslClientConfigs;
    }

    @Test
    public void testConnectionWithCustomKeyManager() throws Exception {
        TestProviderCreator testProviderCreator = new TestProviderCreator();
        String randomString = TestUtils.randomString(102400);
        Map<String, Object> createSslConfig = TestSslUtils.createSslConfig("TestAlgorithm", "TestAlgorithm", TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS);
        createSslConfig.put("security.providers", testProviderCreator.getClass().getName());
        EchoServer echoServer = new EchoServer(SecurityProtocol.SSL, createSslConfig);
        echoServer.start();
        MockTime mockTime = new MockTime();
        Map<String, Object> createSslConfig2 = TestSslUtils.createSslConfig(true, true, Mode.CLIENT, new File(TestKeyManagerFactory.TestKeyManager.mockTrustStoreFile), "client");
        TestSslChannelBuilder testSslChannelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
        testSslChannelBuilder.configure(createSslConfig2);
        Metrics metrics = new Metrics();
        Selector selector = new Selector(5000L, metrics, mockTime, "MetricGroup", testSslChannelBuilder, new LogContext());
        selector.connect("0", new InetSocketAddress("localhost", echoServer.port), 4096, 4096);
        NetworkTestUtils.waitForChannelReady(selector, "0");
        selector.send(createSend("0", randomString));
        waitForBytesBuffered(selector, "0");
        TestUtils.waitForCondition(() -> {
            return cipherMetrics(metrics).size() == 1;
        }, "Waiting for cipher metrics to be created.");
        Assertions.assertEquals(1, cipherMetrics(metrics).get(0).metricValue());
        Assertions.assertNotNull(selector.channel("0").channelMetadataRegistry().cipherInformation());
        selector.close("0");
        super.verifySelectorEmpty(selector);
        Assertions.assertEquals(1, cipherMetrics(metrics).size());
        Assertions.assertEquals(0, cipherMetrics(metrics).get(0).metricValue());
        Security.removeProvider(testProviderCreator.getProvider().getName());
        selector.close();
        echoServer.close();
        metrics.close();
    }

    @Test
    public void testDisconnectWithIntermediateBufferedBytes() throws Exception {
        String randomString = TestUtils.randomString(102400);
        this.selector.close();
        this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
        this.channelBuilder.configure(this.sslClientConfigs);
        this.selector = new Selector(5000L, this.metrics, this.time, "MetricGroup", this.channelBuilder, new LogContext());
        connect("0", new InetSocketAddress("localhost", this.server.port));
        this.selector.send(createSend("0", randomString));
        waitForBytesBuffered(this.selector, "0");
        this.selector.close("0");
        verifySelectorEmpty();
    }

    private void waitForBytesBuffered(Selector selector, String str) throws Exception {
        TestUtils.waitForCondition(() -> {
            try {
                selector.poll(0L);
                return selector.channel(str).hasBytesBuffered();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }, 2000L, "Failed to reach socket state with bytes buffered");
    }

    @Test
    public void testBytesBufferedChannelWithNoIncomingBytes() throws Exception {
        verifyNoUnnecessaryPollWithBytesBuffered(selectionKey -> {
            selectionKey.interestOps(selectionKey.interestOps() & (-2));
        });
    }

    @Test
    public void testBytesBufferedChannelAfterMute() throws Exception {
        verifyNoUnnecessaryPollWithBytesBuffered(selectionKey -> {
            ((KafkaChannel) selectionKey.attachment()).mute();
        });
    }

    private void verifyNoUnnecessaryPollWithBytesBuffered(Consumer<SelectionKey> consumer) throws Exception {
        this.selector.close();
        final String str = "1";
        final AtomicInteger atomicInteger = new AtomicInteger();
        this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
        this.channelBuilder.configure(this.sslClientConfigs);
        this.selector = new Selector(5000L, this.metrics, this.time, "MetricGroup", this.channelBuilder, new LogContext()) { // from class: org.apache.kafka.common.network.SslSelectorTest.1
            void pollSelectionKeys(Set<SelectionKey> set, boolean z, long j) {
                Iterator<SelectionKey> it = set.iterator();
                while (it.hasNext()) {
                    KafkaChannel kafkaChannel = (KafkaChannel) it.next().attachment();
                    if (kafkaChannel != null && kafkaChannel.id().equals(str)) {
                        atomicInteger.incrementAndGet();
                    }
                }
                super.pollSelectionKeys(set, z, j);
            }
        };
        connect("1", new InetSocketAddress("localhost", this.server.port));
        this.selector.send(createSend("1", TestUtils.randomString(102400)));
        waitForBytesBuffered(this.selector, "1");
        TestSslChannelBuilder.TestSslTransportLayer.transportLayers.get("1").truncateReadBuffer();
        consumer.accept(this.selector.channel("1").selectionKey());
        atomicInteger.set(0);
        connect("2", new InetSocketAddress("localhost", this.server.port));
        int i = 0;
        String randomString = TestUtils.randomString(10);
        this.selector.send(createSend("2", randomString));
        while (i < 100) {
            i += this.selector.completedReceives().size();
            if (!this.selector.completedSends().isEmpty()) {
                this.selector.send(createSend("2", randomString));
            }
            this.selector.poll(5L);
        }
        Assertions.assertEquals(1, atomicInteger.get());
        this.selector.close("1");
        this.selector.close("2");
        verifySelectorEmpty();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v80, types: [java.util.Collection] */
    @Override // org.apache.kafka.common.network.SelectorTest
    @Test
    public void testMuteOnOOM() throws Exception {
        this.selector.close();
        SimpleMemoryPool simpleMemoryPool = new SimpleMemoryPool(900L, 900, false, (Sensor) null);
        Map<String, Object> build = new TestSslUtils.SslConfigsBuilder(Mode.SERVER).tlsProtocol("TLSv1.2").createNewTrustStore(TestUtils.tempFile("truststore", ".jks")).build();
        this.channelBuilder = new SslChannelBuilder(Mode.SERVER, (ListenerName) null, false, new LogContext());
        this.channelBuilder.configure(build);
        this.selector = new Selector(-1, 5000L, this.metrics, this.time, "MetricGroup", new HashMap(), true, false, this.channelBuilder, simpleMemoryPool, new LogContext());
        ServerSocketChannel open = ServerSocketChannel.open();
        try {
            open.bind((SocketAddress) new InetSocketAddress(0));
            InetSocketAddress inetSocketAddress = (InetSocketAddress) open.getLocalAddress();
            SslSender createSender = createSender("TLSv1.2", inetSocketAddress, randomPayload(900));
            SslSender createSender2 = createSender("TLSv1.2", inetSocketAddress, randomPayload(900));
            createSender.start();
            createSender2.start();
            SocketChannel accept = open.accept();
            accept.configureBlocking(false);
            SocketChannel accept2 = open.accept();
            accept2.configureBlocking(false);
            this.selector.register("clientX", accept);
            this.selector.register("clientY", accept2);
            boolean z = false;
            NetworkReceive networkReceive = null;
            long currentTimeMillis = System.currentTimeMillis() + 5000;
            while (System.currentTimeMillis() < currentTimeMillis) {
                this.selector.poll(10L);
                Collection completedReceives = this.selector.completedReceives();
                if (networkReceive != null) {
                    Assertions.assertTrue(completedReceives.isEmpty(), "only expecting single request");
                } else if (!completedReceives.isEmpty()) {
                    Assertions.assertEquals(1, completedReceives.size(), "expecting a single request");
                    networkReceive = (NetworkReceive) completedReceives.iterator().next();
                    Assertions.assertTrue(this.selector.isMadeReadProgressLastPoll());
                    Assertions.assertEquals(0L, simpleMemoryPool.availableMemory());
                }
                z = createSender.waitForHandshake(1L) && createSender2.waitForHandshake(1L);
                if (z && networkReceive != null && this.selector.isOutOfMemory()) {
                    break;
                }
            }
            Assertions.assertTrue(z, "could not initiate connections within timeout");
            this.selector.poll(10L);
            Assertions.assertTrue(this.selector.completedReceives().isEmpty());
            Assertions.assertEquals(0L, simpleMemoryPool.availableMemory());
            Assertions.assertNotNull(networkReceive, "First receive not complete");
            Assertions.assertTrue(this.selector.isOutOfMemory(), "Selector not out of memory");
            networkReceive.close();
            Assertions.assertEquals(900L, simpleMemoryPool.availableMemory());
            List emptyList = Collections.emptyList();
            long currentTimeMillis2 = System.currentTimeMillis() + 5000;
            while (System.currentTimeMillis() < currentTimeMillis2 && emptyList.isEmpty()) {
                this.selector.poll(1000L);
                emptyList = this.selector.completedReceives();
            }
            Assertions.assertEquals(1, emptyList.size(), "could not read remaining request within timeout");
            Assertions.assertEquals(0L, simpleMemoryPool.availableMemory());
            Assertions.assertFalse(this.selector.isOutOfMemory());
            if (open != null) {
                open.close();
            }
        } catch (Throwable th) {
            if (open != null) {
                try {
                    open.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // org.apache.kafka.common.network.SelectorTest
    protected void connect(String str, InetSocketAddress inetSocketAddress) throws IOException {
        blockingConnect(str, inetSocketAddress);
    }

    private SslSender createSender(String str, InetSocketAddress inetSocketAddress, byte[] bArr) {
        return new SslSender(str, inetSocketAddress, bArr);
    }
}
