package ml.shifu.guagua.mapreduce.example.kmeans;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.master.MasterComputable;
import ml.shifu.guagua.master.MasterContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/shifu/guagua/mapreduce/example/kmeans/KMeansMaster.class */
public class KMeansMaster implements MasterComputable<KMeansMasterParams, KMeansWorkerParams> {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansMaster.class);

    public KMeansMasterParams compute(MasterContext<KMeansMasterParams, KMeansWorkerParams> masterContext) {
        if (masterContext.getWorkerResults() == null) {
            throw new NullPointerException("No worker results received in Master.");
        }
        return masterContext.getCurrentIteration() == 1 ? doFirstIteration(masterContext) : doOtherIterations(masterContext);
    }

    private KMeansMasterParams doFirstIteration(MasterContext<KMeansMasterParams, KMeansWorkerParams> masterContext) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        for (KMeansWorkerParams kMeansWorkerParams : masterContext.getWorkerResults()) {
            arrayList.addAll(kMeansWorkerParams.getPointList());
            if (0 == 0) {
                i = kMeansWorkerParams.getK();
                i2 = kMeansWorkerParams.getC();
            }
        }
        if (arrayList.size() < i) {
            throw new GuaguaRuntimeException("Error: data size is smaller than k, please check your input and k settings.");
        }
        Collections.sort(arrayList, new Comparator<double[]>() { // from class: ml.shifu.guagua.mapreduce.example.kmeans.KMeansMaster.1
            @Override // java.util.Comparator
            public int compare(double[] dArr, double[] dArr2) {
                return Double.valueOf(KMeansMaster.this.distance(dArr) - KMeansMaster.this.distance(dArr2)).compareTo(Double.valueOf(0.0d));
            }
        });
        ArrayList arrayList2 = new ArrayList(i);
        int size = arrayList.size() / i;
        for (int i3 = 0; i3 < i; i3++) {
            arrayList2.add(arrayList.get(i3 * size));
        }
        KMeansMasterParams kMeansMasterParams = new KMeansMasterParams();
        kMeansMasterParams.setK(i);
        kMeansMasterParams.setC(i2);
        kMeansMasterParams.setPointList(arrayList2);
        return kMeansMasterParams;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double distance(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    private KMeansMasterParams doOtherIterations(MasterContext<KMeansMasterParams, KMeansWorkerParams> masterContext) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        boolean z = false;
        int i = 0;
        int i2 = 0;
        for (KMeansWorkerParams kMeansWorkerParams : masterContext.getWorkerResults()) {
            LOG.debug("Worker result: %s", kMeansWorkerParams);
            if (!z) {
                i = kMeansWorkerParams.getK();
                i2 = kMeansWorkerParams.getC();
            }
            for (int i3 = 0; i3 < i; i3++) {
                if (!z) {
                    linkedList.add(new double[i2]);
                    linkedList2.add(0L);
                }
                linkedList2.set(i3, Long.valueOf(((Long) linkedList2.get(i3)).longValue() + kMeansWorkerParams.getCountList().get(i3).intValue()));
                double[] dArr = (double[]) linkedList.get(i3);
                for (int i4 = 0; i4 < i2; i4++) {
                    int i5 = i4;
                    dArr[i5] = dArr[i5] + kMeansWorkerParams.getPointList().get(i3)[i4];
                }
            }
            z = true;
        }
        LOG.debug("sumList: %s", linkedList);
        LOG.debug("countList: %s", linkedList2);
        LinkedList linkedList3 = new LinkedList();
        for (int i6 = 0; i6 < i; i6++) {
            double[] dArr2 = new double[i2];
            for (int i7 = 0; i7 < i2; i7++) {
                dArr2[i7] = ((double[]) linkedList.get(i6))[i7] / ((Long) linkedList2.get(i6)).longValue();
            }
            linkedList3.add(dArr2);
        }
        LOG.debug("meanList: %s", linkedList3);
        KMeansMasterParams kMeansMasterParams = new KMeansMasterParams();
        kMeansMasterParams.setK(i);
        kMeansMasterParams.setC(i2);
        kMeansMasterParams.setPointList(linkedList3);
        return kMeansMasterParams;
    }

    /* renamed from: compute, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Bytable m3compute(MasterContext masterContext) {
        return compute((MasterContext<KMeansMasterParams, KMeansWorkerParams>) masterContext);
    }
}
