package org.deeplearning4j.clustering.cluster;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/clustering/cluster/ClusterSet.class */
public class ClusterSet implements Serializable {
    private String distanceFunction;
    private List<Cluster> clusters;
    private Map<String, String> pointDistribution;

    public ClusterSet() {
        this(null);
    }

    public ClusterSet(String str) {
        this.distanceFunction = str;
        this.clusters = Collections.synchronizedList(new ArrayList());
        this.pointDistribution = Collections.synchronizedMap(new HashMap());
    }

    public Cluster addNewClusterWithCenter(Point point) {
        Cluster cluster = new Cluster(point, this.distanceFunction);
        getClusters().add(cluster);
        setPointLocation(point, cluster);
        return cluster;
    }

    public PointClassification classifyPoint(Point point) {
        return classifyPoint(point, true);
    }

    public void classifyPoints(List<Point> list) {
        classifyPoints(list, true);
    }

    public void classifyPoints(List<Point> list, boolean z) {
        Iterator<Point> it2 = list.iterator();
        while (it2.hasNext()) {
            classifyPoint(it2.next(), z);
        }
    }

    public PointClassification classifyPoint(Point point, boolean z) {
        Pair<Cluster, Double> nearestCluster = nearestCluster(point);
        Cluster first = nearestCluster.getFirst();
        boolean isPointLocationChange = isPointLocationChange(point, first);
        addPointToCluster(point, first, z);
        return new PointClassification(nearestCluster.getFirst(), nearestCluster.getSecond().doubleValue(), isPointLocationChange);
    }

    private boolean isPointLocationChange(Point point, Cluster cluster) {
        return (getPointDistribution().containsKey(point.getId()) && getPointDistribution().get(point.getId()).equals(cluster.getId())) ? false : true;
    }

    private void addPointToCluster(Point point, Cluster cluster, boolean z) {
        cluster.addPoint(point, z);
        setPointLocation(point, cluster);
    }

    private void setPointLocation(Point point, Cluster cluster) {
        this.pointDistribution.put(point.getId(), cluster.getId());
    }

    public Pair<Cluster, Double> nearestCluster(Point point) {
        Cluster cluster = null;
        double d = 3.4028234663852886E38d;
        for (Cluster cluster2 : getClusters()) {
            double distanceToCenter = cluster2.getDistanceToCenter(point);
            if (distanceToCenter < d) {
                d = distanceToCenter;
                cluster = cluster2;
            }
        }
        return new Pair<>(cluster, Double.valueOf(d));
    }

    public double getDistance(Point point, Point point2) {
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createAccum(this.distanceFunction, point.getArray(), point2.getArray())).getFinalResult().doubleValue();
    }

    public double getDistanceFromNearestCluster(Point point) {
        return nearestCluster(point).getSecond().doubleValue();
    }

    public String getClusterCenterId(String str) {
        Point clusterCenter = getClusterCenter(str);
        if (clusterCenter == null) {
            return null;
        }
        return clusterCenter.getId();
    }

    public Point getClusterCenter(String str) {
        Cluster cluster = getCluster(str);
        if (cluster == null) {
            return null;
        }
        return cluster.getCenter();
    }

    public Cluster getCluster(String str) {
        int size = this.clusters.size();
        for (int i = 0; i < size; i++) {
            if (str.equals(this.clusters.get(i).getId())) {
                return this.clusters.get(i);
            }
        }
        return null;
    }

    public int getClusterCount() {
        if (getClusters() == null) {
            return 0;
        }
        return getClusters().size();
    }

    public void removePoints() {
        Iterator<Cluster> it2 = getClusters().iterator();
        while (it2.hasNext()) {
            it2.next().removePoints();
        }
    }

    public List<Cluster> getMostPopulatedClusters(int i) {
        ArrayList arrayList = new ArrayList(this.clusters);
        Collections.sort(arrayList, new Comparator<Cluster>() { // from class: org.deeplearning4j.clustering.cluster.ClusterSet.1
            @Override // java.util.Comparator
            public int compare(Cluster cluster, Cluster cluster2) {
                return new Integer(cluster.getPoints().size()).compareTo(new Integer(cluster2.getPoints().size()));
            }
        });
        return arrayList.subList(0, i);
    }

    public List<Cluster> removeEmptyClusters() {
        ArrayList arrayList = new ArrayList();
        for (Cluster cluster : this.clusters) {
            if (cluster.isEmpty()) {
                arrayList.add(cluster);
            }
        }
        this.clusters.removeAll(arrayList);
        return arrayList;
    }

    public List<Cluster> getClusters() {
        return this.clusters;
    }

    public void setClusters(List<Cluster> list) {
        this.clusters = list;
    }

    public String getAccumulation() {
        return this.distanceFunction;
    }

    public void setAccumulation(String str) {
        this.distanceFunction = str;
    }

    public Map<String, String> getPointDistribution() {
        return this.pointDistribution;
    }

    public void setPointDistribution(Map<String, String> map) {
        this.pointDistribution = map;
    }
}
