package io.kroxylicious.filter.encryption.encrypt;

import edu.umd.cs.findbugs.annotations.NonNull;
import io.kroxylicious.filter.encryption.common.EncryptionException;
import io.kroxylicious.filter.encryption.common.FilterThreadExecutor;
import io.kroxylicious.filter.encryption.common.RecordEncryptionUtil;
import io.kroxylicious.filter.encryption.crypto.Encryption;
import io.kroxylicious.filter.encryption.dek.BufferTooSmallException;
import io.kroxylicious.filter.encryption.dek.Dek;
import io.kroxylicious.filter.encryption.dek.DestroyedDekException;
import io.kroxylicious.filter.encryption.dek.ExhaustedDekException;
import io.kroxylicious.kafka.transform.RecordStream;
import io.kroxylicious.kms.service.Serde;
import java.nio.ByteBuffer;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.IntFunction;
import org.apache.kafka.common.errors.NetworkException;
import org.apache.kafka.common.record.MemoryRecords;
import org.apache.kafka.common.utils.ByteBufferOutputStream;

/* loaded from: input_file:io/kroxylicious/filter/encryption/encrypt/InBandEncryptionManager.class */
public class InBandEncryptionManager<K, E> implements EncryptionManager<K> {
    private static final int MAX_ATTEMPTS = 100;
    private final Encryption encryption;
    private final Serde<E> edekSerde;
    private final EncryptionDekCache<K, E> dekCache;

    @NonNull
    private final FilterThreadExecutor filterThreadExecutor;
    private final int recordBufferInitialBytes;
    private final int recordBufferMaxBytes;

    public InBandEncryptionManager(@NonNull Encryption encryption, @NonNull Serde<E> serde, int i, int i2, @NonNull EncryptionDekCache<K, E> encryptionDekCache, @NonNull FilterThreadExecutor filterThreadExecutor) {
        this.filterThreadExecutor = filterThreadExecutor;
        this.encryption = (Encryption) Objects.requireNonNull(encryption);
        this.edekSerde = (Serde) Objects.requireNonNull(serde);
        if (i <= 0) {
            throw new IllegalArgumentException();
        }
        this.recordBufferInitialBytes = i;
        if (i2 <= 0) {
            throw new IllegalArgumentException();
        }
        this.recordBufferMaxBytes = i2;
        this.dekCache = encryptionDekCache;
    }

    public CompletionStage<Dek<E>> currentDek(@NonNull EncryptionScheme<K> encryptionScheme) {
        return this.dekCache.get(encryptionScheme, this.filterThreadExecutor);
    }

    @Override // io.kroxylicious.filter.encryption.encrypt.EncryptionManager
    @NonNull
    public CompletionStage<MemoryRecords> encrypt(@NonNull String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull MemoryRecords memoryRecords, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        int i2;
        if (memoryRecords.sizeInBytes() != 0 && (i2 = RecordEncryptionUtil.totalRecordsInBatches(memoryRecords)) != 0) {
            return attemptEncrypt(str, i, encryptionScheme, memoryRecords, 0, intFunction, i2);
        }
        return CompletableFuture.completedFuture(memoryRecords);
    }

    private ByteBufferOutputStream allocateBufferForEncrypt(@NonNull MemoryRecords memoryRecords, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        return intFunction.apply(2 * memoryRecords.sizeInBytes());
    }

    private CompletionStage<MemoryRecords> attemptEncrypt(@NonNull String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull MemoryRecords memoryRecords, int i2, @NonNull IntFunction<ByteBufferOutputStream> intFunction, int i3) {
        return i2 >= MAX_ATTEMPTS ? CompletableFuture.failedFuture(new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + i3 + " records for topic " + str + " partition " + i + " after " + i2 + " attempts", new NetworkException("Failed to encrypt record(s) because there were no valid encryption keys"))) : currentDek(encryptionScheme).thenCompose(dek -> {
            if (!dek.isDestroyed()) {
                try {
                    try {
                        Dek<E>.Encryptor encryptor = dek.encryptor(i3);
                        try {
                            CompletableFuture completedFuture = CompletableFuture.completedFuture(encryptBatches(str, i, encryptionScheme, memoryRecords, encryptor, intFunction));
                            if (encryptor != null) {
                                encryptor.close();
                            }
                            return completedFuture;
                        } catch (Throwable th) {
                            if (encryptor != null) {
                                try {
                                    encryptor.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    } catch (DestroyedDekException | ExhaustedDekException e) {
                        rotateKeyContext(encryptionScheme, dek);
                    }
                } catch (Exception e2) {
                    return CompletableFuture.failedFuture(e2);
                }
            }
            return attemptEncrypt(str, i, encryptionScheme, memoryRecords, i2 + 1, intFunction, i3);
        });
    }

    @NonNull
    private MemoryRecords encryptBatches(@NonNull String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull MemoryRecords memoryRecords, @NonNull Dek<E>.Encryptor encryptor, @NonNull IntFunction<ByteBufferOutputStream> intFunction) {
        ByteBuffer byteBuffer;
        ByteBuffer allocate = ByteBuffer.allocate(this.recordBufferInitialBytes);
        while (true) {
            try {
                byteBuffer = allocate;
                return RecordStream.ofRecords(memoryRecords).mapConstant(encryptor).toMemoryRecords(allocateBufferForEncrypt(memoryRecords, intFunction), new RecordEncryptor(str, i, this.encryption, encryptionScheme, this.edekSerde, byteBuffer));
            } catch (BufferTooSmallException e) {
                int capacity = 2 * byteBuffer.capacity();
                if (capacity > this.recordBufferMaxBytes) {
                    throw new EncryptionException("Record buffer cannot grow greater than " + this.recordBufferMaxBytes + " bytes");
                }
                allocate = ByteBuffer.allocate(capacity);
            }
        }
    }

    private void rotateKeyContext(@NonNull EncryptionScheme<K> encryptionScheme, @NonNull Dek<E> dek) {
        this.dekCache.invalidate(encryptionScheme, dek);
        dek.destroyForEncrypt();
    }
}
