package tech.ytsaurus.spyt.patch;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.instrument.ClassFileTransformer;
import java.security.ProtectionDomain;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javassist.ClassMap;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.CtNewMethod;
import javassist.bytecode.ClassFile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.math.Ordering;
import tech.ytsaurus.spyt.SparkVersionUtils;
import tech.ytsaurus.spyt.patch.annotations.AddInterfaces;
import tech.ytsaurus.spyt.patch.annotations.Applicability;
import tech.ytsaurus.spyt.patch.annotations.Decorate;
import tech.ytsaurus.spyt.patch.annotations.DecoratedMethod;
import tech.ytsaurus.spyt.patch.annotations.OriginClass;
import tech.ytsaurus.spyt.patch.annotations.PatchSource;
import tech.ytsaurus.spyt.patch.annotations.Subclass;

/* JADX INFO: Access modifiers changed from: package-private */
/* compiled from: SparkPatchAgent.java */
/* loaded from: input_file:tech/ytsaurus/spyt/patch/SparkPatchClassTransformer.class */
public class SparkPatchClassTransformer implements ClassFileTransformer {
    private static final Logger log = LoggerFactory.getLogger(SparkPatchAgent.class);
    private final Map<String, String> classMappings;
    private final Map<String, String> patchedClasses;

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Optional<String[]> toOriginClassName(String str) {
        try {
            String substring = str.substring(0, str.length() - 6);
            Optional<ClassFile> loadClassFile = loadClassFile(str);
            if (loadClassFile.isEmpty()) {
                return Optional.empty();
            }
            CtClass makeClass = ClassPool.getDefault().makeClass(loadClassFile.get());
            String originClass = getOriginClass(makeClass);
            boolean z = !makeClass.hasAnnotation(Applicability.class) || checkApplicability((Applicability) makeClass.getAnnotation(Applicability.class));
            PatchSource patchSource = (PatchSource) makeClass.getAnnotation(PatchSource.class);
            if (patchSource != null) {
                substring = patchSource.value().replace('.', File.separatorChar);
            }
            return (originClass == null || !z) ? Optional.empty() : Optional.of(new String[]{substring, originClass.replace('.', File.separatorChar)});
        } catch (IOException | ClassNotFoundException e) {
            throw new SparkPatchException(e);
        }
    }

    static String getOriginClass(CtClass ctClass) throws ClassNotFoundException {
        OriginClass originClass = (OriginClass) ctClass.getAnnotation(OriginClass.class);
        if (originClass == null) {
            return null;
        }
        String value = originClass.value();
        if (!value.endsWith("$") || ctClass.getName().endsWith("$")) {
            return value;
        }
        return null;
    }

    static ClassFile loadClassFile(byte[] bArr) throws IOException {
        return new ClassFile(new DataInputStream(new ByteArrayInputStream(bArr)));
    }

