package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.Version;
import org.elasticsearch.action.fieldcaps.FieldCapabilities;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.class */
public class Classification implements DataFrameAnalysis {
    private static final String STATE_DOC_ID_INFIX = "_classification_state#";
    private static final String NUM_CLASSES = "num_classes";
    public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
    private static final String PREDICTION_FIELD_TYPE = "prediction_field_type";
    private static final int DEFAULT_NUM_TOP_CLASSES = 2;
    static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
    private final String dependentVariable;
    private final BoostedTreeParams boostedTreeParams;
    private final String predictionFieldName;
    private final ClassAssignmentObjective classAssignmentObjective;
    private final int numTopClasses;
    private final double trainingPercent;
    private final long randomizeSeed;
    private final List<PreProcessor> featureProcessors;
    private final boolean earlyStoppingEnabled;
    public static final ParseField NAME = new ParseField(ClassificationInferenceResults.NAME, new String[0]);
    public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable", new String[0]);
    public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name", new String[0]);
    public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective", new String[0]);
    public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes", new String[0]);
    public static final ParseField TRAINING_PERCENT = new ParseField("training_percent", new String[0]);
    public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed", new String[0]);
    public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors", new String[0]);
    public static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled", new String[0]);
    private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = createParser(true);
    private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = createParser(false);
    private static final Set<String> ALLOWED_DEPENDENT_VARIABLE_TYPES = (Set) Stream.of((Object[]) new Set[]{Types.categorical(), Types.discreteNumerical(), Types.bool()}).flatMap((v0) -> {
        return v0.stream();
    }).collect(Collectors.toUnmodifiableSet());
    private static final List<String> PROGRESS_PHASES = Collections.unmodifiableList(Arrays.asList("feature_selection", "coarse_parameter_search", "fine_tuning_parameters", "final_training"));

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification$ClassAssignmentObjective.class */
    public enum ClassAssignmentObjective {
        MAXIMIZE_ACCURACY,
        MAXIMIZE_MINIMUM_RECALL;

        public static ClassAssignmentObjective fromString(String str) {
            return valueOf(str.toUpperCase(Locale.ROOT));
        }

        @Override // java.lang.Enum
        public String toString() {
            return name().toLowerCase(Locale.ROOT);
        }
    }

