package org.deeplearning4j.util;

import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/deeplearning4j/util/NetSaverLoaderUtils.class */
public class NetSaverLoaderUtils {
    private static final Logger log = LoggerFactory.getLogger(NetSaverLoaderUtils.class);

    private NetSaverLoaderUtils() {
    }

    public static void saveNetworkAndParameters(MultiLayerNetwork multiLayerNetwork, String str) {
        String concat = FilenameUtils.concat(str, multiLayerNetwork.toString() + "-conf.json");
        String concat2 = FilenameUtils.concat(str, multiLayerNetwork.toString() + ".bin");
        log.info("Saving model and parameters to {} and {} ...", concat, concat2);
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(concat2)));
            Throwable th = null;
            try {
                try {
                    Nd4j.write(multiLayerNetwork.params(), dataOutputStream);
                    dataOutputStream.flush();
                    FileUtils.write(new File(concat), multiLayerNetwork.getLayerWiseConfigurations().toJson());
                    if (dataOutputStream != null) {
                        if (0 != 0) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static MultiLayerNetwork loadNetworkAndParameters(String str, String str2) {
        log.info("Loading saved model and parameters...");
        MultiLayerNetwork multiLayerNetwork = null;
        try {
            MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(str);
            DataInputStream dataInputStream = new DataInputStream(new FileInputStream(str2));
            INDArray read = Nd4j.read(dataInputStream);
            dataInputStream.close();
            multiLayerNetwork = new MultiLayerNetwork(fromJson);
            multiLayerNetwork.init();
            multiLayerNetwork.setParams(read);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return multiLayerNetwork;
    }

    public static void saveUpdators(MultiLayerNetwork multiLayerNetwork, String str) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(new File(FilenameUtils.concat(str, multiLayerNetwork.toString() + "updators.bin"))));
            Throwable th = null;
            try {
                try {
                    objectOutputStream.writeObject(multiLayerNetwork.getUpdater());
                    if (objectOutputStream != null) {
                        if (0 != 0) {
                            try {
                                objectOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            objectOutputStream.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static Updater loadUpdators(String str) {
        Updater updater = null;
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(new File(str)));
            Throwable th = null;
            try {
                try {
                    updater = (Updater) objectInputStream.readObject();
                    if (objectInputStream != null) {
                        if (0 != 0) {
                            try {
                                objectInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            objectInputStream.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e2) {
            e2.printStackTrace();
        }
        return updater;
    }

    public static void saveLayerParameters(INDArray iNDArray, String str) {
        log.info("Saving parameters to {} ...", str);
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(str)));
            Throwable th = null;
            try {
                try {
                    Nd4j.write(iNDArray, dataOutputStream);
                    dataOutputStream.flush();
                    if (dataOutputStream != null) {
                        if (0 != 0) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static Layer loadLayerParameters(Layer layer, String str) {
        log.info("Loading saved parameters for layer {} ...", layer.conf().getLayer().getLayerName());
        try {
            DataInputStream dataInputStream = new DataInputStream(new FileInputStream(str));
            INDArray read = Nd4j.read(dataInputStream);
            dataInputStream.close();
            layer.setParams(read);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return layer;
    }

    public static void saveParameters(MultiLayerNetwork multiLayerNetwork, int[] iArr, Map<Integer, String> map) {
        for (int i : iArr) {
            Layer layer = multiLayerNetwork.getLayer(i);
            if (!layer.paramTable().isEmpty()) {
                saveLayerParameters(layer.params(), map.get(Integer.valueOf(i)));
            }
        }
    }

    public static void saveParameters(MultiLayerNetwork multiLayerNetwork, String[] strArr, Map<String, String> map) {
        for (String str : strArr) {
            Layer layer = multiLayerNetwork.getLayer(str);
            if (!layer.paramTable().isEmpty()) {
                saveLayerParameters(layer.params(), map.get(str));
            }
        }
    }

    public static MultiLayerNetwork loadParameters(MultiLayerNetwork multiLayerNetwork, int[] iArr, Map<Integer, String> map) {
        for (int i : iArr) {
            loadLayerParameters(multiLayerNetwork.getLayer(i), map.get(Integer.valueOf(i)));
        }
        return multiLayerNetwork;
    }

    public static MultiLayerNetwork loadParameters(MultiLayerNetwork multiLayerNetwork, String[] strArr, Map<String, String> map) {
        for (String str : strArr) {
            loadLayerParameters(multiLayerNetwork.getLayer(str), map.get(str));
        }
        return multiLayerNetwork;
    }

    public static Map<Integer, String> getIdParamPaths(String str, int[] iArr) {
        HashMap hashMap = new HashMap();
        for (int i : iArr) {
            hashMap.put(Integer.valueOf(i), FilenameUtils.concat(str, i + ".bin"));
        }
        return hashMap;
    }

    public static Map<String, String> getStringParamPaths(String str, String[] strArr) {
        HashMap hashMap = new HashMap();
        for (String str2 : strArr) {
            hashMap.put(str2, FilenameUtils.concat(str, str2 + ".bin"));
        }
        return hashMap;
    }

    public static String defineOutputDir(String str) {
        File file = new File(System.getProperty("java.io.tmpdir"), File.separator + str + File.separator + "output");
        if (!file.getParentFile().exists()) {
            file.mkdirs();
        }
        return file.toString();
    }
}
