package ai.libs.jaicore.ml.weka.dataset;

import ai.libs.jaicore.ml.core.dataset.schema.LabeledInstanceSchema;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.NumericAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema;
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.INumericAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.INumericAttributeValue;
import org.api4.java.ai.ml.core.dataset.serialization.UnsupportedAttributeTypeException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/weka/dataset/WekaInstancesUtil.class */
public class WekaInstancesUtil {
    private WekaInstancesUtil() {
    }

    public static ILabeledInstanceSchema extractSchema(Instances instances) {
        int classIndex = instances.classIndex();
        if (classIndex < 0) {
            throw new IllegalArgumentException("Class index of Instances object is not set!");
        }
        IntStream range = IntStream.range(0, instances.numAttributes());
        Objects.requireNonNull(instances);
        List list = (List) range.mapToObj(instances::attribute).map(WekaInstancesUtil::transformWEKAAttributeToAttributeType).collect(Collectors.toList());
        return new LabeledInstanceSchema(instances.relationName(), list, (IAttribute) list.remove(classIndex));
    }

    public static Instances datasetToWekaInstances(ILabeledDataset<? extends ILabeledInstance> iLabeledDataset) throws UnsupportedAttributeTypeException {
        Instances createDatasetFromSchema = createDatasetFromSchema(iLabeledDataset.getInstanceSchema());
        int numAttributes = iLabeledDataset.getInstanceSchema().getNumAttributes();
        for (ILabeledInstance iLabeledInstance : iLabeledDataset) {
            if (iLabeledInstance.getNumAttributes() != numAttributes) {
                throw new IllegalStateException("Dataset scheme defines a number of " + numAttributes + " attributes, but instance has " + iLabeledInstance.getNumAttributes() + ".");
            }
            double[] point = iLabeledInstance.getPoint();
            DenseInstance denseInstance = new DenseInstance(1.0d, Arrays.copyOf(point, point.length + 1));
            denseInstance.setDataset(createDatasetFromSchema);
            if (iLabeledDataset.getLabelAttribute() instanceof ICategoricalAttribute) {
                denseInstance.setClassValue(iLabeledDataset.getLabelAttribute().getLabelOfCategory(Integer.valueOf(((Integer) iLabeledInstance.getLabel()).intValue())));
            } else if (iLabeledInstance.getLabel() != null) {
                denseInstance.setClassValue(Double.parseDouble(iLabeledInstance.getLabel().toString()));
            }
            createDatasetFromSchema.add(denseInstance);
        }
        return createDatasetFromSchema;
    }

