package de.rub.nds.tlsattacker.core.record.cipher;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.CipherAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ProtocolMessageType;
import de.rub.nds.tlsattacker.core.crypto.cipher.CipherWrapper;
import de.rub.nds.tlsattacker.core.exceptions.CryptoException;
import de.rub.nds.tlsattacker.core.protocol.Parser;
import de.rub.nds.tlsattacker.core.record.BlobRecord;
import de.rub.nds.tlsattacker.core.record.Record;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySet;
import de.rub.nds.tlsattacker.core.state.TlsContext;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.math.BigInteger;
import java.util.Arrays;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/rub/nds/tlsattacker/core/record/cipher/RecordAEADCipher.class */
public class RecordAEADCipher extends RecordCipher {
    private static final Logger LOGGER = LogManager.getLogger();
    private static final int AEAD_TAG_LENGTH = 16;
    private static final int AEAD_CCM_8_TAG_LENGTH = 8;
    private final int aeadTagLength;
    private final int aeadExplicitLength;

    /* loaded from: input_file:de/rub/nds/tlsattacker/core/record/cipher/RecordAEADCipher$DecryptionParser.class */
    class DecryptionParser extends Parser<Object> {
        public DecryptionParser(int i, byte[] bArr) {
            super(i, bArr);
        }

        @Override // de.rub.nds.tlsattacker.core.protocol.Parser
        public Object parse() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override // de.rub.nds.tlsattacker.core.protocol.Parser
        public byte[] parseByteArrayField(int i) {
            return super.parseByteArrayField(i);
        }

        @Override // de.rub.nds.tlsattacker.core.protocol.Parser
        public int getBytesLeft() {
            return super.getBytesLeft();
        }

        @Override // de.rub.nds.tlsattacker.core.protocol.Parser
        public int getPointer() {
            return super.getPointer();
        }
    }

    public RecordAEADCipher(TlsContext tlsContext, KeySet keySet) {
        super(tlsContext, keySet);
        ConnectionEndType localConnectionEndType = tlsContext.getConnection().getLocalConnectionEndType();
        this.encryptCipher = CipherWrapper.getEncryptionCipher(this.cipherSuite, localConnectionEndType, getKeySet());
        this.decryptCipher = CipherWrapper.getDecryptionCipher(this.cipherSuite, localConnectionEndType, getKeySet());
        if (this.cipherSuite.isCCM_8()) {
            this.aeadTagLength = 8;
        } else {
            this.aeadTagLength = 16;
        }
        if (this.version.isTLS13()) {
            this.aeadExplicitLength = 0;
        } else {
            this.aeadExplicitLength = AlgorithmResolver.getCipher(this.cipherSuite).getNonceBytesFromRecord();
        }
    }

