package org.nd4j.linalg.api.ops.impl.layers;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseModule;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.weightinit.WeightInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/Linear.class */
public class Linear extends BaseModule {
    private DifferentialFunction forward;
    private int nIn;
    private int nOut;
    private WeightInitScheme weightInitScheme;
    private WeightInitScheme biasWeightInitScheme;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/Linear$LinearBuilder.class */
    public static class LinearBuilder {
        private int nIn;
        private int nOut;
        private WeightInitScheme weightInitScheme;
        private WeightInitScheme biasWeightInitScheme;
        private SameDiff sameDiff;

        LinearBuilder() {
        }

        public LinearBuilder nIn(int i) {
            this.nIn = i;
            return this;
        }

        public LinearBuilder nOut(int i) {
            this.nOut = i;
            return this;
        }

        public LinearBuilder weightInitScheme(WeightInitScheme weightInitScheme) {
            this.weightInitScheme = weightInitScheme;
            return this;
        }

        public LinearBuilder biasWeightInitScheme(WeightInitScheme weightInitScheme) {
            this.biasWeightInitScheme = weightInitScheme;
            return this;
        }

        public Linear build() {
            return new Linear(this.nIn, this.nOut, this.weightInitScheme, this.biasWeightInitScheme);
        }

        public String toString() {
            return "Linear.LinearBuilder(nIn=" + this.nIn + ", nOut=" + this.nOut + ", weightInitScheme=" + this.weightInitScheme + ", biasWeightInitScheme=" + this.biasWeightInitScheme + ")";
        }

        public LinearBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }
    }

    public Linear(int i, int i2, WeightInitScheme weightInitScheme, WeightInitScheme weightInitScheme2) {
        super(null, getParams(i, i2, weightInitScheme, weightInitScheme2), new INDArray[0], new ArrayList(), new ArrayList(), new ArrayList());
        this.weightInitScheme = weightInitScheme;
        this.biasWeightInitScheme = weightInitScheme2;
        this.nIn = i;
        this.nOut = i2;
    }

    public Linear(SameDiff sameDiff, int i, int i2, WeightInitScheme weightInitScheme, WeightInitScheme weightInitScheme2) {
        super(null, sameDiff, null, false, new ArrayList());
        this.weightInitScheme = weightInitScheme;
        this.biasWeightInitScheme = weightInitScheme2;
        this.nIn = i;
        this.nOut = i2;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "linear";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        execSameDiff(new SDVariable[0]);
        return this.forward.doDiff(list);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(), new long[]{this.nOut, this.nIn}), inputArguments()[1].dataType()));
        arrayList.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(), inputArguments()[1].transpose().shape()), inputArguments()[1].dataType()));
        if (this.biasWeightInitScheme != null) {
            arrayList.add(LongShapeDescriptor.fromShape(new long[]{this.nOut, 1}, inputArguments()[1].dataType()));
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.Module
    public void exec(INDArray... iNDArrayArr) {
        INDArray[] inputArguments = inputArguments();
        if (inputArguments == null || inputArguments.length < 1) {
            throw new IllegalStateException("No arguments found.");
        }
        INDArray iNDArray = inputArguments[0];
        INDArray iNDArray2 = inputArguments[1];
        INDArray[] outputArguments = outputArguments();
        if (outputArguments != null && outputArguments.length >= 1) {
            iNDArrayArr[0].mmul(iNDArray.transpose(), outputArguments[0]);
        } else if (inputArguments.length == 1) {
            addOutputArgument(iNDArrayArr[0].mmul(iNDArray.transpose()));
        } else {
            addOutputArgument(iNDArrayArr[0].mmul(iNDArray.transpose()).addiColumnVector(iNDArray2));
        }
    }

    @Override // org.nd4j.linalg.api.ops.Module
    public void execSameDiff(SDVariable... sDVariableArr) {
        SDVariable[] args = args();
        if (args == null || args.length == 0) {
            throw new IllegalStateException("No arguments found");
        }
        if (this.forward == null) {
            if (args.length > 1) {
                this.forward = f().add(new Mmul(this.sameDiff, sDVariableArr[0], args()[0], MMulTranspose.builder().transposeA(false).transposeB(true).build()).outputVariables()[0], args()[1]);
            } else {
                this.forward = new Mmul(this.sameDiff, sDVariableArr[0], args()[0], MMulTranspose.builder().transposeA(false).transposeB(true).build());
            }
            this.outputVariables = this.forward.outputVariables();
        }
    }

    private static INDArray[] getParams(int i, int i2, WeightInitScheme weightInitScheme, WeightInitScheme weightInitScheme2) {
        return weightInitScheme2 != null ? new INDArray[]{weightInitScheme.create(Nd4j.defaultFloatingPointType(), i2, i), weightInitScheme2.create(Nd4j.defaultFloatingPointType(), i2, 1)} : new INDArray[]{weightInitScheme.create(Nd4j.defaultFloatingPointType(), i2, i)};
    }

    public static LinearBuilder execBuilder() {
        return new LinearBuilder();
    }

    public static LinearBuilder sameDiffBuilder() {
        return new LinearBuilder();
    }

    public Linear() {
    }
}
