package org.flinkextended.flink.ml.tensorflow.coding;

import com.alibaba.fastjson.JSONObject;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.Map;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.coding.Coding;
import org.flinkextended.flink.ml.coding.CodingException;
import org.flinkextended.flink.ml.operator.util.DataTypes;
import org.tensorflow.proto.example.Example;
import org.tensorflow.proto.example.Feature;
import org.tensorflow.proto.example.Features;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/coding/ExampleCoding.class */
public class ExampleCoding implements Coding<Object> {
    private MLContext mlContext;
    private ExampleCodingConfig inputConfig = new ExampleCodingConfig();
    private ExampleCodingConfig outputConfig;

    public ExampleCoding(MLContext mLContext) throws CodingException {
        this.mlContext = mLContext;
        JSONObject parseObject = JSONObject.parseObject((String) mLContext.getProperties().get("sys:input_tf_example_config"));
        if (parseObject != null) {
            this.inputConfig.fromJsonObject(parseObject);
        }
        this.outputConfig = new ExampleCodingConfig();
        JSONObject parseObject2 = JSONObject.parseObject((String) mLContext.getProperties().get("sys:output_tf_example_config"));
        if (parseObject2 != null) {
            this.outputConfig.fromJsonObject(parseObject2);
        }
    }

    public Object decode(byte[] bArr) throws CodingException {
        Feature feature;
        try {
            Map featureMap = Example.parseFrom(bArr).getFeatures().getFeatureMap();
            ArrayList arrayList = new ArrayList(this.outputConfig.count());
            for (int i = 0; i < this.outputConfig.count(); i++) {
                String colName = this.outputConfig.getColName(i);
                DataTypes type = this.outputConfig.getType(i);
                if (colName != null && (feature = (Feature) featureMap.get(colName)) != null) {
                    arrayList.add(TFExampleConversion.featureToJava(type, feature));
                }
            }
            return this.outputConfig.createResultObject(arrayList);
        } catch (InvalidProtocolBufferException e) {
            e.printStackTrace();
            throw new CodingException(e.getMessage());
        }
    }

    public byte[] encode(Object obj) throws CodingException {
        Example.Builder newBuilder = Example.newBuilder();
        Features.Builder featuresBuilder = newBuilder.getFeaturesBuilder();
        for (int i = 0; i < this.inputConfig.count(); i++) {
            featuresBuilder.putFeature(this.inputConfig.getColName(i), TFExampleConversion.javaToFeature(this.inputConfig.getType(i), this.inputConfig.getField(obj, i)));
        }
        newBuilder.setFeatures(featuresBuilder);
        return newBuilder.build().toByteArray();
    }
}
