package org.deeplearning4j.clustering.sptree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/clustering/sptree/SpTree.class */
public class SpTree implements Serializable {
    public static final String workspaceExternal = "SPTREE_LOOP_EXTERNAL";
    private int D;
    private INDArray data;
    public static final int NODE_RATIO = 8000;
    private int N;
    private int size;
    private int cumSize;
    private Cell boundary;
    private INDArray centerOfMass;
    private SpTree parent;
    private int[] index;
    private int nodeCapacity;
    private int numChildren;
    private boolean isLeaf;
    private Collection<INDArray> indices;
    private SpTree[] children;
    private static Logger log = LoggerFactory.getLogger((Class<?>) SpTree.class);
    private String similarityFunction;

    public SpTree(SpTree spTree, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Collection<INDArray> collection, String str) {
        this.numChildren = 2;
        this.isLeaf = true;
        this.similarityFunction = Distance.EUCLIDEAN.toString();
        init(spTree, iNDArray, iNDArray2, iNDArray3, collection, str);
    }

    public SpTree(INDArray iNDArray, Collection<INDArray> collection, String str) {
        this.numChildren = 2;
        this.isLeaf = true;
        this.similarityFunction = Distance.EUCLIDEAN.toString();
        this.indices = collection;
        this.N = iNDArray.rows();
        this.D = iNDArray.columns();
        this.similarityFunction = str;
        INDArray dup = iNDArray.dup();
        INDArray mean = dup.mean(0);
        INDArray min = dup.min(0);
        INDArray max = dup.max(0);
        INDArray create = Nd4j.create(dup.dataType(), mean.shape());
        for (int i = 0; i < create.length(); i++) {
            create.putScalar(i, Math.max(max.getDouble(i) - mean.getDouble(i), mean.getDouble(i) - min.getDouble(i)) + Nd4j.EPS_THRESHOLD);
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                init(null, dup, mean, create, collection, str);
                fill(this.N);
                if (scopeOutOfWorkspaces != null) {
                    if (0 == 0) {
                        scopeOutOfWorkspaces.close();
                        return;
                    }
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th4;
        }
    }

    public SpTree(SpTree spTree, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Collection<INDArray> collection) {
        this(spTree, iNDArray, iNDArray2, iNDArray3, collection, "euclidean");
    }

    public SpTree(INDArray iNDArray, Collection<INDArray> collection) {
        this(iNDArray, collection, "euclidean");
    }

    public SpTree(INDArray iNDArray) {
        this(iNDArray, new ArrayList());
    }

    public MemoryWorkspace workspace() {
        return null;
    }

    private void init(SpTree spTree, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, Collection<INDArray> collection, String str) {
        this.parent = spTree;
        this.D = iNDArray.columns();
        this.N = iNDArray.rows();
        this.similarityFunction = str;
        this.nodeCapacity = this.N % NODE_RATIO;
        this.index = new int[this.nodeCapacity];
        for (int i = 1; i < this.D; i++) {
            this.numChildren *= 2;
        }
        this.indices = collection;
        this.isLeaf = true;
        this.size = 0;
        this.cumSize = 0;
        this.children = new SpTree[this.numChildren];
        this.data = iNDArray;
        this.boundary = new Cell(this.D);
        this.boundary.setCorner(iNDArray2.dup());
        this.boundary.setWidth(iNDArray3.dup());
        this.centerOfMass = Nd4j.create(iNDArray.dataType(), this.D);
    }

    private boolean insert(int i) {
        INDArray slice = this.data.slice(i);
        if (!this.boundary.contains(slice)) {
            return false;
        }
        this.cumSize++;
        this.centerOfMass.muli(Double.valueOf((this.cumSize - 1) / this.cumSize));
        this.centerOfMass.addi(slice.mul(Double.valueOf(1.0d / this.cumSize)));
        if (isLeaf() && this.size < this.nodeCapacity) {
            this.index[this.size] = i;
            this.indices.add(slice);
            this.size++;
            return true;
        }
        for (int i2 = 0; i2 < this.size; i2++) {
            if (this.data.slice(this.index[i2]).equals(slice)) {
                return true;
            }
        }
        if (isLeaf()) {
            subDivide();
        }
        for (int i3 = 0; i3 < this.numChildren; i3++) {
            if (this.children[i3].insert(i)) {
                return true;
            }
        }
        throw new IllegalStateException("Shouldn't reach this state");
    }

    public void subDivide() {
        INDArray create = Nd4j.create(this.data.dataType(), this.D);
        INDArray create2 = Nd4j.create(this.data.dataType(), this.D);
        for (int i = 0; i < this.numChildren; i++) {
            int i2 = 1;
            for (int i3 = 0; i3 < this.D; i3++) {
                create2.putScalar(i3, 0.5d * this.boundary.width(i3));
                if ((i / i2) % 2 == 1) {
                    create.putScalar(i3, this.boundary.corner(i3) - (0.5d * this.boundary.width(i3)));
                } else {
                    create.putScalar(i3, this.boundary.corner(i3) + (0.5d * this.boundary.width(i3)));
                }
                i2 *= 2;
            }
            this.children[i] = new SpTree(this, this.data, create, create2, this.indices);
        }
        for (int i4 = 0; i4 < this.size; i4++) {
            boolean z = false;
            for (int i5 = 0; i5 < this.numChildren; i5++) {
                if (!z) {
                    z = this.children[i5].insert(this.index[i4]);
                }
            }
            this.index[i4] = -1;
        }
        this.size = 0;
        this.isLeaf = false;
    }

    public void computeNonEdgeForces(int i, double d, INDArray iNDArray, AtomicDouble atomicDouble) {
        INDArray create = Nd4j.create(this.data.dataType(), this.D);
        if (this.cumSize != 0) {
            if (isLeaf() && this.size == 1 && this.index[0] == i) {
                return;
            }
            this.data.slice(i).subi(this.centerOfMass, create);
            double dot = Nd4j.getBlasWrapper().dot(create, create);
            double doubleValue = this.boundary.width().maxNumber().doubleValue();
            if (isLeaf() || doubleValue / Math.sqrt(dot) < d) {
                double d2 = 1.0d / (1.0d + dot);
                double d3 = this.cumSize * d2;
                atomicDouble.addAndGet(d3);
                iNDArray.addi(create.mul(Double.valueOf(d3 * d2)));
                return;
            }
            for (int i2 = 0; i2 < this.numChildren; i2++) {
                this.children[i2].computeNonEdgeForces(i, d, iNDArray, atomicDouble);
            }
        }
    }

    public void computeEdgeForces(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i, INDArray iNDArray4) {
        if (!iNDArray.isVector()) {
            throw new IllegalArgumentException("RowP must be a vector");
        }
        Nd4j.exec(new BarnesEdgeForces(iNDArray, iNDArray2, iNDArray3, this.data, i, iNDArray4));
    }

    public boolean isLeaf() {
        return this.isLeaf;
    }

    public boolean isCorrect() {
        for (int i = 0; i < this.size; i++) {
            if (!this.boundary.contains(this.data.slice(this.index[i]))) {
                return false;
            }
        }
        if (isLeaf()) {
            return true;
        }
        boolean z = true;
        for (int i2 = 0; i2 < this.numChildren; i2++) {
            z = z && this.children[i2].isCorrect();
        }
        return z;
    }

    public int depth() {
        if (isLeaf()) {
            return 1;
        }
        int i = 0;
        for (int i2 = 0; i2 < this.numChildren; i2++) {
            i = Math.max(i, this.children[0].depth());
        }
        return 1 + i;
    }

    private void fill(int i) {
        if (!this.indices.isEmpty() || this.parent != null) {
            log.warn("Called fill already");
            return;
        }
        for (int i2 = 0; i2 < i; i2++) {
            log.trace("Inserted " + i2);
            insert(i2);
        }
    }

    public SpTree[] getChildren() {
        return this.children;
    }

    public int getD() {
        return this.D;
    }

    public INDArray getCenterOfMass() {
        return this.centerOfMass;
    }

    public Cell getBoundary() {
        return this.boundary;
    }

    public int[] getIndex() {
        return this.index;
    }

    public int getCumSize() {
        return this.cumSize;
    }

    public void setCumSize(int i) {
        this.cumSize = i;
    }

    public int getNumChildren() {
        return this.numChildren;
    }

    public void setNumChildren(int i) {
        this.numChildren = i;
    }
}
