package org.deeplearning4j.zoo.model.helper;

import java.util.Map;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/zoo/model/helper/NASNetHelper.class */
public class NASNetHelper {
    public static String sepConvBlock(ComputationGraphConfiguration.GraphBuilder graphBuilder, int i, int i2, int i3, String str, String str2) {
        String str3 = "sepConvBlock" + str;
        graphBuilder.addLayer(str3 + "_act", new ActivationLayer(Activation.RELU), new String[]{str2}).addLayer(str3 + "_sepconv1", new SeparableConvolution2D.Builder(new int[]{i2, i2}).stride(new int[]{i3, i3}).nOut(i).hasBias(false).convolutionMode(ConvolutionMode.Same).build(), new String[]{str3 + "_act"}).addLayer(str3 + "_conv1_bn", new BatchNormalization.Builder().eps(0.001d).gamma(0.9997d).build(), new String[]{str3 + "_sepconv1"}).addLayer(str3 + "_act2", new ActivationLayer(Activation.RELU), new String[]{str3 + "_conv1_bn"}).addLayer(str3 + "_sepconv2", new SeparableConvolution2D.Builder(new int[]{i2, i2}).stride(new int[]{i3, i3}).nOut(i).hasBias(false).convolutionMode(ConvolutionMode.Same).build(), new String[]{str3 + "_act2"}).addLayer(str3 + "_conv2_bn", new BatchNormalization.Builder().eps(0.001d).gamma(0.9997d).build(), new String[]{str3 + "_sepconv2"});
        return str3 + "_conv2_bn";
    }

    public static String adjustBlock(ComputationGraphConfiguration.GraphBuilder graphBuilder, int i, String str, String str2) {
        return adjustBlock(graphBuilder, i, str, str2, null);
    }

