package io.netty.testsuite.transport.socket;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.handler.codec.string.StringEncoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.concurrent.Future;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.io.File;
import java.io.IOException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLEngine;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:io/netty/testsuite/transport/socket/SocketStartTlsTest.class */
public class SocketStartTlsTest extends AbstractSocketTest {
    private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
    private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
    private static final File CERT_FILE;
    private static final File KEY_FILE;
    private static EventExecutorGroup executor;
    private final SslContext serverCtx;
    private final SslContext clientCtx;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/netty/testsuite/transport/socket/SocketStartTlsTest$StartTlsClientHandler.class */
    public static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
        private final SslHandler sslHandler;
        private final boolean autoRead;
        private Future<Channel> handshakeFuture;
        final AtomicReference<Throwable> exception = new AtomicReference<>();

        StartTlsClientHandler(SSLEngine sSLEngine, boolean z) {
            sSLEngine.setUseClientMode(true);
            this.sslHandler = new SslHandler(sSLEngine);
            this.autoRead = z;
        }

        public void channelActive(ChannelHandlerContext channelHandlerContext) throws Exception {
            if (!this.autoRead) {
                channelHandlerContext.read();
            }
            channelHandlerContext.writeAndFlush("StartTlsRequest\n");
        }

        public void channelRead0(ChannelHandlerContext channelHandlerContext, String str) throws Exception {
            if ("StartTlsResponse".equals(str)) {
                channelHandlerContext.pipeline().addAfter("logger", "ssl", this.sslHandler);
                this.handshakeFuture = this.sslHandler.handshakeFuture();
                channelHandlerContext.writeAndFlush("EncryptedRequest\n");
            } else {
                Assert.assertEquals("EncryptedResponse", str);
                Assert.assertNotNull(this.handshakeFuture);
                Assert.assertTrue(this.handshakeFuture.isSuccess());
                channelHandlerContext.close();
            }
        }

        public void channelReadComplete(ChannelHandlerContext channelHandlerContext) throws Exception {
            if (this.autoRead) {
                return;
            }
            channelHandlerContext.read();
        }

