package com.kotlinnlp.simplednn.deeplearning.treernn;

import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessorsPool;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessorsPool;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.simplemath.SimpleMathKt;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: TreeEncoder.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��T\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010%\n\u0002\u0010\b\n\u0002\u0018\u0002\n��\n\u0002\u0010#\n\u0002\b\u0002\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0010 \n\u0002\b\u0007\u0018��2\u00020\u0001:\u0001(B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0016\u0010\u0013\u001a\u00020\u00142\f\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u00170\u0016H\u0002J\u0016\u0010\u0018\u001a\u00020\u00142\u0006\u0010\u0019\u001a\u00020\u000e2\u0006\u0010\u001a\u001a\u00020\u0007J\u001a\u0010\u001b\u001a\u00060\u000fR\u00020��2\u0006\u0010\u001c\u001a\u00020\u000e2\u0006\u0010\u001d\u001a\u00020\u0007J\b\u0010\u001e\u001a\u00020\u0014H\u0002J\u0006\u0010\u001f\u001a\u00020\u0014J\u0012\u0010 \u001a\u00060\u000fR\u00020��2\u0006\u0010\u001c\u001a\u00020\u000eJ\f\u0010!\u001a\b\u0012\u0004\u0012\u00020\u000e0\"J\u0014\u0010#\u001a\u00020\u00142\n\u0010$\u001a\u00060\u000fR\u00020��H\u0002J\u0014\u0010%\u001a\u00020\u00142\f\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u00170\u0016J\u0016\u0010&\u001a\u00020\u00142\u0006\u0010\u0019\u001a\u00020\u000e2\u0006\u0010'\u001a\u00020\u000eR\u0014\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00070\tX\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR\u001e\u0010\f\u001a\u0012\u0012\u0004\u0012\u00020\u000e\u0012\b\u0012\u00060\u000fR\u00020��0\rX\u0082\u0004¢\u0006\u0002\n��R\u0018\u0010\u0010\u001a\f\u0012\b\u0012\u00060\u000fR\u00020��0\u0011X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00070\tX\u0082\u0004¢\u0006\u0002\n��¨\u0006)"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeEncoder;", "", "network", "Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeRNN;", "(Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeRNN;)V", "concatProcessorsPool", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessorsPool;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "leftProcessorsPool", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessorsPool;", "getNetwork", "()Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeRNN;", "nodes", "", "", "Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeEncoder$Node;", "nodesWithEncodingErrors", "", "rightProcessorsPool", "accumulateParamsErrors", "", "optimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeRNNParameters;", "addEncodingErrors", "nodeId", "errors", "addNode", "id", "vector", "clearNodeErrors", "clearTree", "getNode", "getRootsIds", "", "launchErrorsPropagation", "node", "propagateErrors", "setHead", "headId", "Node", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/treernn/TreeEncoder.class */
public final class TreeEncoder {
    private final RecurrentNeuralProcessorsPool<DenseNDArray> leftProcessorsPool;
    private final RecurrentNeuralProcessorsPool<DenseNDArray> rightProcessorsPool;
    private final FeedforwardNeuralProcessorsPool<DenseNDArray> concatProcessorsPool;
    private final Map<Integer, Node> nodes;
    private final Set<Node> nodesWithEncodingErrors;

    @NotNull
    private final TreeRNN network;

