package org.deeplearning4j.clustering.cluster;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.algorithm.optimisation.ClusteringOptimizationType;
import org.deeplearning4j.clustering.algorithm.strategy.OptimisationStrategy;
import org.deeplearning4j.clustering.cluster.info.ClusterInfo;
import org.deeplearning4j.clustering.cluster.info.ClusterSetInfo;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.MultiThreadUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/clustering/cluster/ClusterUtils.class */
public class ClusterUtils {
    private ClusterUtils() {
    }

    public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> list, ExecutorService executorService) {
        final ClusterSetInfo initialize = ClusterSetInfo.initialize(clusterSet, true);
        ArrayList arrayList = new ArrayList();
        for (final Point point : list) {
            arrayList.add(new Runnable() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.1
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        PointClassification classifyPoint = ClusterUtils.classifyPoint(ClusterSet.this, point);
                        if (classifyPoint.isNewLocation()) {
                            initialize.getPointLocationChange().incrementAndGet();
                        }
                        initialize.getClusterInfo(classifyPoint.getCluster().getId()).getPointDistancesFromCenter().put(point.getId(), Double.valueOf(classifyPoint.getDistanceFromCenter()));
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(arrayList, executorService);
        return initialize;
    }

    public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) {
        return clusterSet.classifyPoint(point, false);
    }

    public static void refreshClustersCenters(ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, ExecutorService executorService) {
        ArrayList arrayList = new ArrayList();
        int clusterCount = clusterSet.getClusterCount();
        for (int i = 0; i < clusterCount; i++) {
            final Cluster cluster = clusterSet.getClusters().get(i);
            arrayList.add(new Runnable() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.2
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        ClusterInfo clusterInfo = ClusterSetInfo.this.getClusterInfo(cluster.getId());
                        ClusterUtils.refreshClusterCenter(cluster, clusterInfo);
                        ClusterUtils.deriveClusterInfoDistanceStatistics(clusterInfo);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(arrayList, executorService);
    }

    public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
        int size = cluster.getPoints().size();
        if (size == 0) {
            return;
        }
        Point point = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length()));
        Iterator<Point> it2 = cluster.getPoints().iterator();
        while (it2.hasNext()) {
            point.getArray().addi(it2.next().getArray());
        }
        point.getArray().divi(Integer.valueOf(size));
        cluster.setCenter(point);
    }

    public static void deriveClusterInfoDistanceStatistics(ClusterInfo clusterInfo) {
        int size = clusterInfo.getPointDistancesFromCenter().size();
        if (size == 0) {
            return;
        }
        double[] primitive = ArrayUtils.toPrimitive((Double[]) clusterInfo.getPointDistancesFromCenter().values().toArray(new Double[0]));
        double max = MathUtils.max(primitive);
        double sum = MathUtils.sum(primitive);
        clusterInfo.setMaxPointDistanceFromCenter(max);
        clusterInfo.setTotalPointDistanceFromCenter(sum);
        clusterInfo.setAveragePointDistanceFromCenter(sum / size);
        clusterInfo.setPointDistanceFromCenterVariance(MathUtils.variance(primitive));
    }

    public static INDArray computeSquareDistancesFromNearestCluster(ClusterSet clusterSet, final List<Point> list, INDArray iNDArray, ExecutorService executorService) {
        int size = list.size();
        final INDArray create = Nd4j.create(size);
        final Cluster cluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            final int i2 = i;
            arrayList.add(new Runnable() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.3
                @Override // java.lang.Runnable
                public void run() {
                    create.putScalar(i2, Math.pow(cluster.getDistanceToCenter((Point) list.get(i2)), 2.0d));
                }
            });
        }
        MultiThreadUtils.parallelTasks(arrayList, executorService);
        for (int i3 = 0; i3 < size; i3++) {
            double d = iNDArray.getDouble(i3);
            if (create.getDouble(i3) > d) {
                create.putScalar(i3, d);
            }
        }
        return create;
    }

    public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
        ExecutorService newExecutorService = MultiThreadUtils.newExecutorService();
        ClusterSetInfo computeClusterSetInfo = computeClusterSetInfo(clusterSet, newExecutorService);
        newExecutorService.shutdownNow();
        return computeClusterSetInfo;
    }

    public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) {
        final ClusterSetInfo clusterSetInfo = new ClusterSetInfo(true);
        int clusterCount = clusterSet.getClusterCount();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < clusterCount; i++) {
            final Cluster cluster = clusterSet.getClusters().get(i);
            arrayList.add(new Runnable() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.4
                @Override // java.lang.Runnable
                public void run() {
                    ClusterSetInfo.this.getClustersInfos().put(cluster.getId(), ClusterUtils.computeClusterInfos(cluster, clusterSet.getAccumulation()));
                }
            });
        }
        MultiThreadUtils.parallelTasks(arrayList, executorService);
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < clusterCount; i2++) {
            final int i3 = i2;
            final Cluster cluster2 = clusterSet.getClusters().get(i2);
            arrayList2.add(new Runnable() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.5
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        int clusterCount2 = clusterSet.getClusterCount();
                        for (int i4 = i3 + 1; i4 < clusterCount2; i4++) {
                            Cluster cluster3 = clusterSet.getClusters().get(i4);
                            clusterSetInfo.getDistancesBetweenClustersCenters().put(cluster2.getId(), cluster3.getId(), Double.valueOf(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createAccum(clusterSet.getAccumulation(), cluster2.getCenter().getArray(), cluster3.getCenter().getArray())).getFinalResult().doubleValue()));
                        }
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(arrayList2, executorService);
        return clusterSetInfo;
    }

    public static ClusterInfo computeClusterInfos(Cluster cluster, String str) {
        ClusterInfo clusterInfo = new ClusterInfo(true);
        int size = cluster.getPoints().size();
        for (int i = 0; i < size; i++) {
            Point point = cluster.getPoints().get(i);
            double doubleValue = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createAccum(str, cluster.getCenter().getArray(), point.getArray())).getFinalResult().doubleValue();
            clusterInfo.getPointDistancesFromCenter().put(point.getId(), Double.valueOf(doubleValue));
            clusterInfo.setTotalPointDistanceFromCenter(clusterInfo.getTotalPointDistanceFromCenter() + doubleValue);
        }
        if (!cluster.getPoints().isEmpty()) {
            clusterInfo.setAveragePointDistanceFromCenter(clusterInfo.getTotalPointDistanceFromCenter() / cluster.getPoints().size());
        }
        return clusterInfo;
    }

    public static boolean applyOptimization(OptimisationStrategy optimisationStrategy, ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, ExecutorService executorService) {
        return optimisationStrategy.isClusteringOptimizationType(ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE) ? splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimisationStrategy.getClusteringOptimizationValue(), executorService) > 0 : optimisationStrategy.isClusteringOptimizationType(ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE) && splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, optimisationStrategy.getClusteringOptimizationValue(), executorService) > 0;
    }

    public static List<Cluster> getMostSpreadOutClusters(ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, int i) {
        ArrayList arrayList = new ArrayList(clusterSet.getClusters());
        Collections.sort(arrayList, new Comparator<Cluster>() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.6
            @Override // java.util.Comparator
            public int compare(Cluster cluster, Cluster cluster2) {
                return -Double.valueOf(ClusterSetInfo.this.getClusterInfo(cluster.getId()).getTotalPointDistanceFromCenter()).compareTo(Double.valueOf(ClusterSetInfo.this.getClusterInfo(cluster2.getId()).getTotalPointDistanceFromCenter()));
            }
        });
        return arrayList.subList(0, i);
    }

    public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double d) {
        ArrayList arrayList = new ArrayList();
        for (Cluster cluster : clusterSet.getClusters()) {
            ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
            if (clusterInfo != null && clusterInfo.getAveragePointDistanceFromCenter() > d) {
                arrayList.add(cluster);
            }
        }
        return arrayList;
    }

    public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double d) {
        ArrayList arrayList = new ArrayList();
        for (Cluster cluster : clusterSet.getClusters()) {
            ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
            if (clusterInfo != null && clusterInfo.getMaxPointDistanceFromCenter() > d) {
                arrayList.add(cluster);
            }
        }
        return arrayList;
    }

    public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int i, ExecutorService executorService) {
        List<Cluster> mostSpreadOutClusters = getMostSpreadOutClusters(clusterSet, clusterSetInfo, i);
        splitClusters(clusterSet, clusterSetInfo, mostSpreadOutClusters, executorService);
        return mostSpreadOutClusters.size();
    }

    public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double d, ExecutorService executorService) {
        List<Cluster> clustersWhereAverageDistanceFromCenterGreaterThan = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, d);
        splitClusters(clusterSet, clusterSetInfo, clustersWhereAverageDistanceFromCenterGreaterThan, d, executorService);
        return clustersWhereAverageDistanceFromCenterGreaterThan.size();
    }

    public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, double d, ExecutorService executorService) {
        List<Cluster> clustersWhereMaximumDistanceFromCenterGreaterThan = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, d);
        splitClusters(clusterSet, clusterSetInfo, clustersWhereMaximumDistanceFromCenterGreaterThan, d, executorService);
        return clustersWhereMaximumDistanceFromCenterGreaterThan.size();
    }

    public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int i, ExecutorService executorService) {
        splitClusters(clusterSet, clusterSetInfo, clusterSet.getMostPopulatedClusters(i), executorService);
    }

    public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, List<Cluster> list, final double d, ExecutorService executorService) {
        final Random random = new Random();
        ArrayList arrayList = new ArrayList();
        for (final Cluster cluster : list) {
            arrayList.add(new Runnable() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.7
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        List<String> pointsFartherFromCenterThan = ClusterSetInfo.this.getClusterInfo(cluster.getId()).getPointsFartherFromCenterThan(d);
                        clusterSet.addNewClusterWithCenter(cluster.removePoint(pointsFartherFromCenterThan.get(random.nextInt(Math.min(pointsFartherFromCenterThan.size(), 3)))));
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        MultiThreadUtils.parallelTasks(arrayList, executorService);
    }

    public static void splitClusters(final ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, List<Cluster> list, ExecutorService executorService) {
        final Random random = new Random();
        ArrayList arrayList = new ArrayList();
        for (final Cluster cluster : list) {
            arrayList.add(new Runnable() { // from class: org.deeplearning4j.clustering.cluster.ClusterUtils.8
                @Override // java.lang.Runnable
                public void run() {
                    clusterSet.addNewClusterWithCenter(Cluster.this.getPoints().remove(random.nextInt(Cluster.this.getPoints().size())));
                }
            });
        }
        MultiThreadUtils.parallelTasks(arrayList, executorService);
    }
}
