package io.goodforgod.testcontainers.extensions;

import io.goodforgod.testcontainers.extensions.ContainerMetadata;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.jetbrains.annotations.ApiStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionConfigurationException;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.junit.platform.commons.support.AnnotationSupport;
import org.junit.platform.commons.util.ReflectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.containers.GenericContainer;

@ApiStatus.Internal
/* loaded from: input_file:io/goodforgod/testcontainers/extensions/AbstractTestcontainersExtension.class */
public abstract class AbstractTestcontainersExtension<Connection, Container extends GenericContainer<?>, Metadata extends ContainerMetadata> implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback, ParameterResolver {
    static final Map<String, Map<SharedKey, ContainerContext<?>>> CLASS_TO_SHARED_CONTAINERS = new ConcurrentHashMap();
    protected final Logger logger = LoggerFactory.getLogger(getClass());

    /* loaded from: input_file:io/goodforgod/testcontainers/extensions/AbstractTestcontainersExtension$CallMode.class */
    public enum CallMode {
        CONSTRUCTOR,
        BEFORE_EACH,
        BEFORE_ALL
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/goodforgod/testcontainers/extensions/AbstractTestcontainersExtension$SharedContainerInstance.class */
    public static final class SharedContainerInstance implements SharedKey {
        private final GenericContainer<?> container;

        public SharedContainerInstance(GenericContainer<?> genericContainer) {
            this.container = genericContainer;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && getClass() == obj.getClass() && this.container == ((SharedContainerInstance) obj).container;
        }