    /* compiled from: TreeEncoder.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��V\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n\u0002\b\u0005\n\u0002\u0010!\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0010\u0002\n\u0002\b\t\n\u0002\u0010 \n��\n\u0002\u0018\u0002\n\u0002\b\u000b\b\u0086\u0004\u0018��2\u00020\u0001B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\u0019\u0010,\u001a\u00020-2\n\u0010.\u001a\u00060��R\u00020\u0014H��¢\u0006\u0002\b/J\u0015\u00100\u001a\u00020-2\u0006\u00101\u001a\u00020\u0005H��¢\u0006\u0002\b2J\u0010\u00103\u001a\u00020-2\u0006\u00101\u001a\u00020\u0005H\u0002J0\u00104\u001a\u00020-2\f\u00105\u001a\b\u0012\u0004\u0012\u00020\u00050\"2\u0010\u00106\u001a\f\u0012\b\u0012\u00060��R\u00020\u0014072\u0006\u00101\u001a\u00020\u0005H\u0002J\u001c\u00108\u001a\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u0005092\u0006\u00101\u001a\u00020\u0005H\u0002J\b\u0010:\u001a\u00020\u0005H\u0002J0\u0010;\u001a\u00020\u00052\f\u00105\u001a\b\u0012\u0004\u0012\u00020\u00050\"2\u0006\u0010\u0013\u001a\u00020\u00052\u0010\u00106\u001a\f\u0012\b\u0012\u00060��R\u00020\u001407H\u0002J\b\u0010<\u001a\u00020-H\u0002J\u000f\u0010=\u001a\u0004\u0018\u00010\u0005H��¢\u0006\u0002\b>J\r\u0010?\u001a\u00020-H��¢\u0006\u0002\b@J\r\u0010A\u001a\u00020-H��¢\u0006\u0002\bBJ\u001c\u0010C\u001a\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u0005092\u0006\u00101\u001a\u00020\u0005H\u0002R\u0010\u0010\u0007\u001a\u0004\u0018\u00010\u0005X\u0082\u000e¢\u0006\u0002\n��R\u001a\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00050\tX\u0080\u0004¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR$\u0010\r\u001a\u00020\u00052\u0006\u0010\f\u001a\u00020\u0005@BX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u000e\u0010\u000f\"\u0004\b\u0010\u0010\u0011R\u0010\u0010\u0012\u001a\u0004\u0018\u00010\u0005X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u0013\u001a\b\u0018\u00010��R\u00020\u0014X\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016R\u001a\u0010\u0017\u001a\u00020\u0018X\u0080\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0019\u0010\u001a\"\u0004\b\u001b\u0010\u001cR\u001e\u0010\u001d\u001a\f\u0012\b\u0012\u00060��R\u00020\u00140\u001eX\u0080\u0004¢\u0006\b\n��\u001a\u0004\b\u001f\u0010 R\u001a\u0010!\u001a\b\u0012\u0004\u0012\u00020\u00050\"X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b#\u0010$R\u001e\u0010%\u001a\f\u0012\b\u0012\u00060��R\u00020\u00140\u001eX\u0080\u0004¢\u0006\b\n��\u001a\u0004\b&\u0010 R\u001a\u0010'\u001a\b\u0012\u0004\u0012\u00020\u00050\"X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b(\u0010$R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b)\u0010\u000fR\u0011\u0010*\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b+\u0010\u000f¨\u0006D"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeEncoder$Node;", "", "id", "", "vector", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "(Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeEncoder;ILcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "backwardErrors", "concatProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "getConcatProcessor$simplednn", "()Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "<set-?>", "encoding", "getEncoding", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "setEncoding", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "encodingErrors", "head", "Lcom/kotlinnlp/simplednn/deeplearning/treernn/TreeEncoder;", "getId", "()I", "isRoot", "", "isRoot$simplednn", "()Z", "setRoot$simplednn", "(Z)V", "leftChildren", "", "getLeftChildren$simplednn", "()Ljava/util/List;", "leftProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "getLeftProcessor$simplednn", "()Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "rightChildren", "getRightChildren$simplednn", "rightProcessor", "getRightProcessor$simplednn", "getVector", "vectorErrors", "getVectorErrors", "addChild", "", "child", "addChild$simplednn", "addEncodingErrors", "errors", "addEncodingErrors$simplednn", "backward", "backwardChildren", "processor", "children", "", "backwardConcat", "Lkotlin/Pair;", "encode", "encodeChildren", "encodeHead", "getNodeErrors", "getNodeErrors$simplednn", "propagateErrors", "propagateErrors$simplednn", "resetNodeErrors", "resetNodeErrors$simplednn", "splitLeftAndRightErrors", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/treernn/TreeEncoder$Node.class */
    public final class Node {