    public static Instances createDatasetFromSchema(ILabeledInstanceSchema iLabeledInstanceSchema) throws UnsupportedAttributeTypeException {
        Attribute attribute;
        Objects.requireNonNull(iLabeledInstanceSchema);
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < iLabeledInstanceSchema.getNumAttributes(); i++) {
            ICategoricalAttribute iCategoricalAttribute = (IAttribute) iLabeledInstanceSchema.getAttributeList().get(i);
            if (iCategoricalAttribute instanceof INumericAttribute) {
                linkedList.add(new Attribute(iCategoricalAttribute.getName()));
            } else {
                if (!(iCategoricalAttribute instanceof ICategoricalAttribute)) {
                    throw new UnsupportedAttributeTypeException("The class attribute has an unsupported attribute type " + iCategoricalAttribute.getClass().getName() + " of attribute " + iCategoricalAttribute.getName() + ".");
                }
                linkedList.add(new Attribute(iCategoricalAttribute.getName(), iCategoricalAttribute.getLabels()));
            }
        }
        IntBasedCategoricalAttribute labelAttribute = iLabeledInstanceSchema.getLabelAttribute();
        if (labelAttribute instanceof INumericAttribute) {
            attribute = new Attribute(labelAttribute.getName());
        } else {
            if (!(labelAttribute instanceof ICategoricalAttribute)) {
                throw new UnsupportedAttributeTypeException("The class attribute has an unsupported attribute type.");
            }
            attribute = new Attribute(labelAttribute.getName(), labelAttribute.getLabels());
        }
        ArrayList arrayList = new ArrayList(linkedList);
        arrayList.add(attribute);
        Instances instances = new Instances("weka-instances", arrayList, 0);
        instances.setClassIndex(instances.numAttributes() - 1);
        return instances;
    }

    public static IAttribute transformWEKAAttributeToAttributeType(Attribute attribute) {
        String name = attribute.name();
        if (attribute.isNumeric()) {
            return new NumericAttribute(name);
        }
        if (!attribute.isNominal()) {
            throw new IllegalArgumentException("Can only transform numeric or categorical attributes");
        }
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < attribute.numValues(); i++) {
            linkedList.add(attribute.value(i));
        }
        return new IntBasedCategoricalAttribute(name, linkedList);
    }

    public static Instance transformInstanceToWekaInstance(ILabeledInstanceSchema iLabeledInstanceSchema, ILabeledInstance iLabeledInstance) throws UnsupportedAttributeTypeException {
        if (iLabeledInstance.getNumAttributes() != iLabeledInstanceSchema.getNumAttributes()) {
            throw new IllegalArgumentException("Schema and instance do not coincide. The schema defines " + iLabeledInstanceSchema.getNumAttributes() + " attributes but the instance has " + iLabeledInstance.getNumAttributes() + " attributes.");
        }
        if (iLabeledInstance instanceof WekaInstance) {
            return (Instance) ((WekaInstance) iLabeledInstance).m60getElement();
        }
        Objects.requireNonNull(iLabeledInstanceSchema);
        Instances createDatasetFromSchema = createDatasetFromSchema(iLabeledInstanceSchema);
        DenseInstance denseInstance = new DenseInstance(createDatasetFromSchema.numAttributes());
        denseInstance.setDataset(createDatasetFromSchema);
        for (int i = 0; i < iLabeledInstance.getNumAttributes(); i++) {
            if (iLabeledInstanceSchema.getAttribute(i) instanceof INumericAttribute) {
                INumericAttributeValue asAttributeValue = iLabeledInstanceSchema.getAttribute(i).getAsAttributeValue(iLabeledInstance.getAttributeValue(i));
                if (asAttributeValue != null) {
                    denseInstance.setValue(i, asAttributeValue.getValue().doubleValue());
                } else {
                    denseInstance.setMissing(i);
                }
            } else {
                if (!(iLabeledInstanceSchema.getAttribute(i) instanceof ICategoricalAttribute)) {
                    throw new UnsupportedAttributeTypeException("Only categorical and numeric attributes are supported!");
                }
                if (iLabeledInstanceSchema.getAttribute(i).getAsAttributeValue(iLabeledInstance.getAttributeValue(i)) != null) {
                    denseInstance.setValue(i, r0.getValue().intValue());
                } else {
                    denseInstance.setMissing(i);
                }
            }
        }
        if (iLabeledInstanceSchema.getLabelAttribute() instanceof INumericAttribute) {
            denseInstance.setValue(denseInstance.numAttributes() - 1, iLabeledInstanceSchema.getLabelAttribute().getAsAttributeValue(iLabeledInstance.getLabel()).getValue().doubleValue());
        } else {
            if (!(iLabeledInstanceSchema.getLabelAttribute() instanceof ICategoricalAttribute)) {
                throw new UnsupportedAttributeTypeException("Only categorical and numeric attributes are supported!");
            }
            denseInstance.setValue(denseInstance.numAttributes() - 1, iLabeledInstanceSchema.getLabelAttribute().getAsAttributeValue(iLabeledInstance.getLabel()).getValue().intValue());
        }
        if (denseInstance.numClasses() != createDatasetFromSchema.numClasses()) {
            throw new IllegalStateException();
        }
        return denseInstance;
    }
}
