package org.deeplearning4j.nn.conf.graph;

import com.google.common.base.Preconditions;
import java.util.Map;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/conf/graph/AttentionVertex.class */
public class AttentionVertex extends SameDiffVertex {
    private long nInKeys;
    private long nInValues;
    private long nInQueries;
    private long nOut;
    private long headSize;
    private int nHeads;
    private boolean projectInput;
    protected WeightInit weightInit;
    private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
    private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
    private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
    private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";

    /* loaded from: input_file:org/deeplearning4j/nn/conf/graph/AttentionVertex$Builder.class */
    public static class Builder {
        private long nInKeys = 0;
        private long nInValues = 0;
        private long nInQueries = 0;
        private long nOut = 0;
        private long headSize = 0;
        private int nHeads = 1;
        private boolean projectInput;
        protected WeightInit weightInit;

        public Builder nInKeys(long j) {
            this.nInKeys = j;
            return this;
        }

        public Builder nInQueries(long j) {
            this.nInQueries = j;
            return this;
        }

        public Builder nInValues(long j) {
            this.nInValues = j;
            return this;
        }

        public Builder headSize(long j) {
            this.headSize = j;
            return this;
        }

        public Builder nHeads(int i) {
            this.nHeads = i;
            return this;
        }

        public Builder nOut(long j) {
            this.nOut = j;
            return this;
        }

        public Builder weightInit(WeightInit weightInit) {
            this.weightInit = weightInit;
            return this;
        }

        public Builder projectInput(boolean z) {
            this.projectInput = z;
            return this;
        }

        public AttentionVertex build() {
            this.nHeads = this.nHeads == 0 ? 1 : this.nHeads;
            this.weightInit = this.weightInit == null ? WeightInit.XAVIER : this.weightInit;
            Preconditions.checkArgument(this.nOut > 0, "You have to set nOut");
            Preconditions.checkArgument(this.nInKeys > 0, "You have to set nInKeys");
            Preconditions.checkArgument(this.nInQueries > 0, "You have to set nInQueries");
            Preconditions.checkArgument(this.nInValues > 0, "You have to set nInValues");
            Preconditions.checkArgument(this.headSize > 0 || this.nOut % ((long) this.nHeads) == 0, "You have to set a head size if nOut isn't cleanly divided by nHeads");
            Preconditions.checkArgument(this.projectInput || (this.nInQueries == this.nInKeys && this.nInKeys == this.nInValues && this.nInValues == this.nOut && this.nHeads == 1), "You may only disable projectInput if all nIn* equal to nOut and you want to use only a single attention head");
            this.headSize = this.headSize == 0 ? this.nOut / this.nHeads : this.headSize;
            return new AttentionVertex(this);
        }

        public long getNInKeys() {
            return this.nInKeys;
        }

        public long getNInValues() {
            return this.nInValues;
        }

        public long getNInQueries() {
            return this.nInQueries;
        }

        public long getNOut() {
            return this.nOut;
        }

        public long getHeadSize() {
            return this.headSize;
        }

        public int getNHeads() {
            return this.nHeads;
        }

        public boolean isProjectInput() {
            return this.projectInput;
        }

        public WeightInit getWeightInit() {
            return this.weightInit;
        }

        public void setNInKeys(long j) {
            this.nInKeys = j;
        }

        public void setNInValues(long j) {
            this.nInValues = j;
        }

        public void setNInQueries(long j) {
            this.nInQueries = j;
        }

        public void setNOut(long j) {
            this.nOut = j;
        }

        public void setHeadSize(long j) {
            this.headSize = j;
        }

        public void setNHeads(int i) {
            this.nHeads = i;
        }

        public void setProjectInput(boolean z) {
            this.projectInput = z;
        }

        public void setWeightInit(WeightInit weightInit) {
            this.weightInit = weightInit;
        }
    }