        @NotNull
        private DenseNDArray encoding;

        @NotNull
        private final DenseNDArray vectorErrors;
        private boolean isRoot;

        @NotNull
        private final List<Node> leftChildren;

        @NotNull
        private final List<Node> rightChildren;

        @NotNull
        private final RecurrentNeuralProcessor<DenseNDArray> leftProcessor;

        @NotNull
        private final RecurrentNeuralProcessor<DenseNDArray> rightProcessor;

        @NotNull
        private final FeedforwardNeuralProcessor<DenseNDArray> concatProcessor;
        private Node head;
        private DenseNDArray backwardErrors;
        private DenseNDArray encodingErrors;
        private final int id;

        @NotNull
        private final DenseNDArray vector;
        final /* synthetic */ TreeEncoder this$0;

        @NotNull
        public final DenseNDArray getEncoding() {
            return this.encoding;
        }

        private final void setEncoding(DenseNDArray denseNDArray) {
            this.encoding = denseNDArray;
        }

        @NotNull
        public final DenseNDArray getVectorErrors() {
            return this.vectorErrors;
        }

        public final boolean isRoot$simplednn() {
            return this.isRoot;
        }

        public final void setRoot$simplednn(boolean z) {
            this.isRoot = z;
        }

        @NotNull
        public final List<Node> getLeftChildren$simplednn() {
            return this.leftChildren;
        }

        @NotNull
        public final List<Node> getRightChildren$simplednn() {
            return this.rightChildren;
        }

        @NotNull
        public final RecurrentNeuralProcessor<DenseNDArray> getLeftProcessor$simplednn() {
            return this.leftProcessor;
        }

        @NotNull
        public final RecurrentNeuralProcessor<DenseNDArray> getRightProcessor$simplednn() {
            return this.rightProcessor;
        }

        @NotNull
        public final FeedforwardNeuralProcessor<DenseNDArray> getConcatProcessor$simplednn() {
            return this.concatProcessor;
        }

        public final void resetNodeErrors$simplednn() {
            this.encodingErrors = (DenseNDArray) null;
            this.backwardErrors = (DenseNDArray) null;
        }

        public final void addEncodingErrors$simplednn(@NotNull DenseNDArray denseNDArray) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
            if (this.encodingErrors == null) {
                this.encodingErrors = denseNDArray.copy();
                return;
            }
            DenseNDArray denseNDArray2 = this.encodingErrors;
            if (denseNDArray2 == null) {
                Intrinsics.throwNpe();
            }
            denseNDArray2.assignSum((NDArray<?>) denseNDArray);
        }

        @Nullable
        public final DenseNDArray getNodeErrors$simplednn() {
            if (this.encodingErrors == null || this.backwardErrors == null) {
                if (this.encodingErrors != null) {
                    return this.encodingErrors;
                }
                if (this.backwardErrors != null) {
                    return this.backwardErrors;
                }
                return null;
            }
            DenseNDArray denseNDArray = this.encodingErrors;
            if (denseNDArray == null) {
                Intrinsics.throwNpe();
            }
            DenseNDArray denseNDArray2 = this.backwardErrors;
            if (denseNDArray2 == null) {
                Intrinsics.throwNpe();
            }
            return denseNDArray.sum(denseNDArray2);
        }

        public final void addChild$simplednn(@NotNull Node node) {
            Intrinsics.checkParameterIsNotNull(node, "child");
            node.isRoot = false;
            node.head = this;
            if (node.id < this.id) {
                this.leftChildren.add(node);
            } else {
                this.rightChildren.add(node);
            }
            this.encoding = encode();
            encodeHead();
        }

