package org.multij.codegen;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.stream.Collectors;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Name;
import javax.lang.model.element.PackageElement;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.ExecutableType;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Types;
import javax.tools.Diagnostic;
import org.multij.AmbiguityException;
import org.multij.CircularityException;
import org.multij.InjectionException;
import org.multij.MissingDefinitionException;
import org.multij.ModuleRepository;
import org.multij.MultiJModule;
import org.multij.model.DecisionTree;
import org.multij.model.EntryPoint;
import org.multij.model.Module;
import org.multij.model.MultiMethod;

/* loaded from: input_file:org/multij/codegen/CodeGenerator.class */
public class CodeGenerator {
    private final ProcessingEnvironment processingEnv;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/multij/codegen/CodeGenerator$MethodCodeGenerator.class */
    public class MethodCodeGenerator implements DecisionTree.Visitor {
        private int indentation = 2;
        private final TypeElement module;
        private final Name moduleName;
        private final PrintWriter writer;
        private final EntryPoint tree;
        private final ExecutableElement entryPoint;

        public MethodCodeGenerator(TypeElement typeElement, PrintWriter printWriter, EntryPoint entryPoint) {
            this.module = typeElement;
            this.moduleName = typeElement.getQualifiedName();
            this.writer = printWriter;
            this.tree = entryPoint;
            this.entryPoint = entryPoint.getEntryPoint();
        }

        public void generateCode() {
            ExecutableType asMemberOf = CodeGenerator.this.typeUtil().asMemberOf(this.module.asType(), this.entryPoint);
            this.writer.format("\tpublic %s%s %s(", asMemberOf.getTypeVariables().isEmpty() ? "" : (String) asMemberOf.getTypeVariables().stream().map(typeVariable -> {
                String obj = typeVariable.toString();
                if (typeVariable.getUpperBound() != null) {
                    obj = obj + " extends " + typeVariable.getUpperBound();
                }
                return obj;
            }).collect(Collectors.joining(", ", "<", "> ")), asMemberOf.getReturnType(), this.entryPoint.getSimpleName());
            int i = 0;
            for (TypeMirror typeMirror : asMemberOf.getParameterTypes()) {
                if (i > 0) {
                    this.writer.print(", ");
                }
                this.writer.print(typeMirror);
                this.writer.print(" p");
                int i2 = i;
                i++;
                this.writer.print(i2);
            }
            this.writer.println(") {");
            generateForNode(this.tree.getDecisionTree());
            this.writer.println("\t}\n");
        }

        private void generateForNode(DecisionTree decisionTree) {
            decisionTree.accept(this);
        }

        private void println(String str) {
            for (int i = 0; i < this.indentation; i++) {
                this.writer.append('\t');
            }
            this.writer.println(str);
        }

        @Override // org.multij.model.DecisionTree.Visitor
        public void visitDecision(DecisionTree.DecisionNode decisionNode) {
            if (!decisionNode.getDefinition().isDefault()) {
                ArrayList arrayList = new ArrayList();
                arrayList.add("\"" + this.moduleName + "\"");
                arrayList.add("\"" + decisionNode.getDefinition().getSimpleName() + "\"");
                int i = 0;
                for (VariableElement variableElement : decisionNode.getDefinition().getParameters()) {
                    if (variableElement.asType().getKind() == TypeKind.DECLARED) {
                        arrayList.add("p" + i + " == null ? \"null\" : p" + i + ".getClass().getCanonicalName()");
                    } else {
                        arrayList.add("\"" + variableElement.asType().toString() + "\"");
                    }
                    i++;
                }
                println("throw new " + MissingDefinitionException.class.getCanonicalName() + "(" + String.join(", ", arrayList) + ");");
                return;
            }
            String str = this.moduleName + ".super." + decisionNode.getDefinition().getSimpleName() + "(";
            int i2 = 0;
            Iterator it = decisionNode.getDefinition().getParameters().iterator();
            while (it.hasNext()) {
                TypeMirror asType = ((VariableElement) it.next()).asType();
                if (i2 > 0) {
                    str = str + ", ";
                }
                if (!CodeGenerator.this.typeUtil().isSameType(asType, ((VariableElement) this.entryPoint.getParameters().get(i2)).asType())) {
                    str = str + "(" + asType.toString() + ") ";
                }
                int i3 = i2;
                i2++;
                str = str + "p" + i3;
            }
            String str2 = str + ")";
            if (decisionNode.getDefinition().getReturnType().getKind() != TypeKind.VOID) {
                println("return " + str2 + ";");
            } else {
                println(str2 + ";");
                println("return;");
            }
        }

        @Override // org.multij.model.DecisionTree.Visitor
        public void visitAmbiguity(DecisionTree.AmbiguityNode ambiguityNode) {
            println("throw new " + AmbiguityException.class.getCanonicalName() + "();");
        }

