package org.deeplearning4j.spark.util;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Array;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction2;
import org.deeplearning4j.spark.impl.common.repartition.AssignIndexFunction;
import org.deeplearning4j.spark.impl.common.repartition.BalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.MapTupleToPairFlatMap;
import org.slf4j.Logger;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/util/SparkUtils.class */
public class SparkUtils {
    private SparkUtils() {
    }

    public static boolean checkKryoConfiguration(JavaSparkContext javaSparkContext, Logger logger) {
        String str = javaSparkContext.getConf().get("spark.serializer", (String) null);
        if (str == null || !str.equals("org.apache.spark.serializer.KryoSerializer")) {
            return true;
        }
        String str2 = javaSparkContext.getConf().get("spark.kryo.registrator", (String) null);
        if (str2 != null && str2.equals("org.nd4j.Nd4jRegistrator")) {
            return true;
        }
        logger.warn("***** Kryo serialization detected without Nd4j Registrator *****");
        logger.warn("***** ND4J Kryo registrator is required to avoid serialization (NullPointerException) issues on NDArrays *****");
        logger.warn("***** Use nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\"); *****");
        return false;
    }

    public static void writeStringToFile(String str, String str2, JavaSparkContext javaSparkContext) throws IOException {
        writeStringToFile(str, str2, javaSparkContext.sc());
    }