        public int hashCode() {
            return Objects.hash(this.container);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/goodforgod/testcontainers/extensions/AbstractTestcontainersExtension$SharedContainerKey.class */
    public static final class SharedContainerKey implements SharedKey {
        private final String image;
        private final boolean network;
        private final String alias;

        SharedContainerKey(String str, boolean z, String str2) {
            this.image = str;
            this.network = z;
            this.alias = str2;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            SharedContainerKey sharedContainerKey = (SharedContainerKey) obj;
            return this.network == sharedContainerKey.network && Objects.equals(this.image, sharedContainerKey.image) && Objects.equals(this.alias, sharedContainerKey.alias);
        }

        public int hashCode() {
            return Objects.hash(this.image, Boolean.valueOf(this.network), this.alias);
        }

        public String toString() {
            return this.alias == null ? "[image=" + this.image + "]" : "[image=" + this.image + ", alias=" + this.alias + "]";
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/goodforgod/testcontainers/extensions/AbstractTestcontainersExtension$SharedKey.class */
    public interface SharedKey {
    }

    protected abstract Class<? extends Annotation> getContainerAnnotation();

    protected abstract Class<? extends Annotation> getConnectionAnnotation();

    protected abstract Class<Connection> getConnectionType();

    protected abstract Class<Container> getContainerType();

    protected abstract Optional<Metadata> findMetadata(ExtensionContext extensionContext);

    protected final Metadata getMetadata(ExtensionContext extensionContext) {
        return findMetadata(extensionContext).orElseThrow(() -> {
            return new ExtensionConfigurationException("Extension annotation not found");
        });
    }

    protected abstract ExtensionContext.Namespace getNamespace();

    protected abstract Container createContainerDefault(Metadata metadata);

    protected abstract ContainerContext<Connection> createContainerContext(Container container);

    protected final ExtensionContext.Store getStorage(ExtensionContext extensionContext) {
        return (extensionContext.getParent().isPresent() && ((ExtensionContext) extensionContext.getParent().get()).getParent().isPresent()) ? ((ExtensionContext) extensionContext.getParent().get()).getStore(getNamespace()) : extensionContext.getParent().isEmpty() ? extensionContext.getStore(getNamespace()) : extensionContext.getStore(getNamespace());
    }

    protected ContainerContext<Connection> getContainerContext(ExtensionContext extensionContext) {
        return (ContainerContext) getStorage(extensionContext).get(getMetadata(extensionContext).runMode(), ContainerContext.class);
    }

    protected <T extends Annotation> Optional<T> findAnnotation(Class<T> cls, ExtensionContext extensionContext) {
        Optional of = Optional.of(extensionContext);
        while (true) {
            Optional optional = of;
            if (!optional.isPresent()) {
                return Optional.empty();
            }
            Class requiredTestClass = ((ExtensionContext) optional.get()).getRequiredTestClass();
            while (true) {
                Class cls2 = requiredTestClass;
                if (!cls2.equals(Object.class)) {
                    Optional<T> findAnnotation = AnnotationSupport.findAnnotation(cls2, cls);
                    if (findAnnotation.isPresent()) {
                        return findAnnotation;
                    }
                    requiredTestClass = cls2.getSuperclass();
                }
            }
            of = ((ExtensionContext) optional.get()).getParent();
        }
    }

    protected Optional<Container> findContainerFromField(ExtensionContext extensionContext) {
        this.logger.debug("Looking for {} Container...", getContainerType().getSimpleName());
        if (extensionContext.getTestClass().isEmpty() || extensionContext.getTestInstance().isEmpty()) {
            return Optional.empty();
        }
        Optional<Container> findContainerInClassField = findContainerInClassField(extensionContext.getTestInstance().get());
        return findContainerInClassField.isPresent() ? findContainerInClassField : extensionContext.getTestClass().filter(cls -> {
            return cls.isAnnotationPresent(Nested.class);
        }).isPresent() ? (Optional<Container>) findParentTestClassIfNested(extensionContext).flatMap(this::findContainerInClassField) : Optional.empty();
    }

    private static Optional<Object> findParentTestClassIfNested(ExtensionContext extensionContext) {
        return extensionContext.getTestClass().filter(cls -> {
            return cls.isAnnotationPresent(Nested.class);
        }).isPresent() ? extensionContext.getTestInstance().flatMap(obj -> {
            return findParentTestClass(obj.getClass(), extensionContext).flatMap(cls2 -> {
                return Arrays.stream(obj.getClass().getDeclaredFields()).filter(field -> {
                    return field.getType().equals(cls2);
                }).findFirst().map(field2 -> {
                    try {
                        field2.setAccessible(true);
                        return field2.get(obj);
                    } catch (IllegalAccessException e) {
                        throw new IllegalStateException(e);
                    }
                });
            });
        }) : Optional.empty();
    }

    private Optional<Container> findContainerInClassField(Object obj) {
        return ReflectionUtils.findFields(obj.getClass(), field -> {
            return (field.isSynthetic() || field.getAnnotation(getContainerAnnotation()) == null) ? false : true;
        }, ReflectionUtils.HierarchyTraversalMode.TOP_DOWN).stream().findFirst().map(field2 -> {
            try {
                field2.setAccessible(true);
                Object obj2 = field2.get(obj);
                if (!getContainerType().isAssignableFrom(obj2.getClass())) {
                    throw new IllegalArgumentException(String.format("Field '%s' annotated with @%s value must be instance of %s", field2.getName(), getContainerAnnotation().getSimpleName(), getContainerType()));
                }
                this.logger.debug("Found {} Container in field: {}", getContainerType().getSimpleName(), field2.getName());
                return (GenericContainer) obj2;
            } catch (IllegalAccessException e) {
                throw new IllegalStateException(String.format("Failed retrieving value from field '%s' annotated with @%s", field2.getName(), getContainerAnnotation().getSimpleName()), e);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Optional<Class<?>> findParentTestClass(Class<?> cls, ExtensionContext extensionContext) {
        return extensionContext.getTestClass().filter(cls2 -> {
            return !cls2.equals(cls);
        }).or(() -> {
            return extensionContext.getParent().flatMap(extensionContext2 -> {
                return findParentTestClass(cls, extensionContext2);
            });
        });
    }

    protected void injectContext(ContainerContext<Connection> containerContext, ExtensionContext extensionContext) {
        extensionContext.getTestInstance().ifPresent(obj -> {
            injectContextIntoInstance(containerContext, obj);
        });
        if (extensionContext.getTestClass().filter(cls -> {
            return cls.isAnnotationPresent(Nested.class);
        }).isPresent()) {
            findParentTestClassIfNested(extensionContext).ifPresent(obj2 -> {
                injectContextIntoInstance(containerContext, obj2);
            });
        }
    }

    protected void injectContextIntoInstance(ContainerContext<Connection> containerContext, Object obj) {
        Class<? extends Annotation> connectionAnnotation = getConnectionAnnotation();
        List findFields = ReflectionUtils.findFields(obj.getClass(), field -> {
            return (field.isSynthetic() || Modifier.isFinal(field.getModifiers()) || Modifier.isStatic(field.getModifiers()) || field.getAnnotation(connectionAnnotation) == null) ? false : true;
        }, ReflectionUtils.HierarchyTraversalMode.TOP_DOWN);
        this.logger.debug("Starting field injection for connection: {}", containerContext.connection());
        Iterator it = findFields.iterator();
        while (it.hasNext()) {
            injectContextIntoField(containerContext, (Field) it.next(), obj);
        }
    }

    protected void injectContextIntoField(ContainerContext<Connection> containerContext, Field field, Object obj) {
        try {
            field.setAccessible(true);
            field.set(obj, containerContext.connection());
        } catch (IllegalAccessException e) {
            throw new IllegalStateException(String.format("Field '%s' annotated with @%s can't set connection", field.getName(), getConnectionAnnotation().getSimpleName()), e);
        }
    }

    public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
        Class<? extends Annotation> connectionAnnotation = getConnectionAnnotation();
        if (!(parameterContext.getParameter().getAnnotation(connectionAnnotation) != null)) {
            return false;
        }
        if (parameterContext.getParameter().getType().equals(getConnectionType())) {
            return true;
        }
        throw new ParameterResolutionException(String.format("Parameter '%s' annotated @%s is not of type %s", parameterContext.getParameter().getName(), connectionAnnotation.getSimpleName(), getConnectionType()));
    }

    public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
        CallMode callMode = getCallMode(parameterContext);
        ContainerContext<Connection> containerContext = getContainerContext(extensionContext);
        if (containerContext != null) {
            return containerContext.connection();
        }
        Metadata metadata = getMetadata(extensionContext);
        if (metadata.runMode() == ContainerMode.PER_RUN || metadata.runMode() == ContainerMode.PER_CLASS) {
            beforeAll(extensionContext);
        } else if (metadata.runMode() == ContainerMode.PER_METHOD) {
            TestInstance.Lifecycle lifecycle = (TestInstance.Lifecycle) extensionContext.getTestInstanceLifecycle().orElse(TestInstance.Lifecycle.PER_METHOD);
            if (callMode == CallMode.CONSTRUCTOR && lifecycle == TestInstance.Lifecycle.PER_CLASS) {
                throw new ParameterResolutionException(String.format("@%s can't be injected into constructor parameter when ContainerMode.%s is used and lifecycle is @%s", getConnectionAnnotation().getSimpleName(), ContainerMode.PER_METHOD, TestInstance.Lifecycle.PER_CLASS));
            }
            if (callMode == CallMode.BEFORE_ALL) {
                throw new ParameterResolutionException(String.format("@%s can't be injected into @%s method parameter when ContainerMode.%s is used", getConnectionAnnotation().getSimpleName(), BeforeAll.class.getSimpleName(), ContainerMode.PER_METHOD));
            }
            beforeEach(extensionContext);
        }
        return Optional.ofNullable(getContainerContext(extensionContext).connection()).orElseThrow(() -> {
            return new ParameterResolutionException(String.format("Parameter named '%s' with type '%s' can't be resolved cause it probably isn't initialized yet, please check extension annotation execution order", parameterContext.getParameter().getName(), getConnectionType()));
        });
    }

    private CallMode getCallMode(ParameterContext parameterContext) {
        if (parameterContext.getDeclaringExecutable().isAnnotationPresent(BeforeAll.class)) {
            return CallMode.BEFORE_ALL;
        }
        if (parameterContext.getDeclaringExecutable().isAnnotationPresent(BeforeEach.class)) {
            return CallMode.BEFORE_EACH;
        }
        if (parameterContext.getDeclaringExecutable().getDeclaringClass().getName().equals(parameterContext.getDeclaringExecutable().getName())) {
            return CallMode.CONSTRUCTOR;
        }
        return null;
    }

    public void beforeAll(ExtensionContext extensionContext) {
        Metadata metadata = getMetadata(extensionContext);
        ExtensionContext.Store storage = getStorage(extensionContext);
        if (getContainerContext(extensionContext) == null) {
            if (metadata.runMode() == ContainerMode.PER_RUN) {
                Optional<Container> findContainerFromField = findContainerFromField(extensionContext);
                ContainerContext<Connection> containerContext = (ContainerContext) CLASS_TO_SHARED_CONTAINERS.computeIfAbsent(getClass().getCanonicalName(), str -> {
                    return new ConcurrentHashMap();
                }).computeIfAbsent((SharedKey) findContainerFromField.map(genericContainer -> {
                    return new SharedContainerInstance(genericContainer);
                }).orElseGet(() -> {
                    String image = metadata.image();
                    Boolean bool = (Boolean) findContainerFromField.filter(genericContainer2 -> {
                        return genericContainer2.getNetwork() != null;
                    }).map(genericContainer3 -> {
                        return Boolean.valueOf(genericContainer3.getNetwork() == org.testcontainers.containers.Network.SHARED);
                    }).orElse(Boolean.valueOf(metadata.networkShared()));
                    return new SharedContainerKey(image, bool.booleanValue(), (String) findContainerFromField.map(genericContainer4 -> {
                        return genericContainer4.getNetworkAliases();
                    }).filter(list -> {
                        return !list.isEmpty();
                    }).map(list2 -> {
                        return (String) list2.stream().filter(str2 -> {
                            return str2.equals(metadata.networkAlias());
                        }).findFirst().orElse((String) list2.get(0));
                    }).orElse(metadata.networkAlias()));
                }), sharedKey -> {
                    GenericContainer genericContainer2 = (GenericContainer) findContainerFromField.orElseGet(() -> {
                        this.logger.debug("Getting default container for image: {}", metadata.image());
                        return createContainerDefault(metadata);
                    });
                    genericContainer2.withReuse(true);
                    ContainerContext createContainerContext = createContainerContext(genericContainer2);
                    this.logger.debug("Starting in mode '{}' container: {}", metadata.runMode(), createContainerContext);
                    createContainerContext.start();
                    this.logger.info("Started in mode '{}' container: {}", metadata.runMode(), createContainerContext);
                    return createContainerContext;
                });
                storage.put(metadata.runMode(), containerContext);
                injectContext(containerContext, extensionContext);
                return;
            }
            if (metadata.runMode() == ContainerMode.PER_CLASS) {
                ContainerContext<Connection> createContainerContext = createContainerContext(findContainerFromField(extensionContext).orElseGet(() -> {
                    this.logger.debug("Getting default container for image: {}", metadata.image());
                    return createContainerDefault(metadata);
                }));
                this.logger.debug("Starting in mode '{}' container: {}", metadata.runMode(), createContainerContext);
                createContainerContext.start();
                this.logger.info("Started in mode '{}' container: {}", metadata.runMode(), createContainerContext);
                storage.put(metadata.runMode(), createContainerContext);
                injectContext(createContainerContext, extensionContext);
            }
        }
    }

    public void beforeEach(ExtensionContext extensionContext) {
        ContainerContext<Connection> containerContext;
        Metadata metadata = getMetadata(extensionContext);
        ExtensionContext.Store storage = getStorage(extensionContext);
        if (getContainerContext(extensionContext) == null && metadata.runMode() == ContainerMode.PER_METHOD) {
            Container orElseGet = findContainerFromField(extensionContext).orElseGet(() -> {
                this.logger.debug("Getting default container for image: {}", metadata.image());
                return createContainerDefault(metadata);
            });
            ContainerContext<Connection> createContainerContext = createContainerContext(orElseGet);
            this.logger.debug("Starting in mode '{}' container: {}", metadata.runMode(), createContainerContext);
            orElseGet.start();
            this.logger.info("Started in mode '{}' container: {}", metadata.runMode(), createContainerContext);
            storage.put(metadata.runMode(), createContainerContext);
        }
        if (((TestInstance.Lifecycle) extensionContext.getTestInstanceLifecycle().orElse(TestInstance.Lifecycle.PER_METHOD)) == TestInstance.Lifecycle.PER_METHOD) {
            ContainerContext<Connection> containerContext2 = getContainerContext(extensionContext);
            if (containerContext2 != null) {
                injectContext(containerContext2, extensionContext);
                return;
            }
            return;
        }
        if (metadata.runMode() != ContainerMode.PER_METHOD || (containerContext = getContainerContext(extensionContext)) == null) {
            return;
        }
        injectContext(containerContext, extensionContext);
    }

    public void afterEach(ExtensionContext extensionContext) {
        Metadata metadata = getMetadata(extensionContext);
        if (metadata.runMode() == ContainerMode.PER_METHOD) {
            ExtensionContext.Store storage = getStorage(extensionContext);
            ContainerContext<Connection> containerContext = getContainerContext(extensionContext);
            if (containerContext != null) {
                this.logger.debug("Stopping in mode '{}' container: {}", metadata.runMode(), containerContext);
                containerContext.stop();
                this.logger.info("Stopped in mode '{}' container: {}", metadata.runMode(), containerContext);
                storage.remove(getConnectionType());
                storage.remove(metadata.runMode());
            }
        }
    }

    public void afterAll(ExtensionContext extensionContext) {
        ContainerContext<Connection> containerContext;
        Metadata metadata = getMetadata(extensionContext);
        if (metadata.runMode() != ContainerMode.PER_CLASS || (containerContext = getContainerContext(extensionContext)) == null) {
            return;
        }
        this.logger.debug("Stopping in mode '{}' container: {}", metadata.runMode(), containerContext);
        containerContext.stop();
        this.logger.info("Stopped in mode '{}' container: {}", metadata.runMode(), containerContext);
    }
}