        @Override // org.multij.model.DecisionTree.Visitor
        public void visitCondition(DecisionTree.ConditionNode conditionNode) {
            println("if (p" + conditionNode.getCondition().getArgument() + " instanceof " + CodeGenerator.this.processingEnv.getTypeUtils().erasure(conditionNode.getCondition().getType()) + ") {");
            this.indentation++;
            generateForNode(conditionNode.getIsTrue());
            this.indentation--;
            println("} else {");
            this.indentation++;
            generateForNode(conditionNode.getIsFalse());
            this.indentation--;
            println("}");
        }
    }

    public CodeGenerator(ProcessingEnvironment processingEnvironment) {
        this.processingEnv = processingEnvironment;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Types typeUtil() {
        return this.processingEnv.getTypeUtils();
    }

    private String className(Element element) {
        String str = "MultiJ";
        do {
            str = element.getSimpleName() + "$" + str;
            element = element.getEnclosingElement();
        } while (element.getKind() != ElementKind.PACKAGE);
        return str;
    }

    public void generateSource(Module module) {
        try {
            String className = className(module.getTypeElement());
            PackageElement packageOf = this.processingEnv.getElementUtils().getPackageOf(module.getTypeElement());
            PrintWriter printWriter = new PrintWriter(this.processingEnv.getFiler().createSourceFile(packageOf.isUnnamed() ? className : packageOf.getQualifiedName() + "." + className, new Element[]{module.getTypeElement()}).openWriter());
            if (!packageOf.isUnnamed()) {
                printWriter.format("package %s;\n", packageOf.getQualifiedName());
            }
            printWriter.format("public final class %s implements %s, %s {\n", className, module.getTypeElement().getQualifiedName(), MultiJModule.class.getCanonicalName());
            generateModuleRefs(module, printWriter);
            generateInjecions(module, printWriter);
            generateCachedAttrs(module, printWriter);
            generateMultiMethods(module, printWriter);
            printWriter.println("}");
            printWriter.close();
        } catch (IOException e) {
            this.processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, e.getMessage());
        }
    }

    private void generateMultiMethods(Module module, PrintWriter printWriter) {
        Iterator<MultiMethod> it = module.getMultiMethods().iterator();
        while (it.hasNext()) {
            Iterator<EntryPoint> it2 = it.next().getEntryPoints().iterator();
            while (it2.hasNext()) {
                new MethodCodeGenerator(module.getTypeElement(), printWriter, it2.next()).generateCode();
            }
        }
    }

    private void generateModuleRefs(Module module, PrintWriter printWriter) {
        for (ExecutableElement executableElement : module.getModuleReferences()) {
            printWriter.format("\tprivate %s module$%s;\n", executableElement.getReturnType(), executableElement.getSimpleName());
        }
        for (ExecutableElement executableElement2 : module.getModuleReferences()) {
            printWriter.format("\tpublic %s %s() { return module$%2$s; }\n", executableElement2.getReturnType(), executableElement2.getSimpleName());
        }
        printWriter.println("\tprivate boolean multij$initialized = false;");
        printWriter.format("\tpublic void multij$init(%s repo) {\n", ModuleRepository.class.getCanonicalName());
        printWriter.format("\t\tif (multij$initialized) throw new %s();\n", IllegalStateException.class.getCanonicalName());
        for (ExecutableElement executableElement3 : module.getModuleReferences()) {
            printWriter.format("\t\tmodule$%s = repo.getModule(%s.class);\n", executableElement3.getSimpleName(), executableElement3.getReturnType());
        }
        printWriter.println("\t\tmultij$initialized = true;");
        printWriter.println("\t}\n");
    }

