package org.flinkextended.flink.ml.tensorflow.util;

import com.google.common.base.Preconditions;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import org.tensorflow.DataType;
import org.tensorflow.Tensor;
import org.tensorflow.framework.TensorInfo;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/TFTensorConversion.class */
public class TFTensorConversion {

    /* renamed from: org.flinkextended.flink.ml.tensorflow.util.TFTensorConversion$1, reason: invalid class name */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/TFTensorConversion$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$framework$DataType;
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.INT32.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.INT64.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.DOUBLE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$org$tensorflow$framework$DataType = new int[org.tensorflow.framework.DataType.values().length];
            try {
                $SwitchMap$org$tensorflow$framework$DataType[org.tensorflow.framework.DataType.DT_INT8.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$DataType[org.tensorflow.framework.DataType.DT_INT16.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$DataType[org.tensorflow.framework.DataType.DT_INT32.ordinal()] = 3;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$DataType[org.tensorflow.framework.DataType.DT_INT64.ordinal()] = 4;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$DataType[org.tensorflow.framework.DataType.DT_FLOAT.ordinal()] = 5;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$DataType[org.tensorflow.framework.DataType.DT_DOUBLE.ordinal()] = 6;
            } catch (NoSuchFieldError e10) {
            }
        }
    }

    private TFTensorConversion() {
    }

    public static Tensor<?> toTensor(Object[] objArr, TensorInfo tensorInfo) {
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$framework$DataType[tensorInfo.getDtype().ordinal()]) {
            case 1:
            case 2:
            case 3:
                long[] jArr = {objArr.length, ((int[]) objArr[0]).length};
                IntBuffer allocate = IntBuffer.allocate(getCapacity(jArr));
                for (Object obj : objArr) {
                    allocate.put((int[]) obj);
                }
                allocate.flip();
                return Tensor.create(jArr, allocate);
            case 4:
                long[] jArr2 = {objArr.length, ((long[]) objArr[0]).length};
                LongBuffer allocate2 = LongBuffer.allocate(getCapacity(jArr2));
                for (Object obj2 : objArr) {
                    allocate2.put((long[]) obj2);
                }
                allocate2.flip();
                return Tensor.create(jArr2, allocate2);
            case 5:
                long[] jArr3 = {objArr.length, ((float[]) objArr[0]).length};
                FloatBuffer allocate3 = FloatBuffer.allocate(getCapacity(jArr3));
                for (Object obj3 : objArr) {
                    allocate3.put((float[]) obj3);
                }
                allocate3.flip();
                return Tensor.create(jArr3, allocate3);
            case 6:
                long[] jArr4 = {objArr.length, ((double[]) objArr[0]).length};
                DoubleBuffer allocate4 = DoubleBuffer.allocate(getCapacity(jArr4));
                for (Object obj4 : objArr) {
                    allocate4.put((double[]) obj4);
                }
                allocate4.flip();
                return Tensor.create(jArr4, allocate4);
            default:
                throw new UnsupportedOperationException("Type can't be converted to tensor: " + tensorInfo.getDtype().name());
        }
    }

    public static Object[] fromTensor(Tensor<?> tensor) {
        Preconditions.checkArgument(tensor.shape().length == 1, "Can only convert tensors with shape long[]");
        int i = (int) tensor.shape()[0];
        Object[] objArr = new Object[i];
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$DataType[tensor.dataType().ordinal()]) {
            case 1:
                int[] iArr = (int[]) tensor.copyTo(new int[i]);
                for (int i2 = 0; i2 < i; i2++) {
                    objArr[i2] = Integer.valueOf(iArr[i2]);
                }
                break;
            case 2:
                float[] fArr = (float[]) tensor.copyTo(new float[i]);
                for (int i3 = 0; i3 < i; i3++) {
                    objArr[i3] = Float.valueOf(fArr[i3]);
                }
                break;
            case 3:
                long[] jArr = (long[]) tensor.copyTo(new long[i]);
                for (int i4 = 0; i4 < i; i4++) {
                    objArr[i4] = Long.valueOf(jArr[i4]);
                }
                break;
            case 4:
                double[] dArr = (double[]) tensor.copyTo(new double[i]);
                for (int i5 = 0; i5 < i; i5++) {
                    objArr[i5] = Double.valueOf(dArr[i5]);
                }
                break;
            default:
                throw new UnsupportedOperationException("Type can't be converted from tensor: " + tensor.dataType().name());
        }
        return objArr;
    }

    private static int getCapacity(long[] jArr) {
        if (jArr == null || jArr.length == 0) {
            return 0;
        }
        long j = jArr[0];
        for (int i = 1; i < jArr.length; i++) {
            j *= jArr[i];
        }
        Preconditions.checkArgument(j >= 0 && j <= 2147483647L, "Invalid shape: " + Arrays.toString(jArr));
        return (int) j;
    }
}
