package org.deeplearning4j.nn.conf.layers;

import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/BaseOutputLayer.class */
public abstract class BaseOutputLayer extends FeedForwardLayer {
    protected ILossFunction lossFn;
    protected boolean hasBias;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/BaseOutputLayer$Builder.class */
    public static abstract class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
        protected ILossFunction lossFn = new LossMCXENT();
        private boolean hasBias = true;

        public Builder() {
        }

        public Builder(LossFunctions.LossFunction lossFunction) {
            lossFunction(lossFunction);
        }

        public Builder(ILossFunction iLossFunction) {
            setLossFn(iLossFunction);
        }

        public T lossFunction(LossFunctions.LossFunction lossFunction) {
            return lossFunction(lossFunction.getILossFunction());
        }

        public T hasBias(boolean z) {
            setHasBias(z);
            return this;
        }

        public T lossFunction(ILossFunction iLossFunction) {
            setLossFn(iLossFunction);
            return this;
        }

        public ILossFunction getLossFn() {
            return this.lossFn;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public void setLossFn(ILossFunction iLossFunction) {
            this.lossFn = iLossFunction;
        }

        public void setHasBias(boolean z) {
            this.hasBias = z;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseOutputLayer(Builder builder) {
        super(builder);
        this.hasBias = true;
        this.lossFn = builder.lossFn;
        this.hasBias = builder.hasBias;
    }

    public boolean hasBias() {
        return this.hasBias;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        InputType outputType = getOutputType(-1, inputType);
        long numParams = initializer().numParams(this);
        int stateSize = (int) getIUpdater().stateSize(numParams);
        int i = 0;
        if (getIDropout() != null) {
            i = (int) (0 + inputType.arrayElementsPerExample());
        }
        return new LayerMemoryReport.Builder(this.layerName, OutputLayer.class, inputType, outputType).standardMemory(numParams, stateSize).workingMemory(0L, 0L, 0, (int) (i + outputType.arrayElementsPerExample())).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    public ILossFunction getLossFn() {
        return this.lossFn;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public void setLossFn(ILossFunction iLossFunction) {
        this.lossFn = iLossFunction;
    }

    public void setHasBias(boolean z) {
        this.hasBias = z;
    }

    public BaseOutputLayer() {
        this.hasBias = true;
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "BaseOutputLayer(super=" + super.toString() + ", lossFn=" + getLossFn() + ", hasBias=" + isHasBias() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BaseOutputLayer)) {
            return false;
        }
        BaseOutputLayer baseOutputLayer = (BaseOutputLayer) obj;
        if (!baseOutputLayer.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        ILossFunction lossFn = getLossFn();
        ILossFunction lossFn2 = baseOutputLayer.getLossFn();
        if (lossFn == null) {
            if (lossFn2 != null) {
                return false;
            }
        } else if (!lossFn.equals(lossFn2)) {
            return false;
        }
        return isHasBias() == baseOutputLayer.isHasBias();
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof BaseOutputLayer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = super.hashCode();
        ILossFunction lossFn = getLossFn();
        return (((hashCode * 59) + (lossFn == null ? 43 : lossFn.hashCode())) * 59) + (isHasBias() ? 79 : 97);
    }
}
