package ai.libs.jaicore.ml.weka.classification.learner.reduction;

import ai.libs.jaicore.basic.StringUtil;
import ai.libs.jaicore.ml.weka.WekaUtil;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.api4.java.ai.ml.core.exception.TrainingException;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaException;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/learner/reduction/MCTreeNodeReD.class */
public class MCTreeNodeReD extends AMCTreeNode<String> {
    private static final long serialVersionUID = 8873192747068561266L;
    private boolean debugMode;
    private Classifier innerNodeClassifier;
    private List<ChildNode> children;
    private boolean trained;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/libs/jaicore/ml/weka/classification/learner/reduction/MCTreeNodeReD$ChildNode.class */
    public class ChildNode implements Serializable {
        private List<String> containedClasses;
        private Classifier childNodeClassifier;

        private ChildNode(List<String> list, Classifier classifier) {
            this.containedClasses = list;
            this.childNodeClassifier = classifier;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            if (this.childNodeClassifier instanceof MCTreeNodeReD) {
                sb.append(this.childNodeClassifier.toString());
            } else {
                sb.append(this.childNodeClassifier.getClass().getSimpleName() + "(");
                sb.append(StringUtil.implode(this.containedClasses, ","));
                sb.append(")");
            }
            return sb.toString();
        }

        public String toStringWithOffset(String str) {
            StringBuilder sb = new StringBuilder();
            if (this.childNodeClassifier instanceof MCTreeNodeReD) {
                sb.append(((MCTreeNodeReD) this.childNodeClassifier).toStringWithOffset(str + "\t"));
            } else {
                sb.append(str);
                sb.append("(");
                sb.append(this.containedClasses);
                sb.append(":");
                sb.append(this.childNodeClassifier.getClass().getSimpleName());
                sb.append(")");
            }
            return sb.toString();
        }
    }

    public MCTreeNodeReD(String str, Collection<String> collection, String str2, Collection<String> collection2, String str3) throws Exception {
        this(str, collection, AbstractClassifier.forName(str2, (String[]) null), collection2, AbstractClassifier.forName(str3, (String[]) null));
    }

    public MCTreeNodeReD(Classifier classifier, Collection<String> collection, Classifier classifier2, Collection<String> collection2, Classifier classifier3) {
        this(classifier, Arrays.asList(collection, collection2), Arrays.asList(classifier2, classifier3));
    }

    public MCTreeNodeReD(String str, Collection<String> collection, Classifier classifier, Collection<String> collection2, Classifier classifier2) throws Exception {
        this(AbstractClassifier.forName(str, new String[0]), collection, classifier, collection2, classifier2);
    }

