package org.nd4j.evaluation.classification;

import java.io.Serializable;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.nd4j.shade.guava.collect.HashMultiset;
import org.nd4j.shade.guava.collect.Multiset;

/* loaded from: input_file:org/nd4j/evaluation/classification/ConfusionMatrix.class */
public class ConfusionMatrix<T extends Comparable<? super T>> implements Serializable {
    private volatile Map<T, Multiset<T>> matrix;
    private List<T> classes;

    public ConfusionMatrix(List<T> list) {
        this.matrix = new ConcurrentHashMap();
        this.classes = list;
    }

    public ConfusionMatrix() {
        this(new ArrayList());
    }

    public ConfusionMatrix(ConfusionMatrix<T> confusionMatrix) {
        this(confusionMatrix.getClasses());
        add(confusionMatrix);
    }

    public synchronized void add(T t, T t2) {
        add(t, t2, 1);
    }

    public synchronized void add(T t, T t2, int i) {
        if (this.matrix.containsKey(t)) {
            this.matrix.get(t).add(t2, i);
            return;
        }
        HashMultiset create = HashMultiset.create();
        create.add(t2, i);
        this.matrix.put(t, create);
    }

    public synchronized void add(ConfusionMatrix<T> confusionMatrix) {
        for (T t : confusionMatrix.matrix.keySet()) {
            Multiset<T> multiset = confusionMatrix.matrix.get(t);
            for (T t2 : multiset.elementSet()) {
                add(t, t2, multiset.count(t2));
            }
        }
    }

    public List<T> getClasses() {
        if (this.classes == null) {
            this.classes = new ArrayList();
        }
        return this.classes;
    }

    public synchronized int getCount(T t, T t2) {
        if (this.matrix.containsKey(t)) {
            return this.matrix.get(t).count(t2);
        }
        return 0;
    }

    public synchronized int getPredictedTotal(T t) {
        int i = 0;
        Iterator<T> it2 = this.classes.iterator();
        while (it2.hasNext()) {
            i += getCount(it2.next(), t);
        }
        return i;
    }

    public synchronized int getActualTotal(T t) {
        if (!this.matrix.containsKey(t)) {
            return 0;
        }
        int i = 0;
        Iterator<T> it2 = this.matrix.get(t).elementSet().iterator();
        while (it2.hasNext()) {
            i += this.matrix.get(t).count(it2.next());
        }
        return i;
    }

    public String toString() {
        return this.matrix.toString();
    }

    public String toCSV() {
        StringBuilder sb = new StringBuilder();
        sb.append(",,Predicted Class,\n");
        sb.append(",,");
        Iterator<T> it2 = this.classes.iterator();
        while (it2.hasNext()) {
            sb.append(String.format("%s,", it2.next()));
        }
        sb.append("Total\n");
        String str = "Actual Class,";
        for (T t : this.classes) {
            sb.append(str);
            str = ",";
            sb.append(String.format("%s,", t));
            Iterator<T> it3 = this.classes.iterator();
            while (it3.hasNext()) {
                sb.append(getCount(t, it3.next()));
                sb.append(",");
            }
            sb.append(getActualTotal(t));
            sb.append("\n");
        }
        sb.append(",Total,");
        Iterator<T> it4 = this.classes.iterator();
        while (it4.hasNext()) {
            sb.append(getPredictedTotal(it4.next()));
            sb.append(",");
        }
        sb.append("\n");
        return sb.toString();
    }

    public String toHTML() {
        StringBuilder sb = new StringBuilder();
        int size = this.classes.size();
        sb.append("<table>\n");
        sb.append("<tr><th class=\"empty-space\" colspan=\"2\" rowspan=\"2\">");
        sb.append(String.format("<th class=\"predicted-class-header\" colspan=\"%d\">Predicted Class</th></tr>%n", Integer.valueOf(size + 1)));
        sb.append("<tr>");
        for (T t : this.classes) {
            sb.append("<th class=\"predicted-class-header\">");
            sb.append(t);
            sb.append("</th>");
        }
        sb.append("<th class=\"predicted-class-header\">Total</th>");
        sb.append("</tr>\n");
        String format = String.format("<tr><th class=\"actual-class-header\" rowspan=\"%d\">Actual Class</th>", Integer.valueOf(size + 1));
        for (T t2 : this.classes) {
            sb.append(format);
            format = "<tr>";
            sb.append(String.format("<th class=\"actual-class-header\" >%s</th>", t2));
            for (T t3 : this.classes) {
                sb.append("<td class=\"count-element\">");
                sb.append(getCount(t2, t3));
                sb.append("</td>");
            }
            sb.append("<td class=\"count-element\">");
            sb.append(getActualTotal(t2));
            sb.append("</td>");
            sb.append("</tr>\n");
        }
        sb.append("<tr><th class=\"actual-class-header\">Total</th>");
        for (T t4 : this.classes) {
            sb.append("<td class=\"count-element\">");
            sb.append(getPredictedTotal(t4));
            sb.append("</td>");
        }
        sb.append("<td class=\"empty-space\"></td>\n");
        sb.append("</tr>\n");
        sb.append("</table>\n");
        return sb.toString();
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof ConfusionMatrix)) {
            return false;
        }
        ConfusionMatrix confusionMatrix = (ConfusionMatrix) obj;
        return this.matrix.equals(confusionMatrix.matrix) && this.classes.equals(confusionMatrix.classes);
    }

    public int hashCode() {
        return (31 * ((31 * 17) + (this.matrix == null ? 0 : this.matrix.hashCode()))) + (this.classes == null ? 0 : this.classes.hashCode());
    }

    public Map<T, Multiset<T>> getMatrix() {
        return this.matrix;
    }
}