    public static void writeStringToFile(String str, String str2, SparkContext sparkContext) throws IOException {
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(FileSystem.get(sparkContext.hadoopConfiguration()).create(new Path(str)));
        Throwable th = null;
        try {
            try {
                bufferedOutputStream.write(str2.getBytes("UTF-8"));
                if (bufferedOutputStream != null) {
                    if (0 == 0) {
                        bufferedOutputStream.close();
                        return;
                    }
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (bufferedOutputStream != null) {
                if (th != null) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th4;
        }
    }

    public static String readStringFromFile(String str, JavaSparkContext javaSparkContext) throws IOException {
        return readStringFromFile(str, javaSparkContext.sc());
    }

    public static String readStringFromFile(String str, SparkContext sparkContext) throws IOException {
        BufferedInputStream bufferedInputStream = new BufferedInputStream(FileSystem.get(sparkContext.hadoopConfiguration()).open(new Path(str)));
        Throwable th = null;
        try {
            try {
                String str2 = new String(IOUtils.toByteArray(bufferedInputStream), "UTF-8");
                if (bufferedInputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedInputStream.close();
                    }
                }
                return str2;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (th != null) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    public static void writeObjectToFile(String str, Object obj, JavaSparkContext javaSparkContext) throws IOException {
        writeObjectToFile(str, obj, javaSparkContext.sc());
    }

    public static void writeObjectToFile(String str, Object obj, SparkContext sparkContext) throws IOException {
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(FileSystem.get(sparkContext.hadoopConfiguration()).create(new Path(str)));
        Throwable th = null;
        try {
            new ObjectOutputStream(bufferedOutputStream).writeObject(obj);
            if (bufferedOutputStream != null) {
                if (0 == 0) {
                    bufferedOutputStream.close();
                    return;
                }
                try {
                    bufferedOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th3;
        }
    }

    public static <T> T readObjectFromFile(String str, Class<T> cls, JavaSparkContext javaSparkContext) throws IOException {
        return (T) readObjectFromFile(str, cls, javaSparkContext.sc());
    }

    public static <T> T readObjectFromFile(String str, Class<T> cls, SparkContext sparkContext) throws IOException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(FileSystem.get(sparkContext.hadoopConfiguration()).open(new Path(str))));
        Throwable th = null;
        try {
            try {
                return (T) objectInputStream.readObject();
            } finally {
                if (objectInputStream != null) {
                    if (0 != 0) {
                        try {
                            objectInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        objectInputStream.close();
                    }
                }
            }
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static <T> JavaRDD<T> repartition(JavaRDD<T> javaRDD, Repartition repartition, RepartitionStrategy repartitionStrategy, int i, int i2) {
        if (repartition == Repartition.Never) {
            return javaRDD;
        }
        switch (repartitionStrategy) {
            case SparkDefault:
                return (repartition == Repartition.NumPartitionsWorkersDiffers && javaRDD.partitions().size() == i2) ? javaRDD : javaRDD.repartition(i2);
            case Balanced:
                return repartitionBalanceIfRequired(javaRDD, repartition, i, i2);
            default:
                throw new RuntimeException("Unknown repartition strategy: " + repartitionStrategy);
        }
    }

    public static <T> JavaRDD<T> repartitionBalanceIfRequired(JavaRDD<T> javaRDD, Repartition repartition, int i, int i2) {
        int size = javaRDD.partitions().size();
        switch (repartition) {
            case Never:
                return javaRDD;
            case NumPartitionsWorkersDiffers:
                if (size == i2) {
                    return javaRDD;
                }
                break;
            case Always:
                break;
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
        javaRDD.persist(StorageLevel.MEMORY_ONLY());
        List<Tuple2> collect = javaRDD.mapPartitionsWithIndex(new CountPartitionsFunction(), true).collect();
        int i3 = 0;
        int size2 = collect.size();
        boolean z = true;
        int[] iArr = new int[collect.size()];
        int i4 = 0;
        for (Tuple2 tuple2 : collect) {
            int intValue = ((Integer) tuple2._2()).intValue();
            int i5 = i4;
            i4++;
            iArr[i5] = intValue;
            z &= intValue == i;
            i3 += ((Integer) tuple2._2()).intValue();
        }
        if (i2 * i < i3) {
            int i6 = (i3 - (i2 * i)) / i2;
            if ((i3 - (i2 * i)) % i2 != 0) {
                i6++;
            }
            i += i6;
            z = true;
            Iterator it = collect.iterator();
            while (it.hasNext()) {
                z &= ((Integer) ((Tuple2) it.next())._2()).intValue() == i;
            }
        }
        if (i2 * i < i3) {
            throw new RuntimeException();
        }
        if (size2 == i2 && z) {
            return javaRDD;
        }
        int[] iArr2 = new int[iArr.length];
        for (int i7 = 1; i7 < iArr2.length; i7++) {
            iArr2[i7] = iArr2[i7 - 1] + iArr[i7 - 1];
        }
        JavaPairRDD mapPartitionsToPair = javaRDD.mapPartitionsWithIndex(new AssignIndexFunction(iArr2), true).mapPartitionsToPair(new MapTupleToPairFlatMap(), true);
        int i8 = i3 / i;
        if (i3 % i != 0) {
            i8++;
        }
        return mapPartitionsToPair.partitionBy(new BalancedPartitioner(i2, i8, i)).values();
    }

    public static <T> JavaRDD<T>[] balancedRandomSplit(int i, int i2, JavaRDD<T> javaRDD) {
        return balancedRandomSplit(i, i2, javaRDD, new Random().nextLong());
    }

    public static <T> JavaRDD<T>[] balancedRandomSplit(int i, int i2, JavaRDD<T> javaRDD, long j) {
        JavaRDD<T>[] javaRDDArr;
        if (i <= i2) {
            javaRDDArr = (JavaRDD[]) Array.newInstance((Class<?>) JavaRDD.class, 1);
            javaRDDArr[0] = javaRDD;
        } else {
            int i3 = i / i2;
            javaRDDArr = (JavaRDD[]) Array.newInstance((Class<?>) JavaRDD.class, i3);
            for (int i4 = 0; i4 < i3; i4++) {
                javaRDDArr[i4] = javaRDD.mapPartitionsWithIndex(new SplitPartitionsFunction(i4, i3, j), true);
            }
        }
        return javaRDDArr;
    }

    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int i, int i2, JavaPairRDD<T, U> javaPairRDD) {
        return balancedRandomSplit(i, i2, javaPairRDD, new Random().nextLong());
    }

    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int i, int i2, JavaPairRDD<T, U> javaPairRDD, long j) {
        JavaPairRDD<T, U>[] javaPairRDDArr;
        if (i <= i2) {
            javaPairRDDArr = (JavaPairRDD[]) Array.newInstance((Class<?>) JavaPairRDD.class, 1);
            javaPairRDDArr[0] = javaPairRDD;
        } else {
            int i3 = i / i2;
            javaPairRDDArr = (JavaPairRDD[]) Array.newInstance((Class<?>) JavaPairRDD.class, i3);
            for (int i4 = 0; i4 < i3; i4++) {
                javaPairRDDArr[i4] = javaPairRDD.mapPartitionsWithIndex(new SplitPartitionsFunction2(i4, i3, j), true).mapPartitionsToPair(new MapTupleToPairFlatMap(), true);
            }
        }
        return javaPairRDDArr;
    }
}