    private static ConstructingObjectParser<Classification, Void> createParser(boolean z) {
        ConstructingObjectParser<Classification, Void> constructingObjectParser = new ConstructingObjectParser<>(NAME.getPreferredName(), z, objArr -> {
            return new Classification((String) objArr[0], new BoostedTreeParams((Double) objArr[1], (Double) objArr[2], (Double) objArr[3], (Integer) objArr[4], (Double) objArr[5], (Integer) objArr[6], (Double) objArr[7], (Double) objArr[8], (Double) objArr[9], (Double) objArr[10], (Double) objArr[11], (Integer) objArr[12]), (String) objArr[13], (ClassAssignmentObjective) objArr[14], (Integer) objArr[15], (Double) objArr[16], (Long) objArr[17], (List) objArr[18], (Boolean) objArr[19]);
        });
        constructingObjectParser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
        BoostedTreeParams.declareFields(constructingObjectParser);
        constructingObjectParser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
        constructingObjectParser.declareString(ConstructingObjectParser.optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
        constructingObjectParser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
        constructingObjectParser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
        constructingObjectParser.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
        constructingObjectParser.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), (xContentParser, r9, str) -> {
            return z ? (PreProcessor) xContentParser.namedObject(LenientlyParsedPreProcessor.class, str, new PreProcessor.PreProcessorParseContext(true)) : (PreProcessor) xContentParser.namedObject(StrictlyParsedPreProcessor.class, str, new PreProcessor.PreProcessorParseContext(true));
        }, classification -> {
        }, FEATURE_PROCESSORS);
        constructingObjectParser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), EARLY_STOPPING_ENABLED);
        return constructingObjectParser;
    }

    public static Classification fromXContent(XContentParser xContentParser, boolean z) {
        return z ? (Classification) LENIENT_PARSER.apply(xContentParser, (Object) null) : (Classification) STRICT_PARSER.apply(xContentParser, (Object) null);
    }

    public Classification(String str, BoostedTreeParams boostedTreeParams, @Nullable String str2, @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable Integer num, @Nullable Double d, @Nullable Long l, @Nullable List<PreProcessor> list, @Nullable Boolean bool) {
        if (num != null && (num.intValue() < -1 || num.intValue() > 1000)) {
            throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
        }
        if (d != null && (d.doubleValue() <= 0.0d || d.doubleValue() > 100.0d)) {
            throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
        }
        this.dependentVariable = (String) ExceptionsHelper.requireNonNull(str, DEPENDENT_VARIABLE);
        this.boostedTreeParams = (BoostedTreeParams) ExceptionsHelper.requireNonNull(boostedTreeParams, "boosted_tree_params");
        this.predictionFieldName = str2 == null ? str + "_prediction" : str2;
        this.classAssignmentObjective = classAssignmentObjective == null ? ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL : classAssignmentObjective;
        this.numTopClasses = num == null ? 2 : num.intValue();
        this.trainingPercent = d == null ? 100.0d : d.doubleValue();
        this.randomizeSeed = l == null ? Randomness.get().nextLong() : l.longValue();
        this.featureProcessors = list == null ? Collections.emptyList() : Collections.unmodifiableList(list);
        this.earlyStoppingEnabled = bool == null ? true : bool.booleanValue();
    }

    public Classification(String str) {
        this(str, BoostedTreeParams.builder().build(), null, null, null, null, null, null, null);
    }

    public Classification(StreamInput streamInput) throws IOException {
        this.dependentVariable = streamInput.readString();
        this.boostedTreeParams = new BoostedTreeParams(streamInput);
        this.predictionFieldName = streamInput.readOptionalString();
        if (streamInput.getVersion().onOrAfter(Version.V_7_7_0)) {
            this.classAssignmentObjective = (ClassAssignmentObjective) streamInput.readEnum(ClassAssignmentObjective.class);
        } else {
            this.classAssignmentObjective = ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL;
        }
        this.numTopClasses = streamInput.readOptionalVInt().intValue();
        this.trainingPercent = streamInput.readDouble();
        if (streamInput.getVersion().onOrAfter(Version.V_7_6_0)) {
            this.randomizeSeed = streamInput.readOptionalLong().longValue();
        } else {
            this.randomizeSeed = Randomness.get().nextLong();
        }
        if (streamInput.getVersion().onOrAfter(Version.V_7_10_0)) {
            this.featureProcessors = Collections.unmodifiableList(streamInput.readNamedWriteableList(PreProcessor.class));
        } else {
            this.featureProcessors = Collections.emptyList();
        }
        this.earlyStoppingEnabled = streamInput.readBoolean();
    }

    public String getDependentVariable() {
        return this.dependentVariable;
    }

    public BoostedTreeParams getBoostedTreeParams() {
        return this.boostedTreeParams;
    }

    public String getPredictionFieldName() {
        return this.predictionFieldName;
    }

    public ClassAssignmentObjective getClassAssignmentObjective() {
        return this.classAssignmentObjective;
    }

    public int getNumTopClasses() {
        return this.numTopClasses;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public double getTrainingPercent() {
        return this.trainingPercent;
    }

    public long getRandomizeSeed() {
        return this.randomizeSeed;
    }

    public List<PreProcessor> getFeatureProcessors() {
        return this.featureProcessors;
    }

    public Boolean getEarlyStoppingEnabled() {
        return Boolean.valueOf(this.earlyStoppingEnabled);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeString(this.dependentVariable);
        this.boostedTreeParams.writeTo(streamOutput);
        streamOutput.writeOptionalString(this.predictionFieldName);
        if (streamOutput.getVersion().onOrAfter(Version.V_7_7_0)) {
            streamOutput.writeEnum(this.classAssignmentObjective);
        }
        streamOutput.writeOptionalVInt(Integer.valueOf(this.numTopClasses));
        streamOutput.writeDouble(this.trainingPercent);
        if (streamOutput.getVersion().onOrAfter(Version.V_7_6_0)) {
            streamOutput.writeOptionalLong(Long.valueOf(this.randomizeSeed));
        }
        if (streamOutput.getVersion().onOrAfter(Version.V_7_10_0)) {
            streamOutput.writeNamedWriteableList(this.featureProcessors);
        }
        streamOutput.writeBoolean(this.earlyStoppingEnabled);
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        Version fromString = Version.fromString(params.param("version", Version.CURRENT.toString()));
        xContentBuilder.startObject();
        xContentBuilder.field(DEPENDENT_VARIABLE.getPreferredName(), this.dependentVariable);
        this.boostedTreeParams.toXContent(xContentBuilder, params);
        xContentBuilder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), this.classAssignmentObjective);
        xContentBuilder.field(NUM_TOP_CLASSES.getPreferredName(), this.numTopClasses);
        if (this.predictionFieldName != null) {
            xContentBuilder.field(PREDICTION_FIELD_NAME.getPreferredName(), this.predictionFieldName);
        }
        xContentBuilder.field(TRAINING_PERCENT.getPreferredName(), this.trainingPercent);
        if (fromString.onOrAfter(Version.V_7_6_0)) {
            xContentBuilder.field(RANDOMIZE_SEED.getPreferredName(), this.randomizeSeed);
        }
        if (!this.featureProcessors.isEmpty()) {
            NamedXContentObjectHelper.writeNamedObjects(xContentBuilder, params, true, FEATURE_PROCESSORS.getPreferredName(), this.featureProcessors);
        }
        xContentBuilder.field(EARLY_STOPPING_ENABLED.getPreferredName(), this.earlyStoppingEnabled);
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public Map<String, Object> getParams(DataFrameAnalysis.FieldInfo fieldInfo) {
        HashMap hashMap = new HashMap();
        hashMap.put(DEPENDENT_VARIABLE.getPreferredName(), this.dependentVariable);
        hashMap.putAll(this.boostedTreeParams.getParams());
        hashMap.put(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), this.classAssignmentObjective);
        hashMap.put(NUM_TOP_CLASSES.getPreferredName(), Integer.valueOf(this.numTopClasses));
        if (this.predictionFieldName != null) {
            hashMap.put(PREDICTION_FIELD_NAME.getPreferredName(), this.predictionFieldName);
        }
        String predictionFieldTypeParamString = getPredictionFieldTypeParamString(getPredictionFieldType(fieldInfo.getTypes(this.dependentVariable)));
        if (predictionFieldTypeParamString != null) {
            hashMap.put(PREDICTION_FIELD_TYPE, predictionFieldTypeParamString);
        }
        hashMap.put(NUM_CLASSES, fieldInfo.getCardinality(this.dependentVariable));
        hashMap.put(TRAINING_PERCENT.getPreferredName(), Double.valueOf(this.trainingPercent));
        if (!this.featureProcessors.isEmpty()) {
            hashMap.put(FEATURE_PROCESSORS.getPreferredName(), this.featureProcessors.stream().map(preProcessor -> {
                return Collections.singletonMap(preProcessor.getName(), preProcessor);
            }).collect(Collectors.toList()));
        }
        hashMap.put(EARLY_STOPPING_ENABLED.getPreferredName(), Boolean.valueOf(this.earlyStoppingEnabled));
        return hashMap;
    }

    private static String getPredictionFieldTypeParamString(PredictionFieldType predictionFieldType) {
        if (predictionFieldType == null) {
            return null;
        }
        switch (predictionFieldType) {
            case NUMBER:
                return "int";
            case STRING:
                return "string";
            case BOOLEAN:
                return "bool";
            default:
                return null;
        }
    }

    public static PredictionFieldType getPredictionFieldType(Set<String> set) {
        if (set == null) {
            return null;
        }
        if (Types.categorical().containsAll(set)) {
            return PredictionFieldType.STRING;
        }
        if (Types.bool().containsAll(set)) {
            return PredictionFieldType.BOOLEAN;
        }
        if (Types.discreteNumerical().containsAll(set)) {
            return PredictionFieldType.NUMBER;
        }
        return null;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public boolean supportsCategoricalFields() {
        return true;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public Set<String> getAllowedCategoricalTypes(String str) {
        return this.dependentVariable.equals(str) ? ALLOWED_DEPENDENT_VARIABLE_TYPES : Types.categorical();
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public List<RequiredField> getRequiredFields() {
        return Collections.singletonList(new RequiredField(this.dependentVariable, ALLOWED_DEPENDENT_VARIABLE_TYPES));
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
        return Collections.singletonList(FieldCardinalityConstraint.between(this.dependentVariable, 2L, 30L));
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public Map<String, Object> getResultMappings(String str, FieldCapabilitiesResponse fieldCapabilitiesResponse) {
        HashMap hashMap = new HashMap();
        hashMap.put(str + ".is_training", Collections.singletonMap("type", ElasticsearchMappings.BOOLEAN));
        hashMap.put(str + ".prediction_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
        hashMap.put(str + ".prediction_score", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
        hashMap.put(str + ".feature_importance", FEATURE_IMPORTANCE_MAPPING);
        Map field = fieldCapabilitiesResponse.getField(this.dependentVariable);
        if (field == null || field.isEmpty()) {
            throw ExceptionsHelper.badRequestException("no mappings could be found for required field [{}]", DEPENDENT_VARIABLE);
        }
        String type = ((FieldCapabilities) field.values().iterator().next()).getType();
        hashMap.put(str + "." + this.predictionFieldName, Collections.singletonMap("type", type));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("class_name", Collections.singletonMap("type", type));
        hashMap2.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
        hashMap2.put("class_score", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
        HashMap hashMap3 = new HashMap();
        hashMap3.put("type", ElasticsearchMappings.NESTED);
        hashMap3.put(ElasticsearchMappings.PROPERTIES, hashMap2);
        hashMap.put(str + ".top_classes", hashMap3);
        return hashMap;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public boolean supportsMissingValues() {
        return true;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public boolean persistsState() {
        return true;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public String getStateDocIdPrefix(String str) {
        return str + "_classification_state#";
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public List<String> getProgressPhases() {
        return PROGRESS_PHASES;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public InferenceConfig inferenceConfig(DataFrameAnalysis.FieldInfo fieldInfo) {
        return ClassificationConfig.builder().setResultsField(this.predictionFieldName).setNumTopClasses(Integer.valueOf(this.numTopClasses)).setNumTopFeatureImportanceValues(getBoostedTreeParams().getNumTopFeatureImportanceValues()).setPredictionFieldType(getPredictionFieldType(fieldInfo.getTypes(this.dependentVariable))).build();
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis
    public boolean supportsInference() {
        return true;
    }

    public static String extractJobIdFromStateDoc(String str) {
        int lastIndexOf = str.lastIndexOf(STATE_DOC_ID_INFIX);
        if (lastIndexOf <= 0) {
            return null;
        }
        return str.substring(0, lastIndexOf);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Classification classification = (Classification) obj;
        return Objects.equals(this.dependentVariable, classification.dependentVariable) && Objects.equals(this.boostedTreeParams, classification.boostedTreeParams) && Objects.equals(this.predictionFieldName, classification.predictionFieldName) && Objects.equals(this.classAssignmentObjective, classification.classAssignmentObjective) && Objects.equals(Integer.valueOf(this.numTopClasses), Integer.valueOf(classification.numTopClasses)) && Objects.equals(this.featureProcessors, classification.featureProcessors) && Objects.equals(Boolean.valueOf(this.earlyStoppingEnabled), Boolean.valueOf(classification.earlyStoppingEnabled)) && this.trainingPercent == classification.trainingPercent && this.randomizeSeed == classification.randomizeSeed;
    }

    public int hashCode() {
        return Objects.hash(this.dependentVariable, this.boostedTreeParams, this.predictionFieldName, this.classAssignmentObjective, Integer.valueOf(this.numTopClasses), Double.valueOf(this.trainingPercent), Long.valueOf(this.randomizeSeed), this.featureProcessors, Boolean.valueOf(this.earlyStoppingEnabled));
    }

    static {
        HashMap hashMap = new HashMap();
        hashMap.put("class_name", Collections.singletonMap("type", ElasticsearchMappings.KEYWORD));
        hashMap.put("importance", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
        HashMap hashMap2 = new HashMap();
        hashMap2.put(ElasticsearchMappings.DYNAMIC, false);
        hashMap2.put("type", ElasticsearchMappings.NESTED);
        hashMap2.put(ElasticsearchMappings.PROPERTIES, hashMap);
        HashMap hashMap3 = new HashMap();
        hashMap3.put("feature_name", Collections.singletonMap("type", ElasticsearchMappings.KEYWORD));
        hashMap3.put("classes", hashMap2);
        HashMap hashMap4 = new HashMap();
        hashMap4.put(ElasticsearchMappings.DYNAMIC, false);
        hashMap4.put("type", ElasticsearchMappings.NESTED);
        hashMap4.put(ElasticsearchMappings.PROPERTIES, hashMap3);
        FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(hashMap4);
    }
}
