package ai.sklearn4j.core.packaging;

import ai.sklearn4j.core.ScikitLearnCoreException;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ai/sklearn4j/core/packaging/BinaryModelPackage.class */
public class BinaryModelPackage {
    private static final int ELEMENT_TYPE_BYTE = 1;
    private static final int ELEMENT_TYPE_SHORT = 2;
    private static final int ELEMENT_TYPE_INT = 4;
    private static final int ELEMENT_TYPE_LONG = 8;
    private static final int ELEMENT_TYPE_UNSIGNED_BYTE = 17;
    private static final int ELEMENT_TYPE_UNSIGNED_SHORT = 18;
    private static final int ELEMENT_TYPE_UNSIGNED_INT = 20;
    private static final int ELEMENT_TYPE_UNSIGNED_LONG = 24;
    private static final int ELEMENT_TYPE_FLOAT = 32;
    private static final int ELEMENT_TYPE_DOUBLE = 33;
    private static final int ELEMENT_TYPE_STRING = 48;
    private static final int ELEMENT_TYPE_LIST = 64;
    private static final int ELEMENT_TYPE_DICTIONARY = 65;
    private static final int ELEMENT_TYPE_NUMPY_ARRAY = 66;
    private static final int ELEMENT_TYPE_STRING_ARRAY = 67;
    private static final int ELEMENT_TYPE_NULL = 16;
    private final InputStream stream;

    private BinaryModelPackage(InputStream inputStream) {
        this.stream = inputStream;
    }

