package com.kotlinnlp.simplednn.core.layers.models.merge.distance;

import com.kotlinnlp.simplednn.core.layers.Layer;
import com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import kotlin.Metadata;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: DistanceBackwardHelper.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��$\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n��\b��\u0018��2\b\u0012\u0004\u0012\u00020\u00020\u0001B\r\u0012\u0006\u0010\u0003\u001a\u00020\u0004¢\u0006\u0002\u0010\u0005J\b\u0010\b\u001a\u00020\tH\u0002J\u0010\u0010\n\u001a\u00020\t2\u0006\u0010\u000b\u001a\u00020\fH\u0014R\u0014\u0010\u0003\u001a\u00020\u0004X\u0094\u0004¢\u0006\b\n��\u001a\u0004\b\u0006\u0010\u0007¨\u0006\r"}, d2 = {"Lcom/kotlinnlp/simplednn/core/layers/models/merge/distance/DistanceBackwardHelper;", "Lcom/kotlinnlp/simplednn/core/layers/helpers/BackwardHelper;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "layer", "Lcom/kotlinnlp/simplednn/core/layers/models/merge/distance/DistanceLayer;", "(Lcom/kotlinnlp/simplednn/core/layers/models/merge/distance/DistanceLayer;)V", "getLayer", "()Lcom/kotlinnlp/simplednn/core/layers/models/merge/distance/DistanceLayer;", "assignLayerGradients", "", "execBackward", "propagateToInput", "", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/layers/models/merge/distance/DistanceBackwardHelper.class */
public final class DistanceBackwardHelper extends BackwardHelper<DenseNDArray> {

    @NotNull
    private final DistanceLayer layer;

    @Override // com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper
    protected void execBackward(boolean z) {
        if (z) {
            assignLayerGradients();
        }
    }

    /* JADX WARN: Type inference failed for: r0v22, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r0v26, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r0v35, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r0v44, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r1v12, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r1v18, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r1v27, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r1v36, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    /* JADX WARN: Type inference failed for: r1v6, types: [com.kotlinnlp.simplednn.core.layers.models.merge.distance.DistanceLayer] */
    private final void assignLayerGradients() {
        DenseNDArray errors = getLayer2().getOutputArray().getErrors();
        errors.assignProd((NDArray<?>) getLayer2().getOutputArray().getValues());
        double doubleValue = errors.get(0).doubleValue();
        DenseNDArray fill = DenseNDArrayFactory.INSTANCE.fill(getLayer2().getInputArray1$simplednn().getValues().getShape(), doubleValue);
        DenseNDArray fill2 = DenseNDArrayFactory.INSTANCE.fill(getLayer2().getInputArray2$simplednn().getValues().getShape(), doubleValue);
        IntIterator it = RangesKt.until(0, getLayer2().getInputArray1$simplednn().getValues().getLength()).iterator();
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            if (getLayer2().getInputArray1$simplednn().getValues().get(nextInt).doubleValue() > getLayer2().getInputArray2$simplednn().getValues().get(nextInt).doubleValue()) {
                fill.set(nextInt, Double.valueOf(fill.get(nextInt).doubleValue() * (-1.0d)));
            } else if (getLayer2().getInputArray1$simplednn().getValues().get(nextInt).doubleValue() < getLayer2().getInputArray2$simplednn().getValues().get(nextInt).doubleValue()) {
                fill2.set(nextInt, Double.valueOf(fill2.get(nextInt).doubleValue() * (-1.0d)));
            } else {
                fill.set(nextInt, Double.valueOf(0.0d));
                fill2.set(nextInt, Double.valueOf(0.0d));
            }
        }
        getLayer2().getInputArray1$simplednn().assignErrors(fill);
        getLayer2().getInputArray2$simplednn().assignErrors(fill2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.layers.helpers.BackwardHelper
    @NotNull
    /* renamed from: getLayer */
    public Layer<DenseNDArray> getLayer2() {
        return this.layer;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public DistanceBackwardHelper(@NotNull DistanceLayer distanceLayer) {
        super(distanceLayer);
        Intrinsics.checkParameterIsNotNull(distanceLayer, "layer");
        this.layer = distanceLayer;
    }
}