    protected AttentionVertex(Builder builder) {
        this.nInKeys = 0L;
        this.nInValues = 0L;
        this.nInQueries = 0L;
        this.nOut = 0L;
        this.headSize = 0L;
        this.nHeads = 1;
        this.nInKeys = builder.nInKeys;
        this.nInValues = builder.nInValues;
        this.nInQueries = builder.nInQueries;
        this.nOut = builder.nOut;
        this.headSize = builder.headSize;
        this.projectInput = builder.projectInput;
        this.nHeads = builder.nHeads;
        this.weightInit = builder.weightInit;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex, org.deeplearning4j.nn.conf.graph.GraphVertex
    public InputType getOutputType(int i, InputType... inputTypeArr) throws InvalidInputTypeException {
        InputType.InputTypeRecurrent inputTypeRecurrent = (InputType.InputTypeRecurrent) inputTypeArr[0];
        return this.projectInput ? InputType.recurrent(this.nOut, inputTypeRecurrent.getTimeSeriesLength()) : InputType.recurrent(this.nInValues, inputTypeRecurrent.getTimeSeriesLength());
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public void defineParametersAndInputs(SDVertexParams sDVertexParams) {
        sDVertexParams.clear();
        sDVertexParams.defineInputs("queries", "keys", "values");
        if (this.projectInput) {
            sDVertexParams.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, this.nHeads, this.headSize, this.nInQueries);
            sDVertexParams.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, this.nHeads, this.headSize, this.nInKeys);
            sDVertexParams.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, this.nHeads, this.headSize, this.nInValues);
            sDVertexParams.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, this.nHeads * this.headSize, this.nOut);
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:7:0x0042. Please report as an issue. */
    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public void initializeParameters(Map<String, INDArray> map) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                for (Map.Entry<String, INDArray> entry : map.entrySet()) {
                    String key = entry.getKey();
                    boolean z = -1;
                    switch (key.hashCode()) {
                        case 2804:
                            if (key.equals(WEIGHT_KEY_KEY_PROJECTION)) {
                                z = true;
                                break;
                            }
                            break;
                        case 2808:
                            if (key.equals(WEIGHT_KEY_OUT_PROJECTION)) {
                                z = 3;
                                break;
                            }
                            break;
                        case 2810:
                            if (key.equals(WEIGHT_KEY_QUERY_PROJECTION)) {
                                z = false;
                                break;
                            }
                            break;
                        case 2815:
                            if (key.equals(WEIGHT_KEY_VALUE_PROJECTION)) {
                                z = 2;
                                break;
                            }
                            break;
                    }
                    switch (z) {
                        case false:
                            WeightInitUtil.initWeights(this.nInQueries, this.headSize, entry.getValue().shape(), this.weightInit, (Distribution) null, 'c', entry.getValue());
                            break;
                        case true:
                            WeightInitUtil.initWeights(this.nInKeys, this.headSize, entry.getValue().shape(), this.weightInit, (Distribution) null, 'c', entry.getValue());
                            break;
                        case true:
                            WeightInitUtil.initWeights(this.nInValues, this.headSize, entry.getValue().shape(), this.weightInit, (Distribution) null, 'c', entry.getValue());
                            break;
                        case true:
                            WeightInitUtil.initWeights(this.nHeads * this.headSize, this.nOut, entry.getValue().shape(), this.weightInit, (Distribution) null, 'c', entry.getValue());
                            break;
                    }
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null) {
            return Pair.of(null, maskState);
        }
        if (iNDArrayArr[0] == null) {
            return null;
        }
        return Pair.of(iNDArrayArr[0], maskState);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public SDVariable defineVertex(SameDiff sameDiff, Map<String, SDVariable> map, Map<String, SDVariable> map2, Map<String, SDVariable> map3) {
        SDVariable sDVariable = map.get("queries");
        SDVariable sDVariable2 = map.get("keys");
        SDVariable sDVariable3 = map.get("values");
        SDVariable min = map3 != null ? sameDiff.min(map3.get("keys"), map3.get("values")) : null;
        SDVariable multiHeadDotProductAttention = this.projectInput ? sameDiff.nn.multiHeadDotProductAttention(getLayerName(), sDVariable, sDVariable2, sDVariable3, map2.get(WEIGHT_KEY_QUERY_PROJECTION), map2.get(WEIGHT_KEY_KEY_PROJECTION), map2.get(WEIGHT_KEY_VALUE_PROJECTION), map2.get(WEIGHT_KEY_OUT_PROJECTION), min, true) : sameDiff.nn.dotProductAttention(getLayerName(), sDVariable, sDVariable2, sDVariable3, min, true);
        return map3 != null ? multiHeadDotProductAttention.mul(sameDiff.expandDims(map3.get("queries"), 1)) : multiHeadDotProductAttention;
    }

