package org.flinkextended.flink.ml.tensorflow.io;

import com.google.common.base.Preconditions;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.types.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.proto.example.Example;
import org.tensorflow.proto.example.Feature;
import org.tensorflow.proto.example.Features;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper.class */
public class TFRExtractRowHelper implements Serializable {
    private static final Logger LOG = LoggerFactory.getLogger(TFRExtractRowHelper.class);
    private final RowTypeInfo outputRowType;
    private final AggFunc[] aggs;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper$AggFunc.class */
    public interface AggFunc<T, R> extends Serializable {
        R aggregate(List<T> list);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper$FirstAgg.class */
    public static class FirstAgg<T> implements AggFunc<T, T> {
        private FirstAgg() {
        }

        @Override // org.flinkextended.flink.ml.tensorflow.io.TFRExtractRowHelper.AggFunc
        public T aggregate(List<T> list) {
            Preconditions.checkArgument(!list.isEmpty(), "Value list is empty");
            return list.get(0);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper$LastAgg.class */
    public static class LastAgg<T> implements AggFunc<T, T> {
        private LastAgg() {
        }

        @Override // org.flinkextended.flink.ml.tensorflow.io.TFRExtractRowHelper.AggFunc
        public T aggregate(List<T> list) {
            Preconditions.checkArgument(!list.isEmpty(), "Value list is empty");
            return list.get(list.size() - 1);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper$MaxAgg.class */
    public static class MaxAgg<T extends Comparable<T>> implements AggFunc<T, T> {
        private MaxAgg() {
        }

        @Override // org.flinkextended.flink.ml.tensorflow.io.TFRExtractRowHelper.AggFunc
        public T aggregate(List<T> list) {
            Preconditions.checkArgument(!list.isEmpty(), "Value list is empty");
            T t = null;
            for (T t2 : list) {
                if (t == null || t.compareTo(t2) < 0) {
                    t = t2;
                }
            }
            return t;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper$MinAgg.class */
    public static class MinAgg<T extends Comparable<T>> implements AggFunc<T, T> {
        private MinAgg() {
        }

        @Override // org.flinkextended.flink.ml.tensorflow.io.TFRExtractRowHelper.AggFunc
        public T aggregate(List<T> list) {
            Preconditions.checkArgument(!list.isEmpty(), "Value list is empty");
            T t = null;
            for (T t2 : list) {
                if (t == null || t.compareTo(t2) > 0) {
                    t = t2;
                }
            }
            return t;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper$OneHotAgg.class */
    public static class OneHotAgg<T extends Number> implements AggFunc<T, Integer> {
        private OneHotAgg() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.flinkextended.flink.ml.tensorflow.io.TFRExtractRowHelper.AggFunc
        public Integer aggregate(List<T> list) {
            Preconditions.checkArgument(!list.isEmpty(), "Value list is empty");
            int i = -1;
            for (int i2 = 0; i2 < list.size(); i2++) {
                T t = list.get(i2);
                if (t.longValue() == 1) {
                    Preconditions.checkArgument(i == -1, "Invalid one-hot list: " + list.toString());
                    i = i2;
                } else {
                    Preconditions.checkArgument(t.longValue() == 0, "Invalid one-hot list: " + list.toString());
                }
            }
            Preconditions.checkArgument(i != -1, "Invalid one-hot list: " + list.toString());
            return Integer.valueOf(i);
        }
    }

    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/io/TFRExtractRowHelper$ScalarConverter.class */
    public enum ScalarConverter {
        FIRST,
        LAST,
        MAX,
        MIN,
        ONE_HOT
    }

    private static AggFunc getAggFunc(ScalarConverter scalarConverter) {
        switch (scalarConverter) {
            case FIRST:
                return new FirstAgg();
            case LAST:
                return new LastAgg();
            case MAX:
                return new MaxAgg();
            case MIN:
                return new MinAgg();
            case ONE_HOT:
                return new OneHotAgg();
            default:
                throw new IllegalArgumentException("Unsupported converter " + scalarConverter);
        }
    }

    public TFRExtractRowHelper(RowTypeInfo rowTypeInfo, ScalarConverter[] scalarConverterArr) {
        Preconditions.checkArgument(rowTypeInfo.getArity() == scalarConverterArr.length);
        this.outputRowType = rowTypeInfo;
        this.aggs = new AggFunc[scalarConverterArr.length];
        for (int i = 0; i < this.aggs.length; i++) {
            this.aggs[i] = getAggFunc(scalarConverterArr[i]);
        }
    }

    public Row extract(byte[] bArr) throws InvalidProtocolBufferException {
        Features features = Example.parseFrom(bArr).getFeatures();
        Preconditions.checkArgument(this.outputRowType.getArity() == features.getFeatureCount(), String.format("RowType arity (%d) and example feature count (%d) mismatch", Integer.valueOf(this.outputRowType.getArity()), Integer.valueOf(features.getFeatureCount())));
        Row row = new Row(this.outputRowType.getArity());
        for (int i = 0; i < this.outputRowType.getArity(); i++) {
            String str = this.outputRowType.getFieldNames()[i];
            row.setField(i, toObject((Feature) Preconditions.checkNotNull(features.getFeatureOrDefault(str, (Feature) null), String.format("Field name %s doesn't exist in example", str)), this.outputRowType.getFieldTypes()[i], this.aggs[i]));
        }
        return row;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Object toObject(Feature feature, TypeInformation typeInformation, AggFunc aggFunc) {
        String format = String.format("Cannot convert %s to %s", feature.toString(), typeInformation.toString());
        boolean z = typeInformation instanceof PrimitiveArrayTypeInfo;
        if (z) {
            typeInformation = ((PrimitiveArrayTypeInfo) typeInformation).getComponentType();
        }
        if (typeInformation.equals(InternalTypeInfo.of(DataTypes.STRING().getLogicalType())) || typeInformation.equals(BasicTypeInfo.STRING_TYPE_INFO)) {
            Preconditions.checkArgument(feature.hasBytesList(), format);
            String[] strArr = new String[feature.getBytesList().getValueCount()];
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = feature.getBytesList().getValue(i).toString(StandardCharsets.ISO_8859_1);
            }
            return z ? strArr : aggFunc.aggregate(Arrays.asList(strArr));
        }
        if (typeInformation.equals(InternalTypeInfo.of(DataTypes.TINYINT().getLogicalType())) || typeInformation.equals(BasicTypeInfo.SHORT_TYPE_INFO)) {
            Preconditions.checkArgument(feature.hasInt64List(), format);
            List list = (List) feature.getInt64List().getValueList().stream().map((v0) -> {
                return v0.shortValue();
            }).collect(Collectors.toList());
            return z ? list.toArray(new Short[0]) : aggFunc.aggregate(list);
        }
        if (typeInformation.equals(InternalTypeInfo.of(DataTypes.INT().getLogicalType())) || typeInformation.equals(BasicTypeInfo.INT_TYPE_INFO)) {
            Preconditions.checkArgument(feature.hasInt64List(), format);
            List list2 = (List) feature.getInt64List().getValueList().stream().map((v0) -> {
                return v0.intValue();
            }).collect(Collectors.toList());
            return z ? list2.toArray(new Integer[0]) : aggFunc.aggregate(list2);
        }
        if (typeInformation.equals(InternalTypeInfo.of(DataTypes.BIGINT().getLogicalType())) || typeInformation.equals(BasicTypeInfo.LONG_TYPE_INFO)) {
            Preconditions.checkArgument(feature.hasInt64List(), format);
            List valueList = feature.getInt64List().getValueList();
            return z ? valueList.toArray(new Long[0]) : aggFunc.aggregate(valueList);
        }
        if (typeInformation.equals(InternalTypeInfo.of(DataTypes.FLOAT().getLogicalType())) || typeInformation.equals(BasicTypeInfo.FLOAT_TYPE_INFO)) {
            Preconditions.checkArgument(feature.hasFloatList(), format);
            float[] fArr = new float[feature.getFloatList().getValueCount()];
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = feature.getFloatList().getValue(i2);
            }
            return z ? fArr : aggFunc.aggregate(feature.getFloatList().getValueList());
        }
        if (!typeInformation.equals(InternalTypeInfo.of(DataTypes.BYTES().getLogicalType())) && !typeInformation.equals(BasicTypeInfo.BYTE_TYPE_INFO)) {
            throw new IllegalArgumentException("Unsupported type " + typeInformation.toString());
        }
        Preconditions.checkArgument(feature.hasBytesList(), format);
        byte[] bArr = new byte[feature.getBytesList().getValueCount()];
        for (int i3 = 0; i3 < bArr.length; i3++) {
            bArr[i3] = feature.getBytesList().getValue(i3).toByteArray();
        }
        return z ? bArr : ((ByteString) aggFunc.aggregate(feature.getBytesList().getValueList())).toByteArray();
    }
}
