package ai.libs.mlplan.sklearn.builder;

import ai.libs.jaicore.ml.core.evaluation.evaluator.factory.AMonteCarloCrossValidationBasedEvaluatorFactory;
import ai.libs.jaicore.ml.core.evaluation.evaluator.factory.ISupervisedLearnerEvaluatorFactory;
import ai.libs.jaicore.ml.scikitwrapper.ScikitLearnWrapper;
import ai.libs.mlplan.core.AMLPlanBuilder;
import ai.libs.mlplan.core.IProblemType;
import ai.libs.mlplan.core.MLPlan;
import ai.libs.mlplan.sklearn.AScikitLearnLearnerFactory;
import ai.libs.mlplan.sklearn.EMLPlanScikitLearnProblemType;
import ai.libs.python.IPythonConfig;
import ai.libs.python.PythonRequirementDefinition;
import java.io.IOException;
import org.apache.commons.lang3.ArrayUtils;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.algorithm.Timeout;

/* loaded from: input_file:ai/libs/mlplan/sklearn/builder/MLPlanScikitLearnBuilder.class */
public class MLPlanScikitLearnBuilder extends AMLPlanBuilder<ScikitLearnWrapper<IPrediction, IPredictionBatch>, MLPlanScikitLearnBuilder> {
    private static final int PYTHON_MINIMUM_REQUIRED_VERSION_REL = 3;
    private static final int PYTHON_MINIMUM_REQUIRED_VERSION_MAJ = 5;
    private static final int PYTHON_MINIMUM_REQUIRED_VERSION_MIN = 0;
    private static final String[] PYTHON_REQUIRED_MODULES = {"arff", "numpy", "json", "pickle", "os", "sys", "warnings", "scipy", "sklearn", "tpot", "pandas", "xgboost"};
    private IPythonConfig pythonConfig;
    private String[] pythonAdditionalRequiredModules;
    private final boolean skipSetupCheck;

    public static MLPlanScikitLearnBuilder forClassification() throws IOException {
        return new MLPlanScikitLearnBuilder(EMLPlanScikitLearnProblemType.CLASSIFICATION_MULTICLASS);
    }

    public static MLPlanScikitLearnBuilder forClassificationWithUnlimitedLength() throws IOException {
        return new MLPlanScikitLearnBuilder(EMLPlanScikitLearnProblemType.CLASSIFICATION_MULTICLASS_UNLIMITED_LENGTH_PIPELINES);
    }

    public static MLPlanScikitLearnBuilder forRegression() throws IOException {
        return new MLPlanScikitLearnBuilder(EMLPlanScikitLearnProblemType.REGRESSION);
    }

    public static MLPlanScikitLearnBuilder forRUL() throws IOException {
        return new MLPlanScikitLearnBuilder(EMLPlanScikitLearnProblemType.RUL);
    }

    protected MLPlanScikitLearnBuilder(EMLPlanScikitLearnProblemType eMLPlanScikitLearnProblemType) throws IOException {
        this(eMLPlanScikitLearnProblemType, false);
        this.pythonAdditionalRequiredModules = eMLPlanScikitLearnProblemType.getSkLearnProblemType().getPythonRequiredModules();
    }

    public MLPlanScikitLearnBuilder(EMLPlanScikitLearnProblemType eMLPlanScikitLearnProblemType, boolean z) throws IOException {
        super(eMLPlanScikitLearnProblemType);
        this.skipSetupCheck = z;
    }

    public MLPlanScikitLearnBuilder withProblemType(IProblemType<ScikitLearnWrapper<IPrediction, IPredictionBatch>> iProblemType) throws IOException {
        super.withProblemType(iProblemType);
        this.pythonAdditionalRequiredModules = ((EMLPlanScikitLearnProblemType) iProblemType).getSkLearnProblemType().getPythonRequiredModules();
        return m10getSelf();
    }

    /* renamed from: withSeed, reason: merged with bridge method [inline-methods] */
    public MLPlanScikitLearnBuilder m7withSeed(long j) {
        super.withSeed(j);
        if (m6getLearnerFactory() != null) {
            m6getLearnerFactory().setSeed(j);
        }
        return m10getSelf();
    }

    /* renamed from: withCandidateEvaluationTimeOut, reason: merged with bridge method [inline-methods] */
    public MLPlanScikitLearnBuilder m8withCandidateEvaluationTimeOut(Timeout timeout) {
        super.withCandidateEvaluationTimeOut(timeout);
        if (m6getLearnerFactory() != null) {
            m6getLearnerFactory().setTimeout(timeout);
        }
        return m10getSelf();
    }

    /* renamed from: getLearnerFactory, reason: merged with bridge method [inline-methods] */
    public AScikitLearnLearnerFactory m6getLearnerFactory() {
        return (AScikitLearnLearnerFactory) super.getLearnerFactory();
    }

    /* renamed from: getSelf, reason: merged with bridge method [inline-methods] */
    public MLPlanScikitLearnBuilder m10getSelf() {
        return this;
    }

    private void setDeterministicDatasetSplitter(ISupervisedLearnerEvaluatorFactory<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> iSupervisedLearnerEvaluatorFactory) {
        if (iSupervisedLearnerEvaluatorFactory instanceof AMonteCarloCrossValidationBasedEvaluatorFactory) {
            ((AMonteCarloCrossValidationBasedEvaluatorFactory) iSupervisedLearnerEvaluatorFactory).withCacheSplitSets(true);
        }
    }

    public MLPlan<ScikitLearnWrapper<IPrediction, IPredictionBatch>> build() {
        if (!this.skipSetupCheck) {
            new PythonRequirementDefinition(PYTHON_MINIMUM_REQUIRED_VERSION_REL, PYTHON_MINIMUM_REQUIRED_VERSION_MAJ, PYTHON_MINIMUM_REQUIRED_VERSION_MIN, (String[]) ArrayUtils.addAll(PYTHON_REQUIRED_MODULES, this.pythonAdditionalRequiredModules)).check(this.pythonConfig);
        }
        setDeterministicDatasetSplitter(getLearnerEvaluationFactoryForSearchPhase());
        setDeterministicDatasetSplitter(getLearnerEvaluationFactoryForSelectionPhase());
        return super.build();
    }

    /* renamed from: withProblemType, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ AMLPlanBuilder m9withProblemType(IProblemType iProblemType) throws IOException {
        return withProblemType((IProblemType<ScikitLearnWrapper<IPrediction, IPredictionBatch>>) iProblemType);
    }
}