        public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) throws Exception {
            if (SocketStartTlsTest.logger.isWarnEnabled()) {
                SocketStartTlsTest.logger.warn("Unexpected exception from the client side", th);
            }
            this.exception.compareAndSet(null, th);
            channelHandlerContext.close();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/netty/testsuite/transport/socket/SocketStartTlsTest$StartTlsServerHandler.class */
    public static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
        private final SslHandler sslHandler;
        private final boolean autoRead;
        volatile Channel channel;
        final AtomicReference<Throwable> exception = new AtomicReference<>();

        StartTlsServerHandler(SSLEngine sSLEngine, boolean z) {
            sSLEngine.setUseClientMode(false);
            this.sslHandler = new SslHandler(sSLEngine, true);
            this.autoRead = z;
        }

        public void channelActive(ChannelHandlerContext channelHandlerContext) throws Exception {
            this.channel = channelHandlerContext.channel();
            if (this.autoRead) {
                return;
            }
            channelHandlerContext.read();
        }

        public void channelRead0(ChannelHandlerContext channelHandlerContext, String str) throws Exception {
            if ("StartTlsRequest".equals(str)) {
                channelHandlerContext.pipeline().addAfter("logger", "ssl", this.sslHandler);
                channelHandlerContext.writeAndFlush("StartTlsResponse\n");
            } else {
                Assert.assertEquals("EncryptedRequest", str);
                channelHandlerContext.writeAndFlush("EncryptedResponse\n");
            }
        }

        public void channelReadComplete(ChannelHandlerContext channelHandlerContext) throws Exception {
            if (this.autoRead) {
                return;
            }
            channelHandlerContext.read();
        }

        public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) throws Exception {
            if (SocketStartTlsTest.logger.isWarnEnabled()) {
                SocketStartTlsTest.logger.warn("Unexpected exception from the server side", th);
            }
            this.exception.compareAndSet(null, th);
            channelHandlerContext.close();
        }
    }

    @Parameterized.Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}")
    public static Collection<Object[]> data() throws Exception {
        ArrayList<SslContext> arrayList = new ArrayList();
        arrayList.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
        if (OpenSsl.isAvailable()) {
            arrayList.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL).build());
            arrayList2.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).trustManager(CERT_FILE).build());
        } else {
            logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
        }
        ArrayList arrayList3 = new ArrayList();
        for (SslContext sslContext : arrayList) {
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                arrayList3.add(new Object[]{sslContext, (SslContext) it.next()});
            }
        }
        return arrayList3;
    }

    @BeforeClass
    public static void createExecutor() {
        executor = new DefaultEventExecutorGroup(2);
    }

    @AfterClass
    public static void shutdownExecutor() throws Exception {
        executor.shutdownGracefully().sync();
    }

    public SocketStartTlsTest(SslContext sslContext, SslContext sslContext2) {
        this.serverCtx = sslContext;
        this.clientCtx = sslContext2;
    }

    @Test(timeout = 30000)
    public void testStartTls() throws Throwable {
        run();
    }

    public void testStartTls(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
        testStartTls(serverBootstrap, bootstrap, true);
    }

    @Test(timeout = 30000)
    public void testStartTlsNotAutoRead() throws Throwable {
        run();
    }

    public void testStartTlsNotAutoRead(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
        testStartTls(serverBootstrap, bootstrap, false);
    }

    private void testStartTls(ServerBootstrap serverBootstrap, Bootstrap bootstrap, boolean z) throws Throwable {
        serverBootstrap.childOption(ChannelOption.AUTO_READ, Boolean.valueOf(z));
        bootstrap.option(ChannelOption.AUTO_READ, Boolean.valueOf(z));
        final EventExecutorGroup eventExecutorGroup = executor;
        SSLEngine newEngine = this.serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
        SSLEngine newEngine2 = this.clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
        final StartTlsServerHandler startTlsServerHandler = new StartTlsServerHandler(newEngine, z);
        final StartTlsClientHandler startTlsClientHandler = new StartTlsClientHandler(newEngine2, z);
        serverBootstrap.childHandler(new ChannelInitializer<Channel>() { // from class: io.netty.testsuite.transport.socket.SocketStartTlsTest.1
            public void initChannel(Channel channel) throws Exception {
                ChannelPipeline pipeline = channel.pipeline();
                pipeline.addLast("logger", new LoggingHandler(SocketStartTlsTest.LOG_LEVEL));
                pipeline.addLast(new ChannelHandler[]{new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder()});
                pipeline.addLast(eventExecutorGroup, new ChannelHandler[]{startTlsServerHandler});
            }
        });
        bootstrap.handler(new ChannelInitializer<Channel>() { // from class: io.netty.testsuite.transport.socket.SocketStartTlsTest.2
            public void initChannel(Channel channel) throws Exception {
                ChannelPipeline pipeline = channel.pipeline();
                pipeline.addLast("logger", new LoggingHandler(SocketStartTlsTest.LOG_LEVEL));
                pipeline.addLast(new ChannelHandler[]{new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder()});
                pipeline.addLast(eventExecutorGroup, new ChannelHandler[]{startTlsClientHandler});
            }
        });
        Channel channel = serverBootstrap.bind().sync().channel();
        Channel channel2 = bootstrap.connect().sync().channel();
        while (channel2.isActive() && startTlsServerHandler.exception.get() == null && startTlsClientHandler.exception.get() == null) {
            try {
                Thread.sleep(50L);
            } catch (InterruptedException e) {
            }
        }
        while (startTlsServerHandler.channel.isActive() && startTlsServerHandler.exception.get() == null && startTlsClientHandler.exception.get() == null) {
            try {
                Thread.sleep(50L);
            } catch (InterruptedException e2) {
            }
        }
        startTlsServerHandler.channel.close().awaitUninterruptibly();
        channel2.close().awaitUninterruptibly();
        channel.close().awaitUninterruptibly();
        if (startTlsServerHandler.exception.get() != null && !(startTlsServerHandler.exception.get() instanceof IOException)) {
            throw startTlsServerHandler.exception.get();
        }
        if (startTlsClientHandler.exception.get() != null && !(startTlsClientHandler.exception.get() instanceof IOException)) {
            throw startTlsClientHandler.exception.get();
        }
        if (startTlsServerHandler.exception.get() != null) {
            throw startTlsServerHandler.exception.get();
        }
        if (startTlsClientHandler.exception.get() != null) {
            throw startTlsClientHandler.exception.get();
        }
    }

    static {
        try {
            SelfSignedCertificate selfSignedCertificate = new SelfSignedCertificate();
            CERT_FILE = selfSignedCertificate.certificate();
            KEY_FILE = selfSignedCertificate.privateKey();
        } catch (CertificateException e) {
            throw new Error(e);
        }
    }
}