        private final DenseNDArray encode() {
            return FeedforwardNeuralProcessor.forward$default((FeedforwardNeuralProcessor) this.concatProcessor, (NDArray) SimpleMathKt.concatVectorsV(encodeChildren(this.leftProcessor, this.vector, this.leftChildren), encodeChildren(this.rightProcessor, this.vector, this.rightChildren)), false, 2, (Object) null);
        }

        public final void propagateErrors$simplednn() {
            boolean z;
            DenseNDArray nodeErrors$simplednn = getNodeErrors$simplednn();
            if (nodeErrors$simplednn != null) {
                Iterable until = RangesKt.until(0, nodeErrors$simplednn.getLength());
                if (!(until instanceof Collection) || !((Collection) until).isEmpty()) {
                    IntIterator it = until.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            z = false;
                            break;
                        } else {
                            if (nodeErrors$simplednn.get(it.nextInt()).doubleValue() != 0.0d) {
                                z = true;
                                break;
                            }
                        }
                    }
                } else {
                    z = false;
                }
                if (z) {
                    backward(nodeErrors$simplednn);
                }
            }
            Iterator<T> it2 = this.leftChildren.iterator();
            while (it2.hasNext()) {
                ((Node) it2.next()).propagateErrors$simplednn();
            }
            Iterator<T> it3 = this.rightChildren.iterator();
            while (it3.hasNext()) {
                ((Node) it3.next()).propagateErrors$simplednn();
            }
        }

        private final void backward(DenseNDArray denseNDArray) {
            Pair<DenseNDArray, DenseNDArray> backwardConcat = backwardConcat(denseNDArray);
            DenseNDArray denseNDArray2 = (DenseNDArray) backwardConcat.component1();
            DenseNDArray denseNDArray3 = (DenseNDArray) backwardConcat.component2();
            backwardChildren(this.leftProcessor, this.leftChildren, denseNDArray2);
            backwardChildren(this.rightProcessor, this.rightChildren, denseNDArray3);
        }

        private final Pair<DenseNDArray, DenseNDArray> backwardConcat(DenseNDArray denseNDArray) {
            FeedforwardNeuralProcessor.backward$default(this.concatProcessor, denseNDArray, true, null, 4, null);
            return splitLeftAndRightErrors(this.concatProcessor.getInputErrors(false));
        }

        private final DenseNDArray encodeChildren(RecurrentNeuralProcessor<DenseNDArray> recurrentNeuralProcessor, DenseNDArray denseNDArray, List<Node> list) {
            RecurrentNeuralProcessor.forward$default((RecurrentNeuralProcessor) recurrentNeuralProcessor, (NDArray) denseNDArray, true, (List) null, false, false, 28, (Object) null);
            Iterator<T> it = list.iterator();
            while (it.hasNext()) {
                RecurrentNeuralProcessor.forward$default((RecurrentNeuralProcessor) recurrentNeuralProcessor, (NDArray) ((Node) it.next()).encoding, false, (List) null, false, false, 28, (Object) null);
            }
            return recurrentNeuralProcessor.getOutput(false);
        }

        private final void backwardChildren(RecurrentNeuralProcessor<DenseNDArray> recurrentNeuralProcessor, List<Node> list, DenseNDArray denseNDArray) {
            RecurrentNeuralProcessor.backward$default((RecurrentNeuralProcessor) recurrentNeuralProcessor, denseNDArray, true, (List) null, 4, (Object) null);
            List inputSequenceErrors$default = RecurrentNeuralProcessor.getInputSequenceErrors$default(recurrentNeuralProcessor, false, 1, null);
            this.vectorErrors.assignSum((NDArray<?>) inputSequenceErrors$default.get(0));
            int i = 0;
            for (Object obj : list) {
                int i2 = i;
                i++;
                ((Node) obj).backwardErrors = (DenseNDArray) inputSequenceErrors$default.get(i2 + 1);
            }
        }

        private final Pair<DenseNDArray, DenseNDArray> splitLeftAndRightErrors(DenseNDArray denseNDArray) {
            return new Pair<>(denseNDArray.getRange(0, denseNDArray.getLength() / 2), denseNDArray.getRange(denseNDArray.getLength() / 2, denseNDArray.getLength()));
        }

        private final void encodeHead() {
            Node node = this.head;
            if (node != null) {
                node.encoding = node.encode();
                node.encodeHead();
            }
        }

        public final int getId() {
            return this.id;
        }

        @NotNull
        public final DenseNDArray getVector() {
            return this.vector;
        }

        public Node(TreeEncoder treeEncoder, @NotNull int i, DenseNDArray denseNDArray) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "vector");
            this.this$0 = treeEncoder;
            this.id = i;
            this.vector = denseNDArray;
            this.vectorErrors = DenseNDArrayFactory.INSTANCE.zeros(this.vector.getShape());
            this.isRoot = true;
            this.leftChildren = new ArrayList();
            this.rightChildren = new ArrayList();
            this.leftProcessor = (RecurrentNeuralProcessor) treeEncoder.leftProcessorsPool.getItem();
            this.rightProcessor = (RecurrentNeuralProcessor) treeEncoder.rightProcessorsPool.getItem();
            this.concatProcessor = (FeedforwardNeuralProcessor) treeEncoder.concatProcessorsPool.getItem();
            this.encoding = encode();
        }
    }

    @NotNull
    public final Node addNode(int i, @NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "vector");
        if (!(!this.nodes.containsKey(Integer.valueOf(i)))) {
            throw new IllegalArgumentException(("Node " + i + " already inserted").toString());
        }
        this.nodes.put(Integer.valueOf(i), new Node(this, i, denseNDArray));
        Node node = this.nodes.get(Integer.valueOf(i));
        if (node == null) {
            Intrinsics.throwNpe();
        }
        return node;
    }

    @NotNull
    public final Node getNode(int i) {
        if (!this.nodes.containsKey(Integer.valueOf(i))) {
            throw new IllegalArgumentException(("Node " + i + " not found").toString());
        }
        Node node = this.nodes.get(Integer.valueOf(i));
        if (node == null) {
            Intrinsics.throwNpe();
        }
        return node;
    }

    public final void setHead(int i, int i2) {
        if (!(i != i2)) {
            throw new IllegalArgumentException(("Cannot set node " + i + " as head of itself").toString());
        }
        if (!this.nodes.containsKey(Integer.valueOf(i))) {
            throw new IllegalArgumentException(("Node " + i + " not found").toString());
        }
        if (!this.nodes.containsKey(Integer.valueOf(i2))) {
            throw new IllegalArgumentException(("Head node " + i2 + " not found").toString());
        }
        Node node = this.nodes.get(Integer.valueOf(i2));
        if (node == null) {
            Intrinsics.throwNpe();
        }
        Node node2 = node;
        Node node3 = this.nodes.get(Integer.valueOf(i));
        if (node3 == null) {
            Intrinsics.throwNpe();
        }
        node2.addChild$simplednn(node3);
    }

    public final void addEncodingErrors(int i, @NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
        if (!this.nodes.containsKey(Integer.valueOf(i))) {
            throw new IllegalArgumentException(("Node " + i + " not found").toString());
        }
        Node node = this.nodes.get(Integer.valueOf(i));
        if (node == null) {
            Intrinsics.throwNpe();
        }
        Node node2 = node;
        if (Intrinsics.areEqual(denseNDArray.getShape(), node2.getVector().getShape())) {
            node2.addEncodingErrors$simplednn(denseNDArray);
            this.nodesWithEncodingErrors.add(node2);
        } else {
            Object[] objArr = {Integer.valueOf(denseNDArray.getLength()), Integer.valueOf(node2.getVector().getLength())};
            String format = String.format("Errors size (%d) not compatible with vector size (%d)", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format.toString());
        }
    }

    @NotNull
    public final List<Integer> getRootsIds() {
        Collection<Node> values = this.nodes.values();
        ArrayList arrayList = new ArrayList();
        for (Object obj : values) {
            if (((Node) obj).isRoot$simplednn()) {
                arrayList.add(obj);
            }
        }
        ArrayList arrayList2 = arrayList;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList2, 10));
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            arrayList3.add(Integer.valueOf(((Node) it.next()).getId()));
        }
        return arrayList3;
    }

    public final void propagateErrors(@NotNull ParamsOptimizer<TreeRNNParameters> paramsOptimizer) {
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "optimizer");
        Collection<Node> values = this.nodes.values();
        ArrayList arrayList = new ArrayList();
        for (Object obj : values) {
            if (((Node) obj).isRoot$simplednn()) {
                arrayList.add(obj);
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            launchErrorsPropagation((Node) it.next());
        }
        accumulateParamsErrors(paramsOptimizer);
        clearNodeErrors();
    }

    public final void clearTree() {
        this.nodes.clear();
        this.nodesWithEncodingErrors.clear();
        this.leftProcessorsPool.releaseAll();
        this.rightProcessorsPool.releaseAll();
        this.concatProcessorsPool.releaseAll();
    }

    private final void launchErrorsPropagation(Node node) {
        if (this.nodesWithEncodingErrors.contains(node)) {
            node.propagateErrors$simplednn();
            return;
        }
        Iterator<T> it = node.getLeftChildren$simplednn().iterator();
        while (it.hasNext()) {
            launchErrorsPropagation((Node) it.next());
        }
        Iterator<T> it2 = node.getRightChildren$simplednn().iterator();
        while (it2.hasNext()) {
            launchErrorsPropagation((Node) it2.next());
        }
    }

    private final void accumulateParamsErrors(ParamsOptimizer<TreeRNNParameters> paramsOptimizer) {
        boolean z;
        paramsOptimizer.newBatch();
        for (Node node : this.nodes.values()) {
            DenseNDArray nodeErrors$simplednn = node.getNodeErrors$simplednn();
            if (nodeErrors$simplednn != null) {
                Iterable until = RangesKt.until(0, nodeErrors$simplednn.getLength());
                if (!(until instanceof Collection) || !((Collection) until).isEmpty()) {
                    IntIterator it = until.iterator();
                    while (true) {
                        if (it.hasNext()) {
                            if (nodeErrors$simplednn.get(it.nextInt()).doubleValue() != 0.0d) {
                                z = true;
                                break;
                            }
                        } else {
                            z = false;
                            break;
                        }
                    }
                } else {
                    z = false;
                }
                if (z) {
                    paramsOptimizer.newExample();
                    Optimizer.accumulate$default(paramsOptimizer, new TreeRNNParameters(node.getLeftProcessor$simplednn().getParamsErrors(false), node.getRightProcessor$simplednn().getParamsErrors(false), node.getConcatProcessor$simplednn().getParamsErrors(false)), false, 2, null);
                }
            }
        }
    }

    private final void clearNodeErrors() {
        Iterator<T> it = this.nodes.values().iterator();
        while (it.hasNext()) {
            ((Node) it.next()).resetNodeErrors$simplednn();
        }
        this.nodesWithEncodingErrors.clear();
    }

    @NotNull
    public final TreeRNN getNetwork() {
        return this.network;
    }

    public TreeEncoder(@NotNull TreeRNN treeRNN) {
        Intrinsics.checkParameterIsNotNull(treeRNN, "network");
        this.network = treeRNN;
        this.leftProcessorsPool = new RecurrentNeuralProcessorsPool<>(this.network.getLeftRNN());
        this.rightProcessorsPool = new RecurrentNeuralProcessorsPool<>(this.network.getRightRNN());
        this.concatProcessorsPool = new FeedforwardNeuralProcessorsPool<>(this.network.getConcatNetwork());
        this.nodes = new LinkedHashMap();
        this.nodesWithEncodingErrors = new LinkedHashSet();
    }
}
