package ai.sklearn4j.preprocessing.label;

import ai.sklearn4j.base.TransformerMixin;
import ai.sklearn4j.core.ScikitLearnCoreException;
import ai.sklearn4j.core.ScikitLearnFeatureNotImplementedException;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/sklearn4j/preprocessing/label/LabelBinarizer.class */
public class LabelBinarizer extends TransformerMixin<List<Object>, NumpyArray<Long>> {
    private static final String Y_TYPE_BINARY = "binary";
    private static final String Y_TYPE_MULTICLASS = "multiclass";
    private List<Object> classes = null;
    private String yType = null;
    private long negativeLabel = 0;
    private long positiveLabel = 1;

    public void setClasses(List<Object> list) {
        this.classes = list;
    }

    public List<Object> getClasses() {
        return this.classes;
    }

    public void setYType(String str) {
        this.yType = str;
    }

    public String getYType() {
        return this.yType;
    }

    public void setNegativeLabel(long j) {
        this.negativeLabel = j;
    }

    public long getNegativeLabel() {
        return this.negativeLabel;
    }

    public void setPositiveLabel(long j) {
        this.positiveLabel = j;
    }

    public long getPositiveLabel() {
        return this.positiveLabel;
    }

    @Override // ai.sklearn4j.base.TransformerMixin
    public NumpyArray<Long> transform(List<Object> list) {
        if (this.yType.equals(Y_TYPE_BINARY)) {
            return transformBinary(list);
        }
        if (this.yType.equals(Y_TYPE_MULTICLASS)) {
            return transformMulticlass(list);
        }
        throw new ScikitLearnFeatureNotImplementedException(String.format("The yType=%s in LabelBinarizer is not implemented.", this.yType));
    }

    private NumpyArray<Long> transformBinary(List<Object> list) {
        HashMap hashMap = new HashMap();
        Iterator<Object> it = this.classes.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), Integer.valueOf(hashMap.size()));
        }
        NumpyArray<Long> arrayOfInt64WithShape = NumpyArrayFactory.arrayOfInt64WithShape(new int[]{list.size(), 1});
        int i = 0;
        for (Object obj : list) {
            if (!hashMap.containsKey(obj)) {
                throw new ScikitLearnCoreException(String.format("The class '%s' was not defined during the LabelEncoder training.", obj.toString()));
            }
            arrayOfInt64WithShape.set(Long.valueOf(((Integer) hashMap.get(obj)).intValue() == 0 ? this.negativeLabel : this.positiveLabel), i, 0);
            i++;
        }
        return arrayOfInt64WithShape;
    }

    private NumpyArray<Long> transformMulticlass(List<Object> list) {
        HashMap hashMap = new HashMap();
        Iterator<Object> it = this.classes.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), Integer.valueOf(hashMap.size()));
        }
        NumpyArray<Long> arrayOfInt64WithShape = NumpyArrayFactory.arrayOfInt64WithShape(new int[]{list.size(), this.classes.size()});
        int i = 0;
        int size = this.classes.size();
        for (Object obj : list) {
            if (!hashMap.containsKey(obj)) {
                throw new ScikitLearnCoreException(String.format("The class '%s' was not defined during the LabelEncoder training.", obj.toString()));
            }
            int intValue = ((Integer) hashMap.get(obj)).intValue();
            for (int i2 = 0; i2 < size; i2++) {
                if (i2 == intValue) {
                    arrayOfInt64WithShape.set(Long.valueOf(this.positiveLabel), i, i2);
                } else {
                    arrayOfInt64WithShape.set(Long.valueOf(this.negativeLabel), i, i2);
                }
            }
            i++;
        }
        return arrayOfInt64WithShape;
    }

    @Override // ai.sklearn4j.base.TransformerMixin
    public List<Object> inverseTransform(NumpyArray<Long> numpyArray) {
        if (this.yType.equals(Y_TYPE_BINARY)) {
            return inverseTransformBinary(numpyArray);
        }
        if (this.yType.equals(Y_TYPE_MULTICLASS)) {
            return inverseTransformMulticlass(numpyArray);
        }
        throw new ScikitLearnFeatureNotImplementedException(String.format("The yType=%s in LabelBinarizer is not implemented.", this.yType));
    }

    private List<Object> inverseTransformBinary(NumpyArray<Long> numpyArray) {
        ArrayList arrayList = new ArrayList();
        for (long[] jArr : (long[][]) numpyArray.getWrapper().getRawArray()) {
            arrayList.add(((long) ((int) jArr[0])) == this.negativeLabel ? this.classes.get(0) : this.classes.get(1));
        }
        return arrayList;
    }

    private List<Object> inverseTransformMulticlass(NumpyArray<Long> numpyArray) {
        ArrayList arrayList = new ArrayList();
        for (long[] jArr : (long[][]) numpyArray.getWrapper().getRawArray()) {
            int positiveLabelIndex = getPositiveLabelIndex(jArr);
            if (positiveLabelIndex < 0 || positiveLabelIndex >= this.classes.size()) {
                throw new ScikitLearnCoreException(String.format("The class '%d' is not in valid range.", Integer.valueOf(positiveLabelIndex)));
            }
            arrayList.add(this.classes.get(positiveLabelIndex));
        }
        return arrayList;
    }

    private int getPositiveLabelIndex(long[] jArr) {
        int i = -1;
        int i2 = 0;
        while (true) {
            if (i2 >= jArr.length) {
                break;
            }
            if (jArr[i2] == this.positiveLabel) {
                i = i2;
                break;
            }
            i2++;
        }
        return i;
    }
}
