package org.flinkextended.flink.ml.tensorflow.util;

import java.util.Arrays;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.flinkextended.flink.ml.operator.util.DataTypes;
import org.flinkextended.flink.ml.tensorflow.client.TFConfigBase;
import org.flinkextended.flink.ml.tensorflow.coding.ExampleCoding;
import org.flinkextended.flink.ml.tensorflow.coding.ExampleCodingConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/ExampleCodingConfigUtil.class */
public class ExampleCodingConfigUtil {
    private static Logger LOG = LoggerFactory.getLogger(ExampleCodingConfigUtil.class);

    public static TypeInformation[] dataTypesListToTypeInformation(DataTypes[] dataTypesArr) {
        return (TypeInformation[]) Arrays.stream(dataTypesArr).map(ExampleCodingConfigUtil::dataTypesToTypeInformation).toArray(i -> {
            return new TypeInformation[i];
        });
    }

    public static TypeInformation dataTypesToTypeInformation(DataTypes dataTypes) {
        if (dataTypes == null) {
            return null;
        }
        if (dataTypes == DataTypes.STRING) {
            return BasicTypeInfo.STRING_TYPE_INFO;
        }
        if (dataTypes == DataTypes.BOOL) {
            return BasicTypeInfo.BOOLEAN_TYPE_INFO;
        }
        if (dataTypes == DataTypes.INT_8) {
            return BasicTypeInfo.BYTE_TYPE_INFO;
        }
        if (dataTypes == DataTypes.INT_16) {
            return BasicTypeInfo.SHORT_TYPE_INFO;
        }
        if (dataTypes == DataTypes.INT_32) {
            return BasicTypeInfo.INT_TYPE_INFO;
        }
        if (dataTypes == DataTypes.INT_64) {
            return BasicTypeInfo.LONG_TYPE_INFO;
        }
        if (dataTypes == DataTypes.FLOAT_32) {
            return BasicTypeInfo.FLOAT_TYPE_INFO;
        }
        if (dataTypes == DataTypes.FLOAT_64) {
            return BasicTypeInfo.DOUBLE_TYPE_INFO;
        }
        if (dataTypes == DataTypes.UINT_16) {
            return BasicTypeInfo.CHAR_TYPE_INFO;
        }
        if (dataTypes == DataTypes.FLOAT_32_ARRAY) {
            return BasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO;
        }
        throw new RuntimeException("Unsupported data type of " + dataTypes.toString());
    }

    public static DataTypes[] typeInormationListToDataTypes(TypeInformation[] typeInformationArr) {
        return (DataTypes[]) Arrays.stream(typeInformationArr).map(ExampleCodingConfigUtil::typeInformationToDataTypes).toArray(i -> {
            return new DataTypes[i];
        });
    }

    public static DataTypes typeInformationToDataTypes(TypeInformation typeInformation) {
        if (typeInformation == null) {
            return null;
        }
        if (typeInformation == BasicTypeInfo.STRING_TYPE_INFO) {
            return DataTypes.STRING;
        }
        if (typeInformation == BasicTypeInfo.BOOLEAN_TYPE_INFO) {
            return DataTypes.BOOL;
        }
        if (typeInformation == BasicTypeInfo.BYTE_TYPE_INFO) {
            return DataTypes.INT_8;
        }
        if (typeInformation == BasicTypeInfo.SHORT_TYPE_INFO) {
            return DataTypes.INT_16;
        }
        if (typeInformation == BasicTypeInfo.INT_TYPE_INFO) {
            return DataTypes.INT_32;
        }
        if (typeInformation == BasicTypeInfo.LONG_TYPE_INFO) {
            return DataTypes.INT_64;
        }
        if (typeInformation == BasicTypeInfo.FLOAT_TYPE_INFO) {
            return DataTypes.FLOAT_32;
        }
        if (typeInformation == BasicTypeInfo.DOUBLE_TYPE_INFO) {
            return DataTypes.FLOAT_64;
        }
        if (typeInformation == BasicTypeInfo.CHAR_TYPE_INFO) {
            return DataTypes.UINT_16;
        }
        if (typeInformation == BasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO) {
            return DataTypes.FLOAT_32_ARRAY;
        }
        throw new RuntimeException("Unsupported data type of " + typeInformation.toString());
    }

    public static void configureEncodeExampleCoding(TFConfigBase tFConfigBase, String[] strArr, DataTypes[] dataTypesArr, ExampleCodingConfig.ObjectType objectType, Class cls) {
        String createExampleConfigStr = ExampleCodingConfig.createExampleConfigStr(strArr, dataTypesArr, objectType, cls);
        LOG.info("input tf example config: " + createExampleConfigStr);
        tFConfigBase.getProperties().put("sys:input_tf_example_config", createExampleConfigStr);
        tFConfigBase.getProperties().put("sys:encoding_class", ExampleCoding.class.getCanonicalName());
    }

    public static void configureDecodeExampleCoding(TFConfigBase tFConfigBase, String[] strArr, DataTypes[] dataTypesArr, ExampleCodingConfig.ObjectType objectType, Class cls) {
        String createExampleConfigStr = ExampleCodingConfig.createExampleConfigStr(strArr, dataTypesArr, objectType, cls);
        LOG.info("output tf example config: " + createExampleConfigStr);
        tFConfigBase.getProperties().put("sys:output_tf_example_config", createExampleConfigStr);
        tFConfigBase.getProperties().put("sys:decoding_class", ExampleCoding.class.getCanonicalName());
    }

    public static void configureEncodeExampleCoding(TFConfigBase tFConfigBase, String[] strArr, TypeInformation[] typeInformationArr, ExampleCodingConfig.ObjectType objectType, Class cls) {
        configureEncodeExampleCoding(tFConfigBase, strArr, (DataTypes[]) Arrays.stream(typeInformationArr).map(ExampleCodingConfigUtil::typeInformationToDataTypes).toArray(i -> {
            return new DataTypes[i];
        }), objectType, cls);
    }

    public static void configureDecodeExampleCoding(TFConfigBase tFConfigBase, String[] strArr, TypeInformation[] typeInformationArr, ExampleCodingConfig.ObjectType objectType, Class cls) {
        configureDecodeExampleCoding(tFConfigBase, strArr, (DataTypes[]) Arrays.stream(typeInformationArr).map(ExampleCodingConfigUtil::typeInformationToDataTypes).toArray(i -> {
            return new DataTypes[i];
        }), objectType, cls);
    }

    public static void configureExampleCoding(TFConfigBase tFConfigBase, TableSchema tableSchema, TableSchema tableSchema2, ExampleCodingConfig.ObjectType objectType, Class cls) {
        if (tableSchema != null) {
            configureEncodeExampleCoding(tFConfigBase, tableSchema.getFieldNames(), tableSchema.getFieldTypes(), objectType, cls);
        }
        if (tableSchema2 != null) {
            configureDecodeExampleCoding(tFConfigBase, tableSchema2.getFieldNames(), tableSchema2.getFieldTypes(), objectType, cls);
        }
    }

    public static void configureExampleCoding(TFConfigBase tFConfigBase, TableSchema tableSchema, TableSchema tableSchema2) {
        configureExampleCoding(tFConfigBase, tableSchema, tableSchema2, ExampleCodingConfig.ObjectType.ROW, Row.class);
    }
}
