package org.flinkextended.flink.ml.operator.util;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.PojoField;
import org.apache.flink.api.java.typeutils.PojoTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo;
import org.apache.flink.types.Row;
import org.flinkextended.flink.ml.util.MLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Product;

/* loaded from: input_file:org/flinkextended/flink/ml/operator/util/ColumnInfos.class */
public class ColumnInfos implements Serializable {
    public static final Logger LOG = LoggerFactory.getLogger(ColumnInfos.class);
    private List<String> colNames = new ArrayList();
    private List<TypeInformation> tiInfos = new ArrayList();
    private TypeInformation originalTI = null;
    private boolean decomposed = false;

    public int count() {
        return this.colNames.size();
    }

    public String getColName(int i) {
        return this.colNames.get(i);
    }

    public TypeInformation getTiInfo(int i) {
        return this.tiInfos.get(i);
    }

    public DataTypes getDataTypes(int i) {
        return DataTypeConversion.fromJavaClass(getTiInfo(i).getTypeClass());
    }

    public static ColumnInfos fromTypeInformation(TypeInformation typeInformation) {
        ColumnInfos columnInfos = new ColumnInfos();
        columnInfos.originalTI = typeInformation;
        if ((typeInformation instanceof CaseClassTypeInfo) || (typeInformation instanceof TupleTypeInfo) || (typeInformation instanceof PojoTypeInfo) || (typeInformation instanceof RowTypeInfo)) {
            CompositeType compositeType = (CompositeType) typeInformation;
            String[] fieldNames = compositeType.getFieldNames();
            for (int i = 0; i < fieldNames.length; i++) {
                columnInfos.colNames.add(fieldNames[i]);
                columnInfos.tiInfos.add(compositeType.getTypeAt(i));
            }
            columnInfos.decomposed = true;
        } else {
            columnInfos.colNames.add("input");
            columnInfos.tiInfos.add(typeInformation);
        }
        columnInfos.checkColumns();
        return columnInfos;
    }

    public static ColumnInfos dummy() {
        return new ColumnInfos();
    }

    public Object getField(Object obj, int i) {
        if (!this.decomposed) {
            return ((Object[]) obj)[i];
        }
        if (this.originalTI instanceof PojoTypeInfo) {
            PojoField pojoFieldAt = this.originalTI.getPojoFieldAt(i);
            try {
                return pojoFieldAt.getField().get(obj);
            } catch (IllegalAccessException e) {
                LOG.error("Fail to get field " + pojoFieldAt.toString(), e);
            }
        } else {
            if (this.originalTI instanceof CaseClassTypeInfo) {
                CaseClassTypeInfo caseClassTypeInfo = this.originalTI;
                return ((Product) obj).productElement(i);
            }
            if (this.originalTI instanceof TupleTypeInfo) {
                return ((Tuple) obj).getField(i);
            }
            if (this.originalTI instanceof RowTypeInfo) {
                return ((Row) obj).getField(i);
            }
        }
        return obj;
    }

    public Object createResultObject(List<Object> list, ExecutionConfig executionConfig) throws MLException {
        if (list.size() != count()) {
            throw new MLException("Invalid field number for create object for class " + this.originalTI.getTypeClass() + ". Needs " + count() + " fields, while having " + list.size() + " fields.");
        }
        if (this.decomposed) {
            if (this.originalTI instanceof PojoTypeInfo) {
                PojoTypeInfo pojoTypeInfo = this.originalTI;
                try {
                    Object newInstance = this.originalTI.getTypeClass().newInstance();
                    for (int i = 0; i < count(); i++) {
                        pojoTypeInfo.getPojoFieldAt(i).getField().set(newInstance, list.get(i));
                    }
                    return newInstance;
                } catch (Exception e) {
                    throw new MLException("Fail to initiate POJO object. The type is " + this.originalTI.getTypeClass(), e);
                }
            }
            if (this.originalTI instanceof CaseClassTypeInfo) {
                return this.originalTI.createSerializer(executionConfig).createInstance(list.toArray());
            }
            if (this.originalTI instanceof TupleTypeInfo) {
                try {
                    Tuple tuple = (Tuple) Tuple.getTupleClass(count()).newInstance();
                    for (int i2 = 0; i2 < count(); i2++) {
                        tuple.setField(list.get(i2), i2);
                    }
                    return tuple;
                } catch (IllegalAccessException | InstantiationException e2) {
                    throw new MLException("Failed to create Tuple object for type " + this.originalTI.getTypeClass().getCanonicalName(), e2);
                }
            }
            if (this.originalTI instanceof RowTypeInfo) {
                Row row = new Row(count());
                for (int i3 = 0; i3 < count(); i3++) {
                    row.setField(i3, list.get(i3));
                }
                return row;
            }
        }
        return list.get(0);
    }

    private void checkColumns() {
        HashSet hashSet = new HashSet();
        List list = (List) this.colNames.stream().filter(str -> {
            return !hashSet.add(str);
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            throw new IllegalArgumentException("Found duplicated column name(s): " + Joiner.on(", ").join(list));
        }
        Iterator<TypeInformation> it = this.tiInfos.iterator();
        while (it.hasNext()) {
            Class typeClass = it.next().getTypeClass();
            Preconditions.checkArgument(DataTypeConversion.fromJavaClass(typeClass) != null, "Data type " + typeClass.getName() + " CAN NOT convert to a Tensorflow data type.");
        }
    }

    public TypeInformation getOriginalTI() {
        return this.originalTI;
    }

    public boolean isDecomposed() {
        return this.decomposed;
    }

    public Map<String, String> getNameToTypeMap() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.colNames.size(); i++) {
            hashMap.put(getColName(i), getDataTypes(i).toString());
        }
        return hashMap;
    }
}
