package org.jpmml.sparkml.xgboost;

import com.google.common.io.MoreFiles;
import com.google.common.io.RecursiveDeleteOption;
import java.io.File;
import java.io.FileInputStream;
import java.util.LinkedHashMap;
import java.util.function.Function;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.XGBoostUtil;

/* loaded from: input_file:org/jpmml/sparkml/xgboost/BoosterUtil.class */
public class BoosterUtil {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.xgboost.BoosterUtil$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/xgboost/BoosterUtil$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.INTEGER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    private BoosterUtil() {
    }

    public static <M extends Model<M> & HasPredictionCol & GeneralParams, C extends ModelConverter<M> & HasSparkMLXGBoostOptions> MiningModel encodeBooster(C c, Booster booster, Schema schema) {
        GeneralParams model = c.getModel();
        try {
            File createTempFile = File.createTempFile("Booster", ".json");
            booster.saveModel(createTempFile.getAbsolutePath());
            FileInputStream fileInputStream = new FileInputStream(createTempFile);
            Throwable th = null;
            try {
                try {
                    Learner loadLearner = XGBoostUtil.loadLearner(fileInputStream);
                    if (fileInputStream != null) {
                        if (0 != 0) {
                            try {
                                fileInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            fileInputStream.close();
                        }
                    }
                    MoreFiles.deleteRecursively(createTempFile.toPath(), new RecursiveDeleteOption[0]);
                    if (Boolean.TRUE.equals((Boolean) c.getOption(HasSparkMLXGBoostOptions.OPTION_INPUT_FLOAT, (Object) null))) {
                        schema = schema.toTransformedSchema(new Function<Feature, Feature>() { // from class: org.jpmml.sparkml.xgboost.BoosterUtil.1
                            @Override // java.util.function.Function
                            public Feature apply(Feature feature) {
                                if (feature instanceof ContinuousFeature) {
                                    ContinuousFeature continuousFeature = (ContinuousFeature) feature;
                                    switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$DataType[continuousFeature.getDataType().ordinal()]) {
                                        case 3:
                                            Field field = continuousFeature.getField();
                                            field.setDataType(DataType.FLOAT);
                                            return new ContinuousFeature(continuousFeature.getEncoder(), field);
                                    }
                                }
                                return feature;
                            }
                        });
                    }
                    Float valueOf = Float.valueOf(model.getMissing());
                    if (valueOf.isNaN()) {
                        valueOf = null;
                    }
                    LinkedHashMap linkedHashMap = new LinkedHashMap();
                    linkedHashMap.put("missing", c.getOption("missing", valueOf));
                    linkedHashMap.put("compact", c.getOption("compact", false));
                    linkedHashMap.put("numeric", c.getOption("numeric", true));
                    linkedHashMap.put("prune", c.getOption("prune", false));
                    linkedHashMap.put("ntree_limit", c.getOption("ntree_limit", (Object) null));
                    return loadLearner.encodeMiningModel(linkedHashMap, loadLearner.toXGBoostSchema(((Boolean) linkedHashMap.get("numeric")).booleanValue(), schema));
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
