package ai.sklearn4j.preprocessing.label;

import ai.sklearn4j.base.TransformerMixin;
import ai.sklearn4j.core.ScikitLearnCoreException;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:ai/sklearn4j/preprocessing/label/MultiLabelBinarizer.class */
public class MultiLabelBinarizer extends TransformerMixin<List<Set<Object>>, NumpyArray<Long>> {
    private List<Object> classes = null;
    private Map<String, Object> cachedDict = null;

    public void setClasses(List<Object> list) {
        this.classes = list;
    }

    public List<Object> getClasses() {
        return this.classes;
    }

    public void setCachedDict(Map<String, Object> map) {
        this.cachedDict = map;
    }

    public Map<String, Object> getCachedDict() {
        return this.cachedDict;
    }

    @Override // ai.sklearn4j.base.TransformerMixin
    public NumpyArray<Long> transform(List<Set<Object>> list) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.classes.size(); i++) {
            hashMap.put(this.classes.get(i), Long.valueOf(i));
        }
        NumpyArray<Long> arrayOfInt64WithShape = NumpyArrayFactory.arrayOfInt64WithShape(new int[]{list.size(), this.classes.size()});
        for (int i2 = 0; i2 < list.size(); i2++) {
            for (Object obj : list.get(i2)) {
                if (!hashMap.containsKey(obj)) {
                    throw new ScikitLearnCoreException(String.format("The class '%s' was not defined during the MultiLabelBinarizer training.", obj.toString()));
                }
                arrayOfInt64WithShape.set(1, i2, (int) ((Long) hashMap.get(obj)).longValue());
            }
        }
        return arrayOfInt64WithShape;
    }

    @Override // ai.sklearn4j.base.TransformerMixin
    public List<Set<Object>> inverseTransform(NumpyArray<Long> numpyArray) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < numpyArray.getShape()[0]; i++) {
            HashSet hashSet = new HashSet();
            arrayList.add(hashSet);
            for (int i2 = 0; i2 < this.classes.size(); i2++) {
                if (numpyArray.get(i, i2).longValue() != 0) {
                    hashSet.add(this.classes.get(i2));
                }
            }
        }
        return arrayList;
    }
}