    public int getAeadSizeIncrease() {
        return this.version.isTLS13() ? this.aeadTagLength : this.aeadExplicitLength + this.aeadTagLength;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [byte[], byte[][]] */
    private byte[] prepareEncryptionGcmNonce(byte[] bArr, byte[] bArr2, Record record) {
        byte[] concatenate = ArrayConverter.concatenate((byte[][]) new byte[]{bArr, bArr2});
        if (this.version.isTLS13() || this.cipherAlg == CipherAlgorithm.CHACHA20_POLY1305) {
            concatenate = preprocessIv(((BigInteger) record.getSequenceNumber().getValue()).longValue(), concatenate);
        } else if (this.cipherAlg == CipherAlgorithm.UNOFFICIAL_CHACHA20_POLY1305) {
            concatenate = ArrayConverter.longToUint64Bytes(((BigInteger) record.getSequenceNumber().getValue()).longValue());
        }
        record.getComputations().setGcmNonce(concatenate);
        return (byte[]) record.getComputations().getGcmNonce().getValue();
    }

    private byte[] prepareEncryptionAeadSalt(Record record) {
        record.getComputations().setAeadSalt(getKeySet().getWriteIv(this.context.getConnection().getLocalConnectionEndType()));
        return (byte[]) record.getComputations().getAeadSalt().getValue();
    }

    private byte[] prepareEncryptionExplicitNonce(Record record) {
        record.getComputations().setExplicitNonce(createExplicitNonce());
        return (byte[]) record.getComputations().getExplicitNonce().getValue();
    }

    private byte[] createExplicitNonce() {
        return this.aeadExplicitLength > 0 ? ArrayConverter.longToBytes(this.context.getWriteSequenceNumber(), this.aeadExplicitLength) : new byte[this.aeadExplicitLength];
    }

    /* JADX WARN: Type inference failed for: r1v26, types: [byte[], byte[][]] */
    /* JADX WARN: Type inference failed for: r1v35, types: [byte[], byte[][]] */
    @Override // de.rub.nds.tlsattacker.core.record.cipher.RecordCipher
    public void encrypt(Record record) throws CryptoException {
        LOGGER.debug("Encrypting Record");
        record.getComputations().setCipherKey(getKeySet().getWriteKey(this.context.getChooser().getConnectionEndType()));
        if (this.version.isTLS13()) {
            int intValue = this.context.getConfig().getDefaultAdditionalPadding().intValue();
            if (intValue > 65536) {
                LOGGER.warn("Additional padding is too big. setting it to max possible value");
                intValue = 65536;
            } else if (intValue < 0) {
                LOGGER.warn("Additional padding is negative, setting it to 0");
                intValue = 0;
            }
            record.getComputations().setPadding(new byte[intValue]);
            record.getComputations().setPlainRecordBytes(ArrayConverter.concatenate((byte[][]) new byte[]{(byte[]) record.getCleanProtocolMessageBytes().getValue(), new byte[]{((Byte) record.getContentType().getValue()).byteValue()}, (byte[]) record.getComputations().getPadding().getValue()}));
            record.setLength(((byte[]) record.getComputations().getPlainRecordBytes().getValue()).length + 16);
            record.setContentType(ProtocolMessageType.APPLICATION_DATA.getValue());
        } else {
            record.getComputations().setPlainRecordBytes((byte[]) record.getCleanProtocolMessageBytes().getValue());
        }
        byte[] prepareEncryptionExplicitNonce = prepareEncryptionExplicitNonce(record);
        byte[] prepareEncryptionGcmNonce = prepareEncryptionGcmNonce(prepareEncryptionAeadSalt(record), prepareEncryptionExplicitNonce, record);
        LOGGER.debug("Encrypting AEAD with the following IV: {}", ArrayConverter.bytesToHexString(prepareEncryptionGcmNonce));
        record.getComputations().setAuthenticatedMetaData(collectAdditionalAuthenticatedData(record, this.context.getChooser().getSelectedProtocolVersion()));
        byte[] bArr = (byte[]) record.getComputations().getAuthenticatedMetaData().getValue();
        LOGGER.debug("Encrypting AEAD with the following AAD: {}", ArrayConverter.bytesToHexString(bArr));
        byte[] encrypt = this.encryptCipher.encrypt(prepareEncryptionGcmNonce, this.aeadTagLength * 8, bArr, (byte[]) record.getComputations().getPlainRecordBytes().getValue());
        if (this.aeadTagLength > encrypt.length) {
            throw new CryptoException("Could not encrypt data. Supposed Tag is longer than the ciphertext");
        }
        byte[] copyOfRange = Arrays.copyOfRange(encrypt, 0, encrypt.length - this.aeadTagLength);
        record.getComputations().setAuthenticatedNonMetaData(copyOfRange);
        record.getComputations().setAuthenticationTag(Arrays.copyOfRange(encrypt, encrypt.length - this.aeadTagLength, encrypt.length));
        byte[] bArr2 = (byte[]) record.getComputations().getAuthenticationTag().getValue();
        record.getComputations().setCiphertext(copyOfRange);
        record.setProtocolMessageBytes(ArrayConverter.concatenate((byte[][]) new byte[]{prepareEncryptionExplicitNonce, (byte[]) record.getComputations().getCiphertext().getValue(), bArr2}));
        record.getComputations().setAuthenticationTagValid(true);
    }

    @Override // de.rub.nds.tlsattacker.core.record.cipher.RecordCipher
    public void encrypt(BlobRecord blobRecord) throws CryptoException {
        LOGGER.debug("Encrypting BlobRecord");
        blobRecord.setProtocolMessageBytes(this.encryptCipher.encrypt((byte[]) blobRecord.getCleanProtocolMessageBytes().getValue()));
    }

    /* JADX WARN: Type inference failed for: r0v44, types: [byte[], byte[][]] */
    /* JADX WARN: Type inference failed for: r4v2, types: [byte[], byte[][]] */
    @Override // de.rub.nds.tlsattacker.core.record.cipher.RecordCipher
    public void decrypt(Record record) throws CryptoException {
        LOGGER.debug("Decrypting Record");
        record.getComputations().setCipherKey(getKeySet().getReadKey(this.context.getChooser().getConnectionEndType()));
        DecryptionParser decryptionParser = new DecryptionParser(0, (byte[]) record.getProtocolMessageBytes().getValue());
        record.getComputations().setExplicitNonce(decryptionParser.parseByteArrayField(this.aeadExplicitLength));
        byte[] bArr = (byte[]) record.getComputations().getExplicitNonce().getValue();
        record.getComputations().setAeadSalt(getKeySet().getReadIv(this.context.getConnection().getLocalConnectionEndType()));
        byte[] bArr2 = (byte[]) record.getComputations().getAeadSalt().getValue();
        byte[] parseByteArrayField = decryptionParser.parseByteArrayField(decryptionParser.getBytesLeft() - this.aeadTagLength);
        record.getComputations().setCiphertext(parseByteArrayField);
        record.getComputations().setAuthenticatedNonMetaData((byte[]) record.getComputations().getCiphertext().getValue());
        record.getComputations().setAuthenticatedMetaData(collectAdditionalAuthenticatedData(record, this.context.getChooser().getSelectedProtocolVersion()));
        byte[] bArr3 = (byte[]) record.getComputations().getAuthenticatedMetaData().getValue();
        LOGGER.debug("Decrypting AEAD with the following AAD: {}", ArrayConverter.bytesToHexString(bArr3));
        byte[] concatenate = ArrayConverter.concatenate((byte[][]) new byte[]{bArr2, bArr});
        if (this.version.isTLS13() || this.cipherAlg == CipherAlgorithm.CHACHA20_POLY1305) {
            concatenate = preprocessIv(((BigInteger) record.getSequenceNumber().getValue()).longValue(), concatenate);
        } else if (this.cipherAlg == CipherAlgorithm.UNOFFICIAL_CHACHA20_POLY1305) {
            concatenate = ArrayConverter.longToUint64Bytes(((BigInteger) record.getSequenceNumber().getValue()).longValue());
        }
        record.getComputations().setGcmNonce(concatenate);
        byte[] bArr4 = (byte[]) record.getComputations().getGcmNonce().getValue();
        LOGGER.debug("Decrypting AEAD with the following IV: {}", ArrayConverter.bytesToHexString(bArr4));
        record.getComputations().setAuthenticationTag(decryptionParser.parseByteArrayField(decryptionParser.getBytesLeft()));
        try {
            byte[] decrypt = this.decryptCipher.decrypt(bArr4, this.aeadTagLength * 8, bArr3, ArrayConverter.concatenate((byte[][]) new byte[]{parseByteArrayField, (byte[]) record.getComputations().getAuthenticationTag().getValue()}));
            record.getComputations().setAuthenticationTagValid(true);
            record.getComputations().setPlainRecordBytes(decrypt);
            byte[] bArr5 = (byte[]) record.getComputations().getPlainRecordBytes().getValue();
            if (this.version.isTLS13()) {
                int countTrailingZeroBytes = countTrailingZeroBytes(bArr5);
                if (countTrailingZeroBytes == bArr5.length) {
                    LOGGER.warn("Record contains ONLY padding and no content type. Setting clean bytes == plainbytes");
                    record.setCleanProtocolMessageBytes(bArr5);
                    return;
                }
                DecryptionParser decryptionParser2 = new DecryptionParser(0, bArr5);
                byte[] parseByteArrayField2 = decryptionParser2.parseByteArrayField((bArr5.length - countTrailingZeroBytes) - 1);
                byte[] parseByteArrayField3 = decryptionParser2.parseByteArrayField(1);
                record.getComputations().setPadding(decryptionParser2.parseByteArrayField(countTrailingZeroBytes));
                record.setCleanProtocolMessageBytes(parseByteArrayField2);
                record.setContentType(parseByteArrayField3[0]);
                record.setContentMessageType(ProtocolMessageType.getContentType(parseByteArrayField3[0]));
            } else {
                record.setCleanProtocolMessageBytes(bArr5);
            }
        } catch (CryptoException e) {
            LOGGER.warn("Tag invalid", e);
            record.getComputations().setAuthenticationTagValid(false);
            throw new CryptoException(e);
        }
    }

    @Override // de.rub.nds.tlsattacker.core.record.cipher.RecordCipher
    public void decrypt(BlobRecord blobRecord) throws CryptoException {
        LOGGER.debug("Decrypting BlobRecord");
        blobRecord.setCleanProtocolMessageBytes(this.decryptCipher.decrypt((byte[]) blobRecord.getProtocolMessageBytes().getValue()));
    }

    private int countTrailingZeroBytes(byte[] bArr) {
        int i = 0;
        for (int length = bArr.length - 1; length < bArr.length && bArr[length] == 0; length--) {
            i++;
        }
        return i;
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [byte[], byte[][]] */
    public byte[] preprocessIv(long j, byte[] bArr) {
        byte[] concatenate = ArrayConverter.concatenate((byte[][]) new byte[]{new byte[]{0, 0, 0, 0}, ArrayConverter.longToUint64Bytes(j)});
        for (int i = 0; i < bArr.length; i++) {
            int i2 = i;
            concatenate[i2] = (byte) (concatenate[i2] ^ bArr[i]);
        }
        return concatenate;
    }
}