    public MCTreeNodeReD(Classifier classifier, List<Collection<String>> list, List<Classifier> list2) {
        this();
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Number of child classes does not equal the number of child classifiers");
        }
        this.innerNodeClassifier = classifier;
        for (int i = 0; i < list.size(); i++) {
            addChild(new ArrayList(list.get(i)), list.get(i).size() > 1 ? list2.get(i) : new ConstantClassifier());
        }
    }

    public MCTreeNodeReD(MCTreeNodeReD mCTreeNodeReD) throws Exception {
        this(mCTreeNodeReD.innerNodeClassifier.getClass().getName(), mCTreeNodeReD.children.get(0).containedClasses, WekaUtil.cloneClassifier(mCTreeNodeReD.children.get(0).childNodeClassifier), mCTreeNodeReD.children.get(1).containedClasses, WekaUtil.cloneClassifier(mCTreeNodeReD.children.get(1).childNodeClassifier));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MCTreeNodeReD() {
        super(new ArrayList());
        this.debugMode = false;
        this.children = new ArrayList();
        this.trained = false;
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [java.util.List] */
    public void addChild(List<String> list, Classifier classifier) {
        if (!$assertionsDisabled && this.trained) {
            throw new AssertionError("Cannot insert children after the tree node has been trained!");
        }
        if (classifier instanceof MCTreeMergeNode) {
            this.children.addAll(((MCTreeMergeNode) classifier).getChildren());
        } else {
            this.children.add(new ChildNode(list, classifier));
        }
        getContainedClasses2().addAll(list);
    }

    public List<ChildNode> getChildren() {
        return this.children;
    }

    public boolean isCompletelyConfigured() {
        if (this.innerNodeClassifier == null || this.children.isEmpty()) {
            return false;
        }
        for (ChildNode childNode : this.children) {
            if ((childNode.childNodeClassifier instanceof MCTreeNodeReD) && !((MCTreeNodeReD) childNode.childNodeClassifier).isCompletelyConfigured()) {
                return false;
            }
        }
        return true;
    }

    @Override // ai.libs.jaicore.ml.weka.classification.learner.reduction.AMCTreeNode
    /* renamed from: getContainedClasses, reason: merged with bridge method [inline-methods] */
    public Collection<String> getContainedClasses2() {
        return (List) super.getContainedClasses2();
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [java.util.List] */
    /* JADX WARN: Type inference failed for: r0v71, types: [java.util.List] */
    /* JADX WARN: Type inference failed for: r0v74, types: [java.util.List] */
    public void buildClassifier(Instances instances) throws Exception {
        if (instances.isEmpty()) {
            throw new IllegalArgumentException("Cannot train MCTree with empty set of instances.");
        }
        if (!$assertionsDisabled && this.children.isEmpty()) {
            throw new AssertionError("Cannot train MCTree without children");
        }
        if (!$assertionsDisabled && this.trained) {
            throw new AssertionError("Cannot retrain MCTreeNodeReD");
        }
        if (this.debugMode) {
            if (!getContainedClasses2().containsAll(WekaUtil.getClassesActuallyContainedInDataset(instances))) {
                throw new IllegalStateException("The classes for which this MCTreeNodeReD has been defined (" + getContainedClasses2() + ") is not a superset of the given training data (" + WekaUtil.getClassesActuallyContainedInDataset(instances) + ") ...");
            }
            if (!WekaUtil.getClassesActuallyContainedInDataset(instances).containsAll(getContainedClasses2())) {
                throw new IllegalStateException("The classes for which this MCTreeNodeReD has been defined (" + getContainedClasses2() + ") is not a subset of the given training data (" + WekaUtil.getClassesActuallyContainedInDataset(instances) + ") ...");
            }
        }
        getContainedClasses2().clear();
        for (int i = 0; i < instances.numClasses(); i++) {
            getContainedClasses2().add(instances.classAttribute().value(i));
        }
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        for (ChildNode childNode : getChildren()) {
            i2++;
            if (!$assertionsDisabled && childNode.containedClasses.isEmpty()) {
                throw new AssertionError("Contained classes of child must not be empty");
            }
            Instances emptySetOfInstancesWithRefactoredClass = WekaUtil.getEmptySetOfInstancesWithRefactoredClass(instances, childNode.containedClasses);
            Iterator it = instances.iterator();
            while (it.hasNext()) {
                Instance instance = (Instance) it.next();
                String value = instance.classAttribute().value((int) Math.round(instance.classValue()));
                if (childNode.containedClasses.contains(value)) {
                    Instance refactoredInstance = WekaUtil.getRefactoredInstance(instance, childNode.containedClasses);
                    refactoredInstance.setClassValue(value);
                    refactoredInstance.setDataset(emptySetOfInstancesWithRefactoredClass);
                    emptySetOfInstancesWithRefactoredClass.add(refactoredInstance);
                }
            }
            if (!$assertionsDisabled && !childNode.containedClasses.containsAll(WekaUtil.getClassesActuallyContainedInDataset(emptySetOfInstancesWithRefactoredClass))) {
                throw new AssertionError("There are data for the child node that are not contained in its declaration");
            }
            if (!$assertionsDisabled && !WekaUtil.getClassesActuallyContainedInDataset(emptySetOfInstancesWithRefactoredClass).containsAll(childNode.containedClasses)) {
                throw new AssertionError("There are classes declared in the child, but no corresponding data have been passed");
            }
            try {
                childNode.childNodeClassifier.buildClassifier(emptySetOfInstancesWithRefactoredClass);
                arrayList.add(new HashSet(childNode.containedClasses));
            } catch (Exception e) {
                throw new TrainingException("Cannot train classifier in child #" + i2, e);
            }
        }
        Instances mergeClassesOfInstances = WekaUtil.mergeClassesOfInstances(instances, arrayList);
        try {
            this.innerNodeClassifier.buildClassifier(mergeClassesOfInstances);
        } catch (Exception e2) {
            throw new TrainingException("Cannot train inner classifier", e2);
        } catch (WekaException e3) {
            this.innerNodeClassifier = new ZeroR();
            this.innerNodeClassifier.buildClassifier(mergeClassesOfInstances);
        }
        this.trained = true;
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [java.util.List] */
    /* JADX WARN: Type inference failed for: r1v16, types: [java.util.List] */
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!$assertionsDisabled && !this.trained) {
            throw new AssertionError("Cannot get distribution from untrained classifier " + toStringWithOffset());
        }
        double[] distributionForInstance = this.innerNodeClassifier.distributionForInstance(WekaUtil.getRefactoredInstance(instance));
        double[] dArr = new double[getContainedClasses2().size()];
        for (int i = 0; i < this.children.size(); i++) {
            ChildNode childNode = this.children.get(i);
            double[] distributionForInstance2 = childNode.childNodeClassifier.distributionForInstance(WekaUtil.getRefactoredInstance(instance, childNode.containedClasses));
            if (!$assertionsDisabled && distributionForInstance2.length != childNode.containedClasses.size()) {
                throw new AssertionError("Mismatch of child classes (" + childNode.containedClasses.size() + ") and distribution in child (" + distributionForInstance2.length + ")");
            }
            for (int i2 = 0; i2 < distributionForInstance2.length; i2++) {
                dArr[getContainedClasses2().indexOf((String) childNode.containedClasses.get(i2))] = distributionForInstance2[i2] * distributionForInstance[i];
            }
        }
        double sum = Arrays.stream(dArr).sum();
        if ($assertionsDisabled || (sum - 1.0E-8d <= 1.0d && sum + 1.0E-8d >= 1.0d)) {
            return dArr;
        }
        throw new AssertionError("Distribution does not sum up to 1; actual some of distribution entries: " + sum);
    }

    public Capabilities getCapabilities() {
        return this.innerNodeClassifier.getCapabilities();
    }

    public int getHeight() {
        int i = 0;
        for (ChildNode childNode : this.children) {
            if (childNode.childNodeClassifier instanceof MCTreeNodeReD) {
                i = Math.max(((MCTreeNodeReD) childNode.childNodeClassifier).getHeight(), i);
            }
        }
        return 1 + i;
    }

    public int getDepthOfFirstCommonParent(List<String> list) {
        for (ChildNode childNode : this.children) {
            if (childNode.containedClasses.containsAll(list)) {
                return childNode.childNodeClassifier instanceof MCTreeNodeReD ? 1 + ((MCTreeNodeReD) childNode.childNodeClassifier).getDepthOfFirstCommonParent(list) : 1;
            }
        }
        return 1;
    }

    public Classifier getClassifier() {
        return this.innerNodeClassifier;
    }

    public void setBaseClassifier(Classifier classifier) {
        if (classifier == null) {
            throw new IllegalArgumentException("Cannot set null classifier!");
        }
        this.innerNodeClassifier = classifier;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("(");
        sb.append(this.innerNodeClassifier.getClass().getSimpleName());
        sb.append(")");
        sb.append("{");
        boolean z = true;
        for (ChildNode childNode : this.children) {
            if (z) {
                z = false;
            } else {
                sb.append(",");
            }
            sb.append(childNode);
        }
        sb.append("}");
        return sb.toString();
    }

    public String toStringWithOffset() {
        return toStringWithOffset("");
    }

    public String toStringWithOffset(String str) {
        StringBuilder sb = new StringBuilder();
        sb.append(str);
        sb.append("(");
        sb.append(getContainedClasses2());
        sb.append(":");
        sb.append(this.innerNodeClassifier.getClass().getSimpleName());
        sb.append(") {");
        boolean z = true;
        for (ChildNode childNode : this.children) {
            if (z) {
                z = false;
            } else {
                sb.append(",");
            }
            sb.append("\n");
            sb.append(childNode.toStringWithOffset(str + "  "));
        }
        sb.append("\n");
        sb.append(str);
        sb.append("}");
        return sb.toString();
    }

    static {
        $assertionsDisabled = !MCTreeNodeReD.class.desiredAssertionStatus();
    }
}
