package org.nd4j.linalg.api.ops.impl.accum.distances;

import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BaseAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/accum/distances/CosineSimilarity.class */
public class CosineSimilarity extends BaseAccumulation {
    private Number constantNormalizedByNorm2X;
    private Number constantNormalizedByNorm2Y;

    public CosineSimilarity() {
        this.passThrough = true;
    }

    public CosineSimilarity(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j) {
        super(iNDArray, iNDArray2, iNDArray3, j);
        this.passThrough = Nd4j.getExecutioner().executionMode() == OpExecutioner.ExecutionMode.JAVA;
        this.extraArgs = new Object[2];
        this.extraArgs[0] = Float.valueOf(0.0f);
        this.extraArgs[1] = Float.valueOf(0.0f);
    }

    public CosineSimilarity(INDArray iNDArray, INDArray iNDArray2, long j) {
        super(iNDArray, iNDArray2, j);
        this.passThrough = Nd4j.getExecutioner().executionMode() == OpExecutioner.ExecutionMode.JAVA;
        this.extraArgs = new Object[2];
        this.extraArgs[0] = Float.valueOf(0.0f);
        this.extraArgs[1] = Float.valueOf(0.0f);
    }

    public CosineSimilarity(INDArray iNDArray) {
        super(iNDArray);
        this.passThrough = Nd4j.getExecutioner().executionMode() == OpExecutioner.ExecutionMode.JAVA;
        this.extraArgs = new Object[2];
        this.extraArgs[0] = Float.valueOf(0.0f);
        this.extraArgs[1] = Float.valueOf(0.0f);
    }

    public CosineSimilarity(INDArray iNDArray, INDArray iNDArray2) {
        super(iNDArray, iNDArray2);
        this.passThrough = Nd4j.getExecutioner().executionMode() == OpExecutioner.ExecutionMode.JAVA;
        this.extraArgs = new Object[2];
        this.extraArgs[0] = Float.valueOf(0.0f);
        this.extraArgs[1] = Float.valueOf(0.0f);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public double update(double d, double d2) {
        return d + d2;
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public double update(double d, double d2, double d3) {
        return d + (d2 * d3);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public float update(float f, float f2) {
        return f + f2;
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public float update(float f, float f2, float f3) {
        return f + (f2 * f3);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, double d) {
        return iComplexNumber.add(Double.valueOf(d));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, double d, double d2) {
        return iComplexNumber.add(Double.valueOf(d * d2));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        return iComplexNumber.add(iComplexNumber2);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2, IComplexNumber iComplexNumber3) {
        return iComplexNumber.add(iComplexNumber2.mul(iComplexNumber3));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2, double d) {
        return iComplexNumber.add(iComplexNumber2.mul(Double.valueOf(d)));
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public int opNum() {
        return 2;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public String name() {
        return "cosinesimilarity";
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, double d) {
        this.numProcessed++;
        return iComplexNumber.mul(Double.valueOf(d));
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, float f) {
        this.numProcessed++;
        return iComplexNumber.mul(Float.valueOf(f));
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Op
    public IComplexNumber op(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        this.numProcessed++;
        return iComplexNumber.mul(iComplexNumber2);
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Op
    public float op(float f, float f2) {
        this.numProcessed++;
        return f * f2;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Op
    public double op(double d, double d2) {
        this.numProcessed++;
        return d * d2;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int i2) {
        CosineSimilarity cosineSimilarity = y() != null ? new CosineSimilarity(this.x.vectorAlongDimension(i, i2), this.y.vectorAlongDimension(i, i2), r0.length()) : new CosineSimilarity(this.x.vectorAlongDimension(i, i2));
        cosineSimilarity.setApplyFinalTransform(applyFinalTransform());
        return cosineSimilarity;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int... iArr) {
        CosineSimilarity cosineSimilarity = y() != null ? new CosineSimilarity(this.x.tensorAlongDimension(i, iArr), this.y.tensorAlongDimension(i, iArr), r0.length()) : new CosineSimilarity(this.x.tensorAlongDimension(i, iArr));
        cosineSimilarity.setApplyFinalTransform(applyFinalTransform());
        return cosineSimilarity;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec() {
        this.constantNormalizedByNorm2X = this.x.norm2Number();
        this.constantNormalizedByNorm2Y = this.y.norm2Number();
        this.extraArgs = new Object[]{Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), this.constantNormalizedByNorm2X, this.constantNormalizedByNorm2Y};
        this.finalResult = Double.valueOf(Nd4j.getBlasWrapper().dot(this.x, this.y) / (this.constantNormalizedByNorm2X.doubleValue() * this.constantNormalizedByNorm2Y.doubleValue()));
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec(int... iArr) {
        int[] removeIndex = ArrayUtil.removeIndex(this.x.shape(), iArr);
        int tensorssAlongDimension = this.x.tensorssAlongDimension(iArr);
        this.z = Nd4j.create(removeIndex);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            this.z.putScalar(i, Nd4j.getExecutioner().execAndReturn((Accumulation) opForDimension(i, iArr)).getFinalResult().doubleValue());
        }
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public double getAndSetFinalResult(double d) {
        if (!applyFinalTransform()) {
            return d;
        }
        double doubleValue = d / (this.constantNormalizedByNorm2X.doubleValue() * this.constantNormalizedByNorm2Y.doubleValue());
        this.finalResult = Double.valueOf(doubleValue);
        return doubleValue;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public float getAndSetFinalResult(float f) {
        return (float) getAndSetFinalResult(f);
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber getAndSetFinalResult(IComplexNumber iComplexNumber) {
        this.finalResultComplex = Nd4j.createComplexNumber(Double.valueOf(iComplexNumber.realComponent().doubleValue() / (this.constantNormalizedByNorm2X.doubleValue() * this.constantNormalizedByNorm2Y.doubleValue())), 0);
        return this.finalResultComplex;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public double calculateFinalResult(double d, long j) {
        throw new UnsupportedOperationException("Not supported for passthrough op");
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public float calculateFinalResult(float f, long j) {
        throw new UnsupportedOperationException("Not supported for passthrough op");
    }
}
