package org.deeplearning4j.nn.layers.mkldnn;

import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.class */
public class MKLDNNConvHelper implements ConvolutionHelper {
    protected OpContext context;
    protected OpContext contextBwd;

    public MKLDNNConvHelper(DataType dataType) {
    }

    @Override // org.deeplearning4j.nn.layers.convolution.ConvolutionHelper
    public boolean checkSupported() {
        return BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override // org.deeplearning4j.nn.layers.convolution.ConvolutionHelper
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, int[] iArr, int[] iArr2, int[] iArr3, INDArray iNDArray5, INDArray iNDArray6, IActivation iActivation, ConvolutionLayer.AlgoMode algoMode, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] iArr4, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (iNDArray.dataType() != DataType.FLOAT || iNDArray2.dataType() != DataType.FLOAT) {
            return null;
        }
        INDArray permute = iNDArray2.permute(2, 3, 1, 0);
        INDArray permute2 = iNDArray6.permute(2, 3, 1, 0);
        if (convolutionMode == ConvolutionMode.Same) {
            iArr3 = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int) iNDArray4.size(2), (int) iNDArray4.size(3)}, new int[]{(int) iNDArray.size(2), (int) iNDArray.size(3)}, iArr, iArr2, iArr4);
        }
        if (this.contextBwd == null) {
            this.contextBwd = Nd4j.getExecutioner().buildContext();
            OpContext opContext = this.contextBwd;
            long[] jArr = new long[10];
            jArr[0] = iArr[0];
            jArr[1] = iArr[1];
            jArr[2] = iArr2[0];
            jArr[3] = iArr2[1];
            jArr[4] = iArr3[0];
            jArr[5] = iArr3[1];
            jArr[6] = iArr4[0];
            jArr[7] = iArr4[1];
            jArr[8] = ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same);
            jArr[9] = 0;
            opContext.setIArguments(jArr);
        }
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray.dataType(), iNDArray.shape());
        INDArray[] iNDArrayArr = iNDArray5 == null ? new INDArray[]{iNDArray, permute, iNDArray4} : new INDArray[]{iNDArray, permute, iNDArray3, iNDArray4};
        INDArray[] iNDArrayArr2 = iNDArray5 == null ? new INDArray[]{createUninitialized, permute2} : new INDArray[]{createUninitialized, permute2, iNDArray5};
        this.contextBwd.getInputArrays().clear();
        this.contextBwd.getOutputArrays().clear();
        for (int i = 0; i < iNDArrayArr.length; i++) {
            this.contextBwd.setInputArray(i, iNDArrayArr[i]);
        }
        for (int i2 = 0; i2 < iNDArrayArr2.length; i2++) {
            this.contextBwd.setOutputArray(i2, iNDArrayArr2[i2]);
        }
        Nd4j.exec(new Conv2DDerivative(), this.contextBwd);
        this.contextBwd.getInputArrays().clear();
        this.contextBwd.getOutputArrays().clear();
        DefaultGradient defaultGradient = new DefaultGradient();
        if (iNDArray5 != null) {
            defaultGradient.gradientForVariable().put("b", iNDArray5);
        }
        defaultGradient.gradientForVariable().put("W", iNDArray6);
        return new Pair<>(defaultGradient, createUninitialized);
    }

    @Override // org.deeplearning4j.nn.layers.convolution.ConvolutionHelper
    public INDArray preOutput(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionLayer.AlgoMode algoMode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] iArr4, LayerWorkspaceMgr layerWorkspaceMgr) {
        int[] outputSize;
        if (iNDArray.dataType() != DataType.FLOAT || iNDArray2.dataType() != DataType.FLOAT) {
            return null;
        }
        int size = (int) iNDArray.size(2);
        int size2 = (int) iNDArray.size(3);
        if (convolutionMode == ConvolutionMode.Same) {
            outputSize = ConvolutionUtils.getOutputSize(iNDArray, iArr, iArr2, null, convolutionMode, iArr4);
            iArr3 = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, new int[]{size, size2}, iArr, iArr2, iArr4);
        } else {
            outputSize = ConvolutionUtils.getOutputSize(iNDArray, iArr, iArr2, iArr3, convolutionMode, iArr4);
        }
        if (this.context != null) {
        }
        this.context = Nd4j.getExecutioner().buildContext();
        OpContext opContext = this.context;
        long[] jArr = new long[10];
        jArr[0] = iArr[0];
        jArr[1] = iArr[1];
        jArr[2] = iArr2[0];
        jArr[3] = iArr2[1];
        jArr[4] = iArr3[0];
        jArr[5] = iArr3[1];
        jArr[6] = iArr4[0];
        jArr[7] = iArr4[1];
        jArr[8] = ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same);
        jArr[9] = 0;
        opContext.setIArguments(jArr);
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, iNDArray.dataType(), iNDArray.size(0), (int) iNDArray2.size(0), outputSize[0], outputSize[1]);
        INDArray permute = iNDArray2.permute(2, 3, 1, 0);
        INDArray[] iNDArrayArr = iNDArray3 == null ? new INDArray[]{iNDArray, permute} : new INDArray[]{iNDArray, permute, iNDArray3};
        this.context.getInputArrays().clear();
        for (int i = 0; i < iNDArrayArr.length; i++) {
            this.context.setInputArray(i, iNDArrayArr[i]);
        }
        this.context.getOutputArrays().clear();
        this.context.setOutputArray(0, createUninitialized);
        Nd4j.exec(new Conv2D(), this.context);
        this.context.getInputArrays().clear();
        this.context.getOutputArrays().clear();
        return createUninitialized;
    }

    @Override // org.deeplearning4j.nn.layers.convolution.ConvolutionHelper
    public INDArray activate(INDArray iNDArray, IActivation iActivation, boolean z) {
        return iActivation.getActivation(iNDArray, z);
    }

    @Override // org.deeplearning4j.nn.layers.LayerHelper
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}
