package com.kotlinnlp.languagemodel.training;

import com.kotlinnlp.languagemodel.CharLM;
import com.kotlinnlp.simplednn.core.arrays.UpdatableDenseArray;
import com.kotlinnlp.simplednn.core.embeddings.Embedding;
import com.kotlinnlp.simplednn.core.embeddings.EmbeddingsMap;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: Processor.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��R\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010\u000e\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0007\b��\u0018��2,\u0012\u0004\u0012\u00020\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00060\u0001B\u0015\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n¢\u0006\u0002\u0010\u000bJ\u0016\u0010\u001b\u001a\u00020\u001c2\f\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003H\u0016J\u0016\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00040\u00032\u0006\u0010\u001f\u001a\u00020\u0002H\u0016J\u0010\u0010 \u001a\u00020\u00052\u0006\u0010!\u001a\u00020\nH\u0016J\u0010\u0010\"\u001a\u00020\u00062\u0006\u0010!\u001a\u00020\nH\u0016R\u0014\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u00040\rX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000e\u001a\u00020\u000fX\u0096D¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R\u0014\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00130\u0003X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00130\u0003X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0015\u001a\u00020\nX\u0096D¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0017R\u0014\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\u00040\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\t\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001a\u0010\u0017¨\u0006#"}, d2 = {"Lcom/kotlinnlp/languagemodel/training/Processor;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor$NoInputErrors;", "Lcom/kotlinnlp/languagemodel/training/ParamsErrors;", "model", "Lcom/kotlinnlp/languagemodel/CharLM;", "useDropout", "", "(Lcom/kotlinnlp/languagemodel/CharLM;Z)V", "classifierProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "id", "", "getId", "()I", "lastCharsEmbeddings", "Lcom/kotlinnlp/simplednn/core/embeddings/Embedding;", "lastCharsEmbeddingsErrors", "propagateToInput", "getPropagateToInput", "()Z", "recurrentProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "getUseDropout", "backward", "", "outputErrors", "forward", "input", "getInputErrors", "copy", "getParamsErrors", "languagemodel"})
/* loaded from: input_file:com/kotlinnlp/languagemodel/training/Processor.class */
public final class Processor implements NeuralProcessor<String, List<? extends DenseNDArray>, List<? extends DenseNDArray>, NeuralProcessor.NoInputErrors, ParamsErrors> {
    private final boolean propagateToInput = false;
    private final int id = 0;
    private final RecurrentNeuralProcessor<DenseNDArray> recurrentProcessor;
    private final BatchFeedforwardProcessor<DenseNDArray> classifierProcessor;
    private List<Embedding> lastCharsEmbeddings;
    private List<Embedding> lastCharsEmbeddingsErrors;
    private final CharLM model;
    private final boolean useDropout;

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    @Override // com.kotlinnlp.utils.ItemsPool.IDItem
    public int getId() {
        return this.id;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public List<DenseNDArray> forward(@NotNull String input) {
        Intrinsics.checkParameterIsNotNull(input, "input");
        String str = input;
        ArrayList arrayList = new ArrayList(str.length());
        for (int i = 0; i < str.length(); i++) {
            arrayList.add(EmbeddingsMap.get$default(this.model.getCharsEmbeddings(), Character.valueOf(str.charAt(i)), 0.0d, 2, null));
        }
        this.lastCharsEmbeddings = arrayList;
        BatchFeedforwardProcessor<DenseNDArray> batchFeedforwardProcessor = this.classifierProcessor;
        RecurrentNeuralProcessor<DenseNDArray> recurrentNeuralProcessor = this.recurrentProcessor;
        List<Embedding> list = this.lastCharsEmbeddings;
        ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            arrayList2.add(((Embedding) it.next()).getArray().getValues());
        }
        return batchFeedforwardProcessor.forward((List<? extends DenseNDArray>) recurrentNeuralProcessor.forward((List<? extends DenseNDArray>) arrayList2));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* renamed from: backward, reason: avoid collision after fix types in other method */
    public void backward2(@NotNull List<DenseNDArray> outputErrors) {
        Intrinsics.checkParameterIsNotNull(outputErrors, "outputErrors");
        this.classifierProcessor.backward2(outputErrors);
        this.recurrentProcessor.backward2((List<DenseNDArray>) this.classifierProcessor.getInputErrors2(false));
        List<? extends DenseNDArray> inputErrors2 = this.recurrentProcessor.getInputErrors2(false);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(inputErrors2, 10));
        int i = 0;
        Iterator<T> it = inputErrors2.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            arrayList.add(new Embedding(this.lastCharsEmbeddings.get(i2).getId(), new UpdatableDenseArray(((DenseNDArray) it.next()).copy())));
        }
        this.lastCharsEmbeddingsErrors = arrayList;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ void backward(List<? extends DenseNDArray> list) {
        backward2((List<DenseNDArray>) list);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getInputErrors */
    public NeuralProcessor.NoInputErrors getInputErrors2(boolean z) {
        return NeuralProcessor.NoInputErrors.INSTANCE;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getParamsErrors */
    public ParamsErrors getParamsErrors2(boolean z) {
        return new ParamsErrors(new CharLM.RecurrentClassifierParameters(this.recurrentProcessor.getParamsErrors2(false), this.classifierProcessor.getParamsErrors2(false)), this.lastCharsEmbeddingsErrors);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getUseDropout() {
        return this.useDropout;
    }

    public Processor(@NotNull CharLM model, boolean z) {
        Intrinsics.checkParameterIsNotNull(model, "model");
        this.model = model;
        this.useDropout = z;
        this.recurrentProcessor = new RecurrentNeuralProcessor<>(this.model.getRecurrentNetwork(), getUseDropout(), true, null, 0, 24, null);
        this.classifierProcessor = new BatchFeedforwardProcessor<>(this.model.getClassifier(), getUseDropout(), true, null, 0, 24, null);
        this.lastCharsEmbeddings = CollectionsKt.emptyList();
        this.lastCharsEmbeddingsErrors = CollectionsKt.emptyList();
    }

    @NotNull
    /* renamed from: propagateErrors, reason: avoid collision after fix types in other method */
    public NeuralProcessor.NoInputErrors propagateErrors2(@NotNull List<DenseNDArray> errors, @NotNull com.kotlinnlp.simplednn.core.optimizer.Optimizer<? super ParamsErrors> optimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(errors, "errors");
        Intrinsics.checkParameterIsNotNull(optimizer, "optimizer");
        return (NeuralProcessor.NoInputErrors) NeuralProcessor.DefaultImpls.propagateErrors(this, errors, optimizer, z);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ NeuralProcessor.NoInputErrors propagateErrors(List<? extends DenseNDArray> list, com.kotlinnlp.simplednn.core.optimizer.Optimizer<? super ParamsErrors> optimizer, boolean z) {
        return propagateErrors2((List<DenseNDArray>) list, optimizer, z);
    }
}