    private void generateInjecions(Module module, PrintWriter printWriter) {
        for (ExecutableElement executableElement : module.getInjectedAttributes()) {
            Name simpleName = executableElement.getSimpleName();
            TypeMirror returnType = executableElement.getReturnType();
            printWriter.format("\tprivate boolean inj$init$%s;\n", simpleName);
            printWriter.format("\tprivate %s inj$attr$%s;\n\n", returnType, simpleName);
            printWriter.format("\tpublic %s %s() { return inj$attr$%2$s; }\n\n", returnType, simpleName);
        }
        String canonicalName = MultiJModule.NotFound.class.getCanonicalName();
        String canonicalName2 = MultiJModule.WrongType.class.getCanonicalName();
        String canonicalName3 = MultiJModule.AlreadySet.class.getCanonicalName();
        String canonicalName4 = MultiJModule.class.getCanonicalName();
        printWriter.format("\tpublic %s multij$getModule(String name) throws %s {\n", canonicalName4, canonicalName);
        printWriter.format("\t\tswitch(name) {\n", new Object[0]);
        Iterator<ExecutableElement> it = module.getModuleReferences().iterator();
        while (it.hasNext()) {
            printWriter.format("\t\tcase \"%1$s\": return (%2$s) module$%1$s;\n", it.next().getSimpleName(), MultiJModule.class.getCanonicalName());
        }
        printWriter.format("\t\tdefault: throw new %s();\n", canonicalName);
        printWriter.format("\t\t}\n", new Object[0]);
        printWriter.format("\t}\n\n", new Object[0]);
        printWriter.format("\tpublic void multij$setField(String name, Object value) throws %s, %s, %s{\n", canonicalName, canonicalName2, canonicalName3);
        printWriter.format("\t\tswitch(name) {\n", new Object[0]);
        for (ExecutableElement executableElement2 : module.getInjectedAttributes()) {
            Name simpleName2 = executableElement2.getSimpleName();
            TypeMirror erasure = this.processingEnv.getTypeUtils().erasure(executableElement2.getReturnType());
            printWriter.format("\t\tcase \"%s\":\n", simpleName2);
            printWriter.format("\t\t\tif (inj$init$%s) {\n", simpleName2);
            printWriter.format("\t\t\t\tthrow new %s();\n", canonicalName3);
            printWriter.format("\t\t\t} else if (value instanceof %s || value == null) {\n", erasure);
            printWriter.format("\t\t\t\tinj$init$%s = true;\n", simpleName2);
            printWriter.format("\t\t\t\tinj$attr$%s = (%s) value;\n", simpleName2, erasure);
            printWriter.format("\t\t\t\treturn;\n", new Object[0]);
            printWriter.format("\t\t\t} else {\n", new Object[0]);
            printWriter.format("\t\t\t\tthrow new %s(%s.class);\n", canonicalName2, erasure);
            printWriter.format("\t\t\t}\n", new Object[0]);
        }
        printWriter.format("\t\tdefault: throw new %s();\n", canonicalName);
        printWriter.format("\t\t}\n", new Object[0]);
        printWriter.format("\t}\n\n", new Object[0]);
        printWriter.format("\tpublic void multij$checkInjected(java.util.Set<%s> visited, java.util.List<String> prefix) {\n", canonicalName4);
        printWriter.format("\t\tif (visited.contains(this)) return;\n", new Object[0]);
        printWriter.format("\t\tvisited.add(this);\n", new Object[0]);
        for (ExecutableElement executableElement3 : module.getInjectedAttributes()) {
            printWriter.format("\t\tif (!inj$init$%s) {\n", executableElement3.getSimpleName());
            printWriter.format("\t\t\tprefix.add(\"%s\");\n", executableElement3.getSimpleName());
            printWriter.format("\t\t\tthrow %s.notInjected(prefix.toArray(new String[prefix.size()]));\n", InjectionException.class.getCanonicalName());
            printWriter.format("\t\t}\n", new Object[0]);
        }
        printWriter.format("\t}\n\n", new Object[0]);
    }

    private void generateCachedAttrs(Module module, PrintWriter printWriter) {
        for (ExecutableElement executableElement : module.getCachedAttributes()) {
            Name simpleName = executableElement.getSimpleName();
            TypeMirror returnType = executableElement.getReturnType();
            Name qualifiedName = module.getTypeElement().getQualifiedName();
            printWriter.format("boolean cache$init$%s = false;\n", simpleName);
            printWriter.format("boolean cache$done$%s = false;\n", simpleName);
            printWriter.format("%s cache$value$%s;\n\n", returnType, simpleName);
            printWriter.format("public synchronized %s %s() {\n", returnType, simpleName);
            printWriter.format("\tif (cache$done$%s) {\n", simpleName);
            printWriter.format("\t\treturn cache$value$%s;\n", simpleName);
            printWriter.format("\t} else if (cache$init$%s) {\n", simpleName);
            printWriter.format("\t\tthrow new %s();\n", CircularityException.class.getCanonicalName());
            printWriter.format("\t} else {\n", new Object[0]);
            printWriter.format("\t\ttry {\n", new Object[0]);
            printWriter.format("\t\t\tcache$init$%s = true;\n", simpleName);
            printWriter.format("\t\t\tcache$value$%s = %s.super.%1$s();\n", simpleName, qualifiedName);
            printWriter.format("\t\t\tcache$done$%s = true;\n", simpleName);
            printWriter.format("\t\t\treturn cache$value$%s;\n", simpleName);
            printWriter.format("\t\t} catch (%s e) {\n", CircularityException.class.getCanonicalName());
            printWriter.format("\t\t\tthrow e;\n", new Object[0]);
            printWriter.format("\t\t} catch (%s | %s e) {\n", RuntimeException.class.getCanonicalName(), Error.class.getCanonicalName());
            printWriter.format("\t\t\tcache$init$%s = false;\n", simpleName);
            printWriter.format("\t\t\tthrow e;\n", new Object[0]);
            printWriter.format("\t\t}\n", new Object[0]);
            printWriter.format("\t}\n", new Object[0]);
            printWriter.format("}\n\n", new Object[0]);
        }
    }
}
