package com.kotlinnlp.simplednn.core.embeddings;

import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.LinkedHashMap;
import java.util.Map;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: EmbeddingsOptimizer.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��H\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010\b\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010%\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0002\b\u0003\u0018��*\u0006\b��\u0010\u0001 ��2\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u00050\u00030\u0002:\u0001\u0019B\u001f\u0012\f\u0010\u0006\u001a\b\u0012\u0004\u0012\u00028��0\u0007\u0012\n\u0010\b\u001a\u0006\u0012\u0002\b\u00030\t¢\u0006\u0002\u0010\nJ\u001d\u0010\u000e\u001a\u00020\u000f2\b\u0010\u0010\u001a\u0004\u0018\u00018��2\u0006\u0010\u0011\u001a\u00020\u0005¢\u0006\u0002\u0010\u0012J\u0016\u0010\u000e\u001a\u00020\u000f2\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0011\u001a\u00020\u0005J$\u0010\u000e\u001a\u00020\u000f2\u0012\u0010\u0015\u001a\u000e\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u00050\u00032\u0006\u0010\u0016\u001a\u00020\u0017H\u0016J\b\u0010\u0018\u001a\u00020\u000fH\u0016R\u001a\u0010\u000b\u001a\u000e\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\r0\fX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0006\u001a\b\u0012\u0004\u0012\u00028��0\u0007X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u001a"}, d2 = {"Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsOptimizer;", "T", "Lcom/kotlinnlp/simplednn/core/optimizer/Optimizer;", "Lkotlin/Pair;", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "embeddingsMap", "Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsMap;", "updateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "(Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsMap;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;)V", "embeddingsErrorsMap", "", "Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsOptimizer$ErrorsAccumulator;", "accumulate", "", "embeddingKey", "errors", "(Ljava/lang/Object;Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "embedding", "Lcom/kotlinnlp/simplednn/core/embeddings/Embedding;", "paramsErrors", "copy", "", "update", "ErrorsAccumulator", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/embeddings/EmbeddingsOptimizer.class */
public final class EmbeddingsOptimizer<T> extends Optimizer<Pair<? extends Integer, ? extends DenseNDArray>> {
    private final Map<Integer, ErrorsAccumulator> embeddingsErrorsMap;
    private final EmbeddingsMap<T> embeddingsMap;

    /* JADX INFO: Access modifiers changed from: private */
    /* compiled from: EmbeddingsOptimizer.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��.\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\b\n\u0002\u0010\u0002\n\u0002\b\u0004\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u000e\n��\b\u0082\b\u0018��2\u00020\u0001B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\u000e\u0010\r\u001a\u00020\u000e2\u0006\u0010\u0002\u001a\u00020\u0003J\t\u0010\u000f\u001a\u00020\u0003HÆ\u0003J\t\u0010\u0010\u001a\u00020\u0005HÆ\u0003J\u001d\u0010\u0011\u001a\u00020��2\b\b\u0002\u0010\u0002\u001a\u00020\u00032\b\b\u0002\u0010\u0004\u001a\u00020\u0005HÆ\u0001J\u0013\u0010\u0012\u001a\u00020\u00132\b\u0010\u0014\u001a\u0004\u0018\u00010\u0001HÖ\u0003J\t\u0010\u0015\u001a\u00020\u0005HÖ\u0001J\t\u0010\u0016\u001a\u00020\u0017HÖ\u0001R\u001a\u0010\u0004\u001a\u00020\u0005X\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0007\u0010\b\"\u0004\b\t\u0010\nR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u000b\u0010\f¨\u0006\u0018"}, d2 = {"Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsOptimizer$ErrorsAccumulator;", "", "errors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "count", "", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;I)V", "getCount", "()I", "setCount", "(I)V", "getErrors", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "accumulate", "", "component1", "component2", "copy", "equals", "", "other", "hashCode", "toString", "", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/core/embeddings/EmbeddingsOptimizer$ErrorsAccumulator.class */
    public static final class ErrorsAccumulator {

        @NotNull
        private final DenseNDArray errors;
        private int count;

        public final void accumulate(@NotNull DenseNDArray denseNDArray) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
            this.errors.assignSum((NDArray<?>) denseNDArray);
            this.count++;
        }

        @NotNull
        public final DenseNDArray getErrors() {
            return this.errors;
        }

        public final int getCount() {
            return this.count;
        }

        public final void setCount(int i) {
            this.count = i;
        }

