package org.flinkextended.flink.ml.tensorflow.util;

import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Test;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/TFTensorConversionTest.class */
public class TFTensorConversionTest {
    /* JADX WARN: Type inference failed for: r0v1, types: [int[], java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r0v13, types: [float[], java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r0v19, types: [double[], java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r0v7, types: [long[], java.lang.Object[]] */
    @Test
    public void testToTensor() {
        Tensor tensor = TFTensorConversion.toTensor((Object[]) new int[]{new int[]{1, 2, 3}}, TensorInfo.newBuilder().setDtype(DataType.DT_INT32).build());
        Assert.assertEquals(2L, tensor.shape().numDimensions());
        Assert.assertEquals(1L, tensor.shape().size(0));
        Assert.assertEquals(3L, tensor.shape().size(1));
        Tensor tensor2 = TFTensorConversion.toTensor((Object[]) new long[]{new long[]{1, 2, 3}}, TensorInfo.newBuilder().setDtype(DataType.DT_INT64).build());
        Assert.assertEquals(2L, tensor2.shape().numDimensions());
        Assert.assertEquals(1L, tensor2.shape().size(0));
        Assert.assertEquals(3L, tensor2.shape().size(1));
        Tensor tensor3 = TFTensorConversion.toTensor((Object[]) new float[]{new float[]{1.0f, 2.0f, 3.0f}}, TensorInfo.newBuilder().setDtype(DataType.DT_FLOAT).build());
        Assert.assertEquals(2L, tensor3.shape().numDimensions());
        Assert.assertEquals(1L, tensor3.shape().size(0));
        Assert.assertEquals(3L, tensor3.shape().size(1));
        Tensor tensor4 = TFTensorConversion.toTensor((Object[]) new double[]{new double[]{1.0d, 2.0d, 3.0d}}, TensorInfo.newBuilder().setDtype(DataType.DT_DOUBLE).build());
        Assert.assertEquals(2L, tensor4.shape().numDimensions());
        Assert.assertEquals(1L, tensor4.shape().size(0));
        Assert.assertEquals(3L, tensor4.shape().size(1));
    }

    @Test
    public void testFromTensor() {
        Object[] fromTensor = TFTensorConversion.fromTensor(TInt32.tensorOf(Shape.of(new long[]{3})));
        Assert.assertEquals(3L, fromTensor.length);
        Assert.assertThat(fromTensor[0], CoreMatchers.instanceOf(Integer.class));
        Object[] fromTensor2 = TFTensorConversion.fromTensor(TFloat32.tensorOf(Shape.of(new long[]{3})));
        Assert.assertEquals(3L, fromTensor2.length);
        Assert.assertThat(fromTensor2[0], CoreMatchers.instanceOf(Float.class));
        Object[] fromTensor3 = TFTensorConversion.fromTensor(TInt64.tensorOf(Shape.of(new long[]{3})));
        Assert.assertEquals(3L, fromTensor3.length);
        Assert.assertThat(fromTensor3[0], CoreMatchers.instanceOf(Long.class));
        Object[] fromTensor4 = TFTensorConversion.fromTensor(TFloat64.tensorOf(Shape.of(new long[]{3})));
        Assert.assertEquals(3L, fromTensor4.length);
        Assert.assertThat(fromTensor4[0], CoreMatchers.instanceOf(Double.class));
    }
}