    public static BinaryModelPackage fromFile(String str) {
        try {
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(str));
            byte[] bArr = new byte[bufferedInputStream.available()];
            bufferedInputStream.read(bArr);
            BinaryModelPackage fromStream = fromStream(new ByteArrayInputStream(bArr));
            bufferedInputStream.close();
            return fromStream;
        } catch (IOException e) {
            throw new ScikitLearnCoreException("An error occurred while loading a package from file:\n" + e.getMessage());
        }
    }

    public static BinaryModelPackage fromStream(InputStream inputStream) {
        return new BinaryModelPackage(inputStream);
    }

    public byte readByte() {
        return readBuffer(1)[0];
    }

    public short readShort() {
        int i = 0;
        byte[] readBuffer = readBuffer(2);
        for (int i2 = 0; i2 < 2; i2++) {
            i = (readBuffer[(2 - 1) - i2] & 255) + (i * 256);
        }
        return (short) i;
    }

    public int readInteger() {
        int i = 0;
        byte[] readBuffer = readBuffer(4);
        for (int i2 = 0; i2 < 4; i2++) {
            i = (readBuffer[(4 - 1) - i2] & 255) + (i * 256);
        }
        return i;
    }

    public long readLongInteger() {
        long j = 0;
        byte[] readBuffer = readBuffer(8);
        for (int i = 0; i < 8; i++) {
            j = (readBuffer[(8 - 1) - i] & 255) + (j * 256);
        }
        return j;
    }

    public float readFloat() {
        float f = Float.NaN;
        if (readByte() == 1) {
            f = Float.intBitsToFloat(readInteger());
        }
        return f;
    }

    public double readDouble() {
        double d = Double.NaN;
        if (readByte() == 1) {
            d = Double.longBitsToDouble(readLongInteger());
        }
        return d;
    }

    public String readString() {
        String str = null;
        if (readByte() == 1) {
            str = new String(readBuffer(readInteger()), StandardCharsets.UTF_8);
        }
        return str;
    }

    public NumpyArray readNumpyArray() {
        NumpyArray numpyArray = null;
        if (readByte() == 1) {
            int[] iArr = new int[readInteger()];
            byte readByte = readByte();
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = readInteger();
            }
            numpyArray = createNumpyArray(readByte, iArr);
            readNumpyDataFromStream(numpyArray.getWrapper().getRawArray(), iArr, 0, readByte);
        }
        return numpyArray;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private NumpyArray createNumpyArray(int i, int[] iArr) {
        NumpyArray arrayOfInt8WithShape;
        if (i == 1 || i == ELEMENT_TYPE_UNSIGNED_BYTE) {
            arrayOfInt8WithShape = NumpyArrayFactory.arrayOfInt8WithShape(iArr);
        } else if (i == 2 || i == ELEMENT_TYPE_UNSIGNED_SHORT) {
            arrayOfInt8WithShape = NumpyArrayFactory.arrayOfInt16WithShape(iArr);
        } else if (i == 4 || i == ELEMENT_TYPE_UNSIGNED_INT) {
            arrayOfInt8WithShape = NumpyArrayFactory.arrayOfInt32WithShape(iArr);
        } else if (i == 8 || i == ELEMENT_TYPE_UNSIGNED_LONG) {
            arrayOfInt8WithShape = NumpyArrayFactory.arrayOfInt64WithShape(iArr);
        } else if (i == ELEMENT_TYPE_FLOAT) {
            arrayOfInt8WithShape = NumpyArrayFactory.arrayOfFloatWithShape(iArr);
        } else {
            if (i != ELEMENT_TYPE_DOUBLE) {
                throw new ScikitLearnCoreException(String.format("Numpy array with element type %d is not supported.", Integer.valueOf(i)));
            }
            arrayOfInt8WithShape = NumpyArrayFactory.arrayOfDoubleWithShape(iArr);
        }
        return arrayOfInt8WithShape;
    }

    public List<Object> readList() {
        ArrayList arrayList = null;
        if (readByte() == 1) {
            arrayList = new ArrayList();
            int readInteger = readInteger();
            for (int i = 0; i < readInteger; i++) {
                byte readByte = readByte();
                if (readByte == ELEMENT_TYPE_NULL) {
                    arrayList.add(null);
                } else {
                    arrayList.add(getPrimitiveDataReader(readByte).readPrimitiveValue());
                }
            }
        }
        return arrayList;
    }

    public Map<String, Object> readDictionary() {
        HashMap hashMap = null;
        if (readByte() == 1) {
            hashMap = new HashMap();
            int readInteger = readInteger();
            for (int i = 0; i < readInteger; i++) {
                String readString = readString();
                byte readByte = readByte();
                if (readByte == ELEMENT_TYPE_NULL) {
                    hashMap.put(readString, null);
                } else if (readByte == ELEMENT_TYPE_STRING_ARRAY) {
                    hashMap.put(readString, readStringArray());
                } else {
                    hashMap.put(readString, getPrimitiveDataReader(readByte).readPrimitiveValue());
                }
            }
        }
        return hashMap;
    }

    public String[] readStringArray() {
        String[] strArr = null;
        if (readByte() == 1) {
            int readInteger = readInteger();
            strArr = new String[readInteger];
            for (int i = 0; i < readInteger; i++) {
                strArr[i] = readString();
            }
        }
        return strArr;
    }

    private void readNumpyDataFromStream(Object obj, int[] iArr, int i, int i2) {
        if (i != iArr.length - 1) {
            for (int i3 = 0; i3 < iArr[i]; i3++) {
                readNumpyDataFromStream(Array.get(obj, i3), iArr, i + 1, i2);
            }
            return;
        }
        IBinaryModelPackagePrimitiveValueReader primitiveDataReader = getPrimitiveDataReader(i2);
        int i4 = iArr[i];
        for (int i5 = 0; i5 < i4; i5++) {
            Array.set(obj, i5, primitiveDataReader.readPrimitiveValue());
        }
    }

    private IBinaryModelPackagePrimitiveValueReader getPrimitiveDataReader(int i) {
        IBinaryModelPackagePrimitiveValueReader iBinaryModelPackagePrimitiveValueReader;
        if (i == 1 || i == ELEMENT_TYPE_UNSIGNED_BYTE) {
            iBinaryModelPackagePrimitiveValueReader = this::readByte;
        } else if (i == 2 || i == ELEMENT_TYPE_UNSIGNED_SHORT) {
            iBinaryModelPackagePrimitiveValueReader = this::readShort;
        } else if (i == 4 || i == ELEMENT_TYPE_UNSIGNED_INT) {
            iBinaryModelPackagePrimitiveValueReader = this::readInteger;
        } else if (i == 8 || i == ELEMENT_TYPE_UNSIGNED_LONG) {
            iBinaryModelPackagePrimitiveValueReader = this::readLongInteger;
        } else if (i == ELEMENT_TYPE_FLOAT) {
            iBinaryModelPackagePrimitiveValueReader = this::readFloat;
        } else if (i == ELEMENT_TYPE_DOUBLE) {
            iBinaryModelPackagePrimitiveValueReader = this::readDouble;
        } else if (i == ELEMENT_TYPE_STRING) {
            iBinaryModelPackagePrimitiveValueReader = this::readString;
        } else if (i == ELEMENT_TYPE_DICTIONARY) {
            iBinaryModelPackagePrimitiveValueReader = this::readDictionary;
        } else if (i == ELEMENT_TYPE_NUMPY_ARRAY) {
            iBinaryModelPackagePrimitiveValueReader = this::readNumpyArray;
        } else {
            if (i != ELEMENT_TYPE_LIST) {
                throw new ScikitLearnCoreException(String.format("Numpy array with element type %d is not supported.", Integer.valueOf(i)));
            }
            iBinaryModelPackagePrimitiveValueReader = this::readList;
        }
        return iBinaryModelPackagePrimitiveValueReader;
    }

    private byte[] readBuffer(int i) {
        byte[] bArr = new byte[i];
        try {
            if (this.stream.read(bArr) != i) {
                throw new ScikitLearnCoreException(String.format("Unable to read %d bytes from the stream.", Integer.valueOf(i)));
            }
            return bArr;
        } catch (IOException e) {
            throw new ScikitLearnCoreException("Unable to read from buffer.");
        }
    }

    public boolean canRead() {
        try {
            return this.stream.available() > 0;
        } catch (IOException e) {
            throw new ScikitLearnCoreException("An error occurred while assessing if the stream reached end or not:\n" + e.getMessage());
        }
    }
}
