package org.nd4j.linalg.dataset.api.preprocessor.serializer;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.linalg.dataset.api.preprocessor.MinMaxStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.StandardizeStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/serializer/MultiHybridSerializerStrategy.class */
public class MultiHybridSerializerStrategy implements NormalizerSerializerStrategy<MultiNormalizerHybrid> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/serializer/MultiHybridSerializerStrategy$Strategy.class */
    public enum Strategy {
        NULL,
        STANDARDIZE,
        MIN_MAX
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializerStrategy
    public void write(@NonNull MultiNormalizerHybrid multiNormalizerHybrid, @NonNull OutputStream outputStream) throws IOException {
        if (multiNormalizerHybrid == null) {
            throw new NullPointerException("normalizer");
        }
        if (outputStream == null) {
            throw new NullPointerException("stream");
        }
        DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
        Throwable th = null;
        try {
            try {
                writeStatsMap(multiNormalizerHybrid.getInputStats(), dataOutputStream);
                writeStatsMap(multiNormalizerHybrid.getOutputStats(), dataOutputStream);
                writeStrategy(multiNormalizerHybrid.getGlobalInputStrategy(), dataOutputStream);
                writeStrategy(multiNormalizerHybrid.getGlobalOutputStrategy(), dataOutputStream);
                writeStrategyMap(multiNormalizerHybrid.getPerInputStrategies(), dataOutputStream);
                writeStrategyMap(multiNormalizerHybrid.getPerOutputStrategies(), dataOutputStream);
                if (dataOutputStream != null) {
                    if (0 == 0) {
                        dataOutputStream.close();
                        return;
                    }
                    try {
                        dataOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (dataOutputStream != null) {
                if (th != null) {
                    try {
                        dataOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    dataOutputStream.close();
                }
            }
            throw th4;
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializerStrategy
    public MultiNormalizerHybrid restore(@NonNull InputStream inputStream) throws IOException {
        if (inputStream == null) {
            throw new NullPointerException("stream");
        }
        DataInputStream dataInputStream = new DataInputStream(inputStream);
        MultiNormalizerHybrid multiNormalizerHybrid = new MultiNormalizerHybrid();
        multiNormalizerHybrid.setInputStats(readStatsMap(dataInputStream));
        multiNormalizerHybrid.setOutputStats(readStatsMap(dataInputStream));
        multiNormalizerHybrid.setGlobalInputStrategy(readStrategy(dataInputStream));
        multiNormalizerHybrid.setGlobalOutputStrategy(readStrategy(dataInputStream));
        multiNormalizerHybrid.setPerInputStrategies(readStrategyMap(dataInputStream));
        multiNormalizerHybrid.setPerOutputStrategies(readStrategyMap(dataInputStream));
        return multiNormalizerHybrid;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializerStrategy
    public NormalizerType getSupportedType() {
        return NormalizerType.MULTI_HYBRID;
    }

    private static void writeStatsMap(Map<Integer, NormalizerStats> map, DataOutputStream dataOutputStream) throws IOException {
        Set<Integer> keySet = map.keySet();
        dataOutputStream.writeInt(keySet.size());
        Iterator<Integer> it2 = keySet.iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            dataOutputStream.writeInt(intValue);
            writeNormalizerStats(map.get(Integer.valueOf(intValue)), dataOutputStream);
        }
    }

    private static Map<Integer, NormalizerStats> readStatsMap(DataInputStream dataInputStream) throws IOException {
        HashMap hashMap = new HashMap();
        int readInt = dataInputStream.readInt();
        for (int i = 0; i < readInt; i++) {
            hashMap.put(Integer.valueOf(dataInputStream.readInt()), readNormalizerStats(dataInputStream));
        }
        return hashMap;
    }

    private static void writeNormalizerStats(NormalizerStats normalizerStats, DataOutputStream dataOutputStream) throws IOException {
        if (normalizerStats instanceof DistributionStats) {
            writeDistributionStats((DistributionStats) normalizerStats, dataOutputStream);
        } else {
            if (!(normalizerStats instanceof MinMaxStats)) {
                throw new RuntimeException("Unsupported stats class " + normalizerStats.getClass());
            }
            writeMinMaxStats((MinMaxStats) normalizerStats, dataOutputStream);
        }
    }

    private static NormalizerStats readNormalizerStats(DataInputStream dataInputStream) throws IOException {
        Strategy strategy = Strategy.values()[dataInputStream.readInt()];
        switch (strategy) {
            case STANDARDIZE:
                return readDistributionStats(dataInputStream);
            case MIN_MAX:
                return readMinMaxStats(dataInputStream);
            default:
                throw new RuntimeException("Unsupported strategy " + strategy.name());
        }
    }

    private static void writeDistributionStats(DistributionStats distributionStats, DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeInt(Strategy.STANDARDIZE.ordinal());
        Nd4j.write(distributionStats.getMean(), dataOutputStream);
        Nd4j.write(distributionStats.getStd(), dataOutputStream);
    }

    private static NormalizerStats readDistributionStats(DataInputStream dataInputStream) throws IOException {
        return new DistributionStats(Nd4j.read(dataInputStream), Nd4j.read(dataInputStream));
    }

    private static void writeMinMaxStats(MinMaxStats minMaxStats, DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeInt(Strategy.MIN_MAX.ordinal());
        Nd4j.write(minMaxStats.getLower(), dataOutputStream);
        Nd4j.write(minMaxStats.getUpper(), dataOutputStream);
    }

    private static NormalizerStats readMinMaxStats(DataInputStream dataInputStream) throws IOException {
        return new MinMaxStats(Nd4j.read(dataInputStream), Nd4j.read(dataInputStream));
    }

    private static void writeStrategyMap(Map<Integer, NormalizerStrategy> map, DataOutputStream dataOutputStream) throws IOException {
        Set<Integer> keySet = map.keySet();
        dataOutputStream.writeInt(keySet.size());
        Iterator<Integer> it2 = keySet.iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            dataOutputStream.writeInt(intValue);
            writeStrategy(map.get(Integer.valueOf(intValue)), dataOutputStream);
        }
    }

    private static Map<Integer, NormalizerStrategy> readStrategyMap(DataInputStream dataInputStream) throws IOException {
        HashMap hashMap = new HashMap();
        int readInt = dataInputStream.readInt();
        for (int i = 0; i < readInt; i++) {
            hashMap.put(Integer.valueOf(dataInputStream.readInt()), readStrategy(dataInputStream));
        }
        return hashMap;
    }

    private static void writeStrategy(NormalizerStrategy normalizerStrategy, DataOutputStream dataOutputStream) throws IOException {
        if (normalizerStrategy == null) {
            writeNoStrategy(dataOutputStream);
        } else if (normalizerStrategy instanceof StandardizeStrategy) {
            writeStandardizeStrategy(dataOutputStream);
        } else {
            if (!(normalizerStrategy instanceof MinMaxStrategy)) {
                throw new RuntimeException("Unsupported strategy class " + normalizerStrategy.getClass());
            }
            writeMinMaxStrategy((MinMaxStrategy) normalizerStrategy, dataOutputStream);
        }
    }

    private static NormalizerStrategy readStrategy(DataInputStream dataInputStream) throws IOException {
        Strategy strategy = Strategy.values()[dataInputStream.readInt()];
        switch (strategy) {
            case STANDARDIZE:
                return readStandardizeStrategy();
            case MIN_MAX:
                return readMinMaxStrategy(dataInputStream);
            case NULL:
                return null;
            default:
                throw new RuntimeException("Unsupported strategy " + strategy.name());
        }
    }

    private static void writeNoStrategy(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeInt(Strategy.NULL.ordinal());
    }

    private static void writeStandardizeStrategy(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeInt(Strategy.STANDARDIZE.ordinal());
    }

    private static NormalizerStrategy readStandardizeStrategy() {
        return new StandardizeStrategy();
    }

    private static void writeMinMaxStrategy(MinMaxStrategy minMaxStrategy, DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeInt(Strategy.MIN_MAX.ordinal());
        dataOutputStream.writeDouble(minMaxStrategy.getMinRange());
        dataOutputStream.writeDouble(minMaxStrategy.getMaxRange());
    }

    private static NormalizerStrategy readMinMaxStrategy(DataInputStream dataInputStream) throws IOException {
        return new MinMaxStrategy(dataInputStream.readDouble(), dataInputStream.readDouble());
    }
}