    public AttentionVertex() {
        this.nInKeys = 0L;
        this.nInValues = 0L;
        this.nInQueries = 0L;
        this.nOut = 0L;
        this.headSize = 0L;
        this.nHeads = 1;
    }

    public long getNInKeys() {
        return this.nInKeys;
    }

    public long getNInValues() {
        return this.nInValues;
    }

    public long getNInQueries() {
        return this.nInQueries;
    }

    public long getNOut() {
        return this.nOut;
    }

    public long getHeadSize() {
        return this.headSize;
    }

    public int getNHeads() {
        return this.nHeads;
    }

    public boolean isProjectInput() {
        return this.projectInput;
    }

    public WeightInit getWeightInit() {
        return this.weightInit;
    }

    public void setNInKeys(long j) {
        this.nInKeys = j;
    }

    public void setNInValues(long j) {
        this.nInValues = j;
    }

    public void setNInQueries(long j) {
        this.nInQueries = j;
    }

    public void setNOut(long j) {
        this.nOut = j;
    }

    public void setHeadSize(long j) {
        this.headSize = j;
    }

    public void setNHeads(int i) {
        this.nHeads = i;
    }

    public void setProjectInput(boolean z) {
        this.projectInput = z;
    }

    public void setWeightInit(WeightInit weightInit) {
        this.weightInit = weightInit;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex, org.deeplearning4j.nn.conf.graph.GraphVertex
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AttentionVertex)) {
            return false;
        }
        AttentionVertex attentionVertex = (AttentionVertex) obj;
        if (!attentionVertex.canEqual(this) || !super.equals(obj) || getNInKeys() != attentionVertex.getNInKeys() || getNInValues() != attentionVertex.getNInValues() || getNInQueries() != attentionVertex.getNInQueries() || getNOut() != attentionVertex.getNOut() || getHeadSize() != attentionVertex.getHeadSize() || getNHeads() != attentionVertex.getNHeads() || isProjectInput() != attentionVertex.isProjectInput()) {
            return false;
        }
        WeightInit weightInit = getWeightInit();
        WeightInit weightInit2 = attentionVertex.getWeightInit();
        return weightInit == null ? weightInit2 == null : weightInit.equals(weightInit2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    protected boolean canEqual(Object obj) {
        return obj instanceof AttentionVertex;
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex, org.deeplearning4j.nn.conf.graph.GraphVertex
    public int hashCode() {
        int hashCode = super.hashCode();
        long nInKeys = getNInKeys();
        int i = (hashCode * 59) + ((int) ((nInKeys >>> 32) ^ nInKeys));
        long nInValues = getNInValues();
        int i2 = (i * 59) + ((int) ((nInValues >>> 32) ^ nInValues));
        long nInQueries = getNInQueries();
        int i3 = (i2 * 59) + ((int) ((nInQueries >>> 32) ^ nInQueries));
        long nOut = getNOut();
        int i4 = (i3 * 59) + ((int) ((nOut >>> 32) ^ nOut));
        long headSize = getHeadSize();
        int nHeads = (((((i4 * 59) + ((int) ((headSize >>> 32) ^ headSize))) * 59) + getNHeads()) * 59) + (isProjectInput() ? 79 : 97);
        WeightInit weightInit = getWeightInit();
        return (nHeads * 59) + (weightInit == null ? 43 : weightInit.hashCode());
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public String toString() {
        return "AttentionVertex(nInKeys=" + getNInKeys() + ", nInValues=" + getNInValues() + ", nInQueries=" + getNInQueries() + ", nOut=" + getNOut() + ", headSize=" + getHeadSize() + ", nHeads=" + getNHeads() + ", projectInput=" + isProjectInput() + ", weightInit=" + getWeightInit() + ")";
    }
}