    static Optional<ClassFile> loadClassFile(String str) throws IOException {
        InputStream resourceAsStream = SparkPatchClassTransformer.class.getClassLoader().getResourceAsStream(str);
        try {
            Optional<ClassFile> empty = resourceAsStream == null ? Optional.empty() : Optional.of(loadClassFile(resourceAsStream.readAllBytes()));
            if (resourceAsStream != null) {
                resourceAsStream.close();
            }
            return empty;
        } catch (Throwable th) {
            if (resourceAsStream != null) {
                try {
                    resourceAsStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    static byte[] serializeClass(ClassFile classFile) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        classFile.write(new DataOutputStream(byteArrayOutputStream));
        byteArrayOutputStream.flush();
        return byteArrayOutputStream.toByteArray();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SparkPatchClassTransformer(Map<String, String> map) {
        this.classMappings = map;
        this.patchedClasses = (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getValue();
        }, (v0) -> {
            return v0.getKey();
        }));
        log.debug("Creating classfile transformer for the following classes: {}", this.patchedClasses);
    }

    public byte[] transform(ClassLoader classLoader, String str, Class<?> cls, ProtectionDomain protectionDomain, byte[] bArr) {
        if (!this.patchedClasses.containsKey(str)) {
            return null;
        }
        try {
            ClassFile processAnnotations = processAnnotations(loadClassFile(toPatchClassName(str)).orElseThrow(), classLoader, bArr);
            processAnnotations.renameClass(this.classMappings);
            byte[] serializeClass = serializeClass(processAnnotations);
            log.info("Patch size for class {} is {} and after patching the size is {}", new Object[]{str, Integer.valueOf(bArr.length), Integer.valueOf(serializeClass.length)});
            return serializeClass;
        } catch (Exception e) {
            log.error(String.format("Can't patch class %s because an exception has occured", str), e);
            throw new SparkPatchException(e);
        }
    }

    private String toPatchClassName(String str) {
        return this.patchedClasses.get(str) + ".class";
    }

    private ClassFile processAnnotations(ClassFile classFile, ClassLoader classLoader, byte[] bArr) throws Exception {
        CtClass makeClass = ClassPool.getDefault().makeClass(classFile);
        String originClass = getOriginClass(makeClass);
        if (originClass == null) {
            return classFile;
        }
        ClassFile classFile2 = classFile;
        for (Object obj : makeClass.getAnnotations()) {
            if (obj instanceof Subclass) {
                String str = originClass + "Base";
                log.info("Changing superclass of {} to {}", originClass, str);
                classFile.setSuperclass(str);
                classFile.renameClass(originClass, str);
                ClassFile loadClassFile = loadClassFile(bArr);
                loadClassFile.renameClass(originClass, str);
                ClassPool.getDefault().makeClass(loadClassFile).toClass(classLoader, (ProtectionDomain) null);
            }
            if (obj instanceof AddInterfaces) {
                ClassFile loadClassFile2 = loadClassFile(bArr);
                for (Class<?> cls : ((AddInterfaces) obj).value()) {
                    loadClassFile2.addInterface(cls.getName());
                }
                classFile2 = ClassPool.getDefault().makeClass(loadClassFile2).getClassFile();
            }
            if (obj instanceof Decorate) {
                CtClass makeClass2 = ClassPool.getDefault().makeClass(loadClassFile(bArr));
                for (CtMethod ctMethod : makeClass.getDeclaredMethods()) {
                    if (checkDecoratedMethod(ctMethod)) {
                        DecoratedMethod decoratedMethod = (DecoratedMethod) ctMethod.getAnnotation(DecoratedMethod.class);
                        String name = ctMethod.getName();
                        CtMethod method = makeClass2.getMethod(name, ctMethod.getSignature());
                        log.debug("Patching decorated method {} with signature {}", name, method.getSignature());
                        method.setName("__" + name);
                        for (Class<? extends MethodProcesor> cls2 : decoratedMethod.baseMethodProcessors()) {
                            cls2.getDeclaredConstructor(new Class[0]).newInstance(new Object[0]).process(method);
                        }
                        makeClass2.addMethod(CtNewMethod.copy(ctMethod, makeClass2, (ClassMap) null));
                    }
                }
                classFile2 = makeClass2.getClassFile();
            }
        }
        return classFile2;
    }

    private static boolean checkDecoratedMethod(CtMethod ctMethod) throws Exception {
        return ctMethod.hasAnnotation(DecoratedMethod.class) && (!ctMethod.hasAnnotation(Applicability.class) || checkApplicability((Applicability) ctMethod.getAnnotation(Applicability.class)));
    }

    private static boolean checkApplicability(Applicability applicability) {
        Ordering ordering = SparkVersionUtils.ordering();
        String currentVersion = SparkVersionUtils.currentVersion();
        return (applicability.from().isEmpty() || ordering.gteq(currentVersion, applicability.from())) && (applicability.to().isEmpty() || ordering.lteq(currentVersion, applicability.to()));
    }
}