    public static String adjustBlock(ComputationGraphConfiguration.GraphBuilder graphBuilder, int i, String str, String str2, String str3) {
        String str4 = "adjustBlock" + str;
        String str5 = str2;
        if (str3 == null) {
            str3 = str2;
        }
        Map layerActivationTypes = graphBuilder.getLayerActivationTypes();
        int[] shape = ((InputType) layerActivationTypes.get(str3)).getShape();
        int[] shape2 = ((InputType) layerActivationTypes.get(str2)).getShape();
        if (shape[1] != shape2[1]) {
            graphBuilder.addLayer(str4 + "_relu1", new ActivationLayer(Activation.RELU), new String[]{str2}).addLayer(str4 + "_avgpool1", new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{1, 1}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Truncate).build(), new String[]{str4 + "_relu1"}).addLayer(str4 + "_conv1", new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut((int) Math.floor(i / 2)).hasBias(false).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_avg_pool_1"}).addLayer(str4 + "_zeropad1", new ZeroPaddingLayer(0, 1), new String[]{str4 + "_relu1"}).addLayer(str4 + "_crop1", new Cropping2D(1, 0), new String[]{str4 + "_zeropad_1"}).addLayer(str4 + "_avgpool2", new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{1, 1}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Truncate).build(), new String[]{str4 + "_crop1"}).addLayer(str4 + "_conv2", new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut((int) Math.floor(i / 2)).hasBias(false).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_avgpool2"}).addVertex(str4 + "_concat1", new MergeVertex(), new String[]{str4 + "_conv1", str4 + "_conv2"}).addLayer(str4 + "_bn1", new BatchNormalization.Builder().eps(0.001d).gamma(0.9997d).build(), new String[]{str4 + "_concat1"});
            str5 = str4 + "_bn1";
        }
        if (shape2[3] != i) {
            graphBuilder.addLayer(str4 + "_projection_relu", new ActivationLayer(Activation.RELU), new String[]{str5}).addLayer(str4 + "_projection_conv", new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(i).hasBias(false).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_projection_relu"}).addLayer(str4 + "_projection_bn", new BatchNormalization.Builder().eps(0.001d).gamma(0.9997d).build(), new String[]{str4 + "_projection_conv"});
            str5 = str4 + "_projection_bn";
        }
        return str5;
    }

    public static Pair<String, String> normalA(ComputationGraphConfiguration.GraphBuilder graphBuilder, int i, String str, String str2, String str3) {
        String str4 = "normalA" + str;
        String adjustBlock = adjustBlock(graphBuilder, i, str4, str3, str2);
        graphBuilder.addLayer(str4 + "_relu1", new ActivationLayer(Activation.RELU), new String[]{adjustBlock}).addLayer(str4 + "_conv1", new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(i).hasBias(false).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_relu1"}).addLayer(str4 + "_bn1", new BatchNormalization.Builder().eps(0.001d).gamma(0.9997d).build(), new String[]{str4 + "_conv1"});
        graphBuilder.addVertex(str4 + "_add1", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{sepConvBlock(graphBuilder, i, 5, 1, str4 + "_left1", str4 + "_bn1"), sepConvBlock(graphBuilder, i, 3, 1, str4 + "_right1", adjustBlock)});
        graphBuilder.addVertex(str4 + "_add2", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{sepConvBlock(graphBuilder, i, 5, 1, str4 + "_left2", adjustBlock), sepConvBlock(graphBuilder, i, 3, 1, str4 + "_right2", adjustBlock)});
        graphBuilder.addLayer(str4 + "_left3", new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_bn1"}).addVertex(str4 + "_add3", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{str4 + "_left3", adjustBlock});
        graphBuilder.addLayer(str4 + "_left4", new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same).build(), new String[]{adjustBlock}).addLayer(str4 + "_right4", new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same).build(), new String[]{adjustBlock}).addVertex(str4 + "_add4", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{str4 + "_left4", str4 + "_right4"});
        sepConvBlock(graphBuilder, i, 3, 1, str4 + "_left5", adjustBlock);
        graphBuilder.addVertex(str4 + "_add5", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{str4 + "_left5", str4 + "_bn1"});
        graphBuilder.addVertex(str4, new MergeVertex(), new String[]{adjustBlock, str4 + "_add1", str4 + "_add2", str4 + "_add3", str4 + "_add4", str4 + "_add5"});
        return new Pair<>(str4, str2);
    }

    public static Pair<String, String> reductionA(ComputationGraphConfiguration.GraphBuilder graphBuilder, int i, String str, String str2, String str3) {
        String str4 = "reductionA" + str;
        String adjustBlock = adjustBlock(graphBuilder, i, str4, str3, str2);
        graphBuilder.addLayer(str4 + "_relu1", new ActivationLayer(Activation.RELU), new String[]{adjustBlock}).addLayer(str4 + "_conv1", new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(i).hasBias(false).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_relu1"}).addLayer(str4 + "_bn1", new BatchNormalization.Builder().eps(0.001d).gamma(0.9997d).build(), new String[]{str4 + "_conv1"});
        graphBuilder.addVertex(str4 + "_add1", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{sepConvBlock(graphBuilder, i, 5, 2, str4 + "_left1", str4 + "_bn1"), sepConvBlock(graphBuilder, i, 7, 2, str4 + "_right1", adjustBlock)});
        graphBuilder.addLayer(str4 + "_left2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(new int[]{3, 3}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_bn1"});
        graphBuilder.addVertex(str4 + "_add2", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{str4 + "_left2", sepConvBlock(graphBuilder, i, 3, 1, str4 + "_right2", adjustBlock)});
        graphBuilder.addLayer(str4 + "_left3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_bn1"});
        graphBuilder.addVertex(str4 + "_add3", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{str4 + "_left3", sepConvBlock(graphBuilder, i, 5, 2, str4 + "_right3", adjustBlock)});
        graphBuilder.addLayer(str4 + "_left4", new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_add1"}).addVertex(str4 + "_add4", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{str4 + "_add2", str4 + "_left4"});
        graphBuilder.addLayer(str4 + "_right5", new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(new int[]{3, 3}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Same).build(), new String[]{str4 + "_bn1"}).addVertex(str4 + "_add5", new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{sepConvBlock(graphBuilder, i, 3, 2, str4 + "_left5", str4 + "_add1"), str4 + "_right5"});
        graphBuilder.addVertex(str4, new MergeVertex(), new String[]{str4 + "_add2", str4 + "_add3", str4 + "_add4", str4 + "_add5"});
        return new Pair<>(str4, str2);
    }
}