        public ErrorsAccumulator(@NotNull DenseNDArray denseNDArray, int i) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
            this.errors = denseNDArray;
            this.count = i;
        }

        @NotNull
        public final DenseNDArray component1() {
            return this.errors;
        }

        public final int component2() {
            return this.count;
        }

        @NotNull
        public final ErrorsAccumulator copy(@NotNull DenseNDArray denseNDArray, int i) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
            return new ErrorsAccumulator(denseNDArray, i);
        }

        @NotNull
        public static /* bridge */ /* synthetic */ ErrorsAccumulator copy$default(ErrorsAccumulator errorsAccumulator, DenseNDArray denseNDArray, int i, int i2, Object obj) {
            if ((i2 & 1) != 0) {
                denseNDArray = errorsAccumulator.errors;
            }
            if ((i2 & 2) != 0) {
                i = errorsAccumulator.count;
            }
            return errorsAccumulator.copy(denseNDArray, i);
        }

        public String toString() {
            return "ErrorsAccumulator(errors=" + this.errors + ", count=" + this.count + ")";
        }

        public int hashCode() {
            DenseNDArray denseNDArray = this.errors;
            return ((denseNDArray != null ? denseNDArray.hashCode() : 0) * 31) + Integer.hashCode(this.count);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof ErrorsAccumulator)) {
                return false;
            }
            ErrorsAccumulator errorsAccumulator = (ErrorsAccumulator) obj;
            if (Intrinsics.areEqual(this.errors, errorsAccumulator.errors)) {
                return this.count == errorsAccumulator.count;
            }
            return false;
        }
    }

    @Override // com.kotlinnlp.simplednn.core.optimizer.Optimizer, com.kotlinnlp.simplednn.core.optimizer.ScheduledUpdater
    public void update() {
        for (Map.Entry<Integer, ErrorsAccumulator> entry : this.embeddingsErrorsMap.entrySet()) {
            int intValue = entry.getKey().intValue();
            ErrorsAccumulator value = entry.getValue();
            value.getErrors().assignDiv(value.getCount());
            UpdateMethod<?> updateMethod = getUpdateMethod();
            Embedding byId = this.embeddingsMap.getById(intValue);
            if (byId == null) {
                Intrinsics.throwNpe();
            }
            updateMethod.update(byId.getArray(), value.getErrors());
        }
        this.embeddingsErrorsMap.clear();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* renamed from: accumulate, reason: avoid collision after fix types in other method */
    public void accumulate2(@NotNull Pair<Integer, DenseNDArray> pair, boolean z) {
        Intrinsics.checkParameterIsNotNull(pair, "paramsErrors");
        ErrorsAccumulator errorsAccumulator = this.embeddingsErrorsMap.get(pair.getFirst());
        if (errorsAccumulator != null) {
            errorsAccumulator.accumulate((DenseNDArray) pair.getSecond());
        } else {
            this.embeddingsErrorsMap.put(pair.getFirst(), new ErrorsAccumulator(((DenseNDArray) pair.getSecond()).copy(), 1));
        }
    }

    @Override // com.kotlinnlp.simplednn.core.optimizer.Optimizer
    public /* bridge */ /* synthetic */ void accumulate(Pair<? extends Integer, ? extends DenseNDArray> pair, boolean z) {
        accumulate2((Pair<Integer, DenseNDArray>) pair, z);
    }

    public final void accumulate(@NotNull Embedding embedding, @NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(embedding, "embedding");
        Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
        Optimizer.accumulate$default(this, new Pair(Integer.valueOf(embedding.getId()), denseNDArray), false, 2, null);
    }

    public final void accumulate(@Nullable T t, @NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
        Optimizer.accumulate$default(this, new Pair(Integer.valueOf(EmbeddingsMap.get$default(this.embeddingsMap, t, 0.0d, 2, null).getId()), denseNDArray), false, 2, null);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    /* JADX WARN: Multi-variable type inference failed */
    public EmbeddingsOptimizer(@NotNull EmbeddingsMap<? super T> embeddingsMap, @NotNull UpdateMethod<?> updateMethod) {
        super(updateMethod);
        Intrinsics.checkParameterIsNotNull(embeddingsMap, "embeddingsMap");
        Intrinsics.checkParameterIsNotNull(updateMethod, "updateMethod");
        this.embeddingsMap = embeddingsMap;
        this.embeddingsErrorsMap = new LinkedHashMap();
    }
}
