package io.kroxylicious.filter.encryption.inband;

import com.github.benmanes.caffeine.cache.AsyncLoadingCache;
import com.github.benmanes.caffeine.cache.Caffeine;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import io.kroxylicious.filter.encryption.AadSpec;
import io.kroxylicious.filter.encryption.BackoffStrategy;
import io.kroxylicious.filter.encryption.CipherCode;
import io.kroxylicious.filter.encryption.EncryptionException;
import io.kroxylicious.filter.encryption.EncryptionScheme;
import io.kroxylicious.filter.encryption.EncryptionVersion;
import io.kroxylicious.filter.encryption.EnvelopeEncryptionFilter;
import io.kroxylicious.filter.encryption.KeyManager;
import io.kroxylicious.filter.encryption.Receiver;
import io.kroxylicious.filter.encryption.RecordField;
import io.kroxylicious.filter.encryption.ResilientKms;
import io.kroxylicious.filter.encryption.WrapperVersion;
import io.kroxylicious.kms.service.Kms;
import io.kroxylicious.kms.service.Serde;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.internals.RecordHeader;
import org.apache.kafka.common.record.Record;
import org.apache.kafka.common.utils.ByteUtils;

/* loaded from: input_file:io/kroxylicious/filter/encryption/inband/InBandKeyManager.class */
public class InBandKeyManager<K, E> implements KeyManager<K> {
    private static final int MAX_ATTEMPTS = 3;
    static final String ENCRYPTION_HEADER_NAME = "kroxylicious.io/encryption";
    private final Kms<K, E> kms;
    private final BufferPool bufferPool;
    private final Serde<E> edekSerde;
    private final int maxEncryptionsPerDek;
    private final long dekTtlNanos = 5000000000L;
    private final AsyncLoadingCache<K, KeyContext> keyContextCache = Caffeine.newBuilder().buildAsync((obj, executor) -> {
        return makeKeyContext(obj);
    });
    private final AsyncLoadingCache<E, AesGcmEncryptor> decryptorCache = Caffeine.newBuilder().buildAsync((obj, executor) -> {
        return makeDecryptor(obj);
    });
    private final EncryptionVersion encryptionVersion = EncryptionVersion.V1;
    private final Header[] encryptionHeader = {new RecordHeader(ENCRYPTION_HEADER_NAME, new byte[]{this.encryptionVersion.code()})};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState.class */
    public static final class DecryptState extends Record {

        @NonNull
        private final Record kafkaRecord;

        @NonNull
        private final ByteBuffer valueWrapper;

        @Nullable
        private final EncryptionVersion decryptionVersion;

        @Nullable
        private final AesGcmEncryptor encryptor;

        private DecryptState(@NonNull Record record, @NonNull ByteBuffer byteBuffer, @Nullable EncryptionVersion encryptionVersion, @Nullable AesGcmEncryptor aesGcmEncryptor) {
            this.kafkaRecord = record;
            this.valueWrapper = byteBuffer;
            this.decryptionVersion = encryptionVersion;
            this.encryptor = aesGcmEncryptor;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, DecryptState.class), DecryptState.class, "kafkaRecord;valueWrapper;decryptionVersion;encryptor", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->kafkaRecord:Lorg/apache/kafka/common/record/Record;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->valueWrapper:Ljava/nio/ByteBuffer;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->decryptionVersion:Lio/kroxylicious/filter/encryption/EncryptionVersion;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->encryptor:Lio/kroxylicious/filter/encryption/inband/AesGcmEncryptor;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, DecryptState.class), DecryptState.class, "kafkaRecord;valueWrapper;decryptionVersion;encryptor", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->kafkaRecord:Lorg/apache/kafka/common/record/Record;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->valueWrapper:Ljava/nio/ByteBuffer;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->decryptionVersion:Lio/kroxylicious/filter/encryption/EncryptionVersion;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->encryptor:Lio/kroxylicious/filter/encryption/inband/AesGcmEncryptor;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, DecryptState.class, Object.class), DecryptState.class, "kafkaRecord;valueWrapper;decryptionVersion;encryptor", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->kafkaRecord:Lorg/apache/kafka/common/record/Record;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->valueWrapper:Ljava/nio/ByteBuffer;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->decryptionVersion:Lio/kroxylicious/filter/encryption/EncryptionVersion;", "FIELD:Lio/kroxylicious/filter/encryption/inband/InBandKeyManager$DecryptState;->encryptor:Lio/kroxylicious/filter/encryption/inband/AesGcmEncryptor;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        @NonNull
        public Record kafkaRecord() {
            return this.kafkaRecord;
        }

        @NonNull
        public ByteBuffer valueWrapper() {
            return this.valueWrapper;
        }

        @Nullable
        public EncryptionVersion decryptionVersion() {
            return this.decryptionVersion;
        }

        @Nullable
        public AesGcmEncryptor encryptor() {
            return this.encryptor;
        }
    }

    public InBandKeyManager(Kms<K, E> kms, BufferPool bufferPool, int i, ScheduledExecutorService scheduledExecutorService, BackoffStrategy backoffStrategy) {
        this.kms = ResilientKms.get(kms, scheduledExecutorService, backoffStrategy, MAX_ATTEMPTS);
        this.bufferPool = bufferPool;
        this.edekSerde = kms.edekSerde();
        this.maxEncryptionsPerDek = i;
    }

    private CompletionStage<KeyContext> currentDekContext(@NonNull K k) {
        return this.keyContextCache.get(k);
    }

    private CompletableFuture<KeyContext> makeKeyContext(@NonNull K k) {
        return this.kms.generateDekPair(k).thenApply(dekPair -> {
            Object edek = dekPair.edek();
            ByteBuffer allocate = ByteBuffer.allocate((short) this.edekSerde.sizeOf(edek));
            this.edekSerde.serialize(edek, allocate);
            allocate.flip();
            return new KeyContext(allocate, System.nanoTime() + this.dekTtlNanos, this.maxEncryptionsPerDek, AesGcmEncryptor.forEncrypt(new AesGcmIvGenerator(new SecureRandom()), dekPair.dek()));
        }).toCompletableFuture();
    }

    @Override // io.kroxylicious.filter.encryption.KeyManager
    @NonNull
    public CompletionStage<Void> encrypt(@NonNull String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull List<? extends Record> list, @NonNull Receiver receiver) {
        return list.isEmpty() ? CompletableFuture.completedFuture(null) : attemptEncrypt(str, i, encryptionScheme, list, receiver, 0);
    }

    private CompletionStage<Void> attemptEncrypt(String str, int i, @NonNull EncryptionScheme<K> encryptionScheme, @NonNull List<? extends Record> list, @NonNull Receiver receiver, int i2) {
        return i2 >= MAX_ATTEMPTS ? CompletableFuture.failedFuture(new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + list.size() + " records for topic " + str + " partition " + i + " after " + i2 + " attempts")) : currentDekContext(encryptionScheme.kekId()).thenCompose(keyContext -> {
            synchronized (keyContext) {
                if (!keyContext.isDestroyed()) {
                    if (keyContext.hasAtLeastRemainingEncryptions(list.size())) {
                        return encrypt(encryptionScheme, list, receiver, keyContext);
                    }
                    rotateKeyContext(encryptionScheme, keyContext);
                }
                return attemptEncrypt(str, i, encryptionScheme, list, receiver, i2 + 1);
            }
        });
    }

    @NonNull
    private CompletableFuture<Void> encrypt(@NonNull EncryptionScheme<K> encryptionScheme, @NonNull List<? extends Record> list, @NonNull Receiver receiver, KeyContext keyContext) {
        int orElseThrow = list.stream().mapToInt(record -> {
            return Parcel.sizeOfParcel(this.encryptionVersion.parcelVersion(), encryptionScheme.recordFields(), record);
        }).filter(i -> {
            return i > 0;
        }).max().orElseThrow();
        int orElseThrow2 = list.stream().mapToInt(record2 -> {
            return sizeOfWrapper(keyContext, orElseThrow);
        }).filter(i2 -> {
            return i2 > 0;
        }).max().orElseThrow();
        ByteBuffer acquire = this.bufferPool.acquire(orElseThrow);
        ByteBuffer acquire2 = this.bufferPool.acquire(orElseThrow2);
        try {
            encryptRecords(encryptionScheme, keyContext, list, acquire, acquire2, receiver);
            if (acquire2 != null) {
                this.bufferPool.release(acquire2);
            }
            if (acquire != null) {
                this.bufferPool.release(acquire);
            }
            keyContext.recordEncryptions(list.size());
            return CompletableFuture.completedFuture(null);
        } catch (Throwable th) {
            if (acquire2 != null) {
                this.bufferPool.release(acquire2);
            }
            if (acquire != null) {
                this.bufferPool.release(acquire);
            }
            throw th;
        }
    }

    private void rotateKeyContext(@NonNull EncryptionScheme<K> encryptionScheme, KeyContext keyContext) {
        keyContext.destroy();
        this.keyContextCache.synchronous().invalidate(encryptionScheme.kekId());
    }

    private void encryptRecords(@NonNull EncryptionScheme<K> encryptionScheme, @NonNull KeyContext keyContext, @NonNull List<? extends Record> list, @NonNull ByteBuffer byteBuffer, @NonNull ByteBuffer byteBuffer2, @NonNull Receiver receiver) {
        list.forEach(record -> {
            if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) && record.headers().length > 0 && !record.hasValue()) {
                throw new IllegalStateException("encrypting headers prohibited when original record value null, we must preserve the null for tombstoning");
            }
            if (!record.hasValue()) {
                receiver.accept(record, null, record.headers());
                return;
            }
            Parcel.writeParcel(this.encryptionVersion.parcelVersion(), encryptionScheme.recordFields(), record, byteBuffer);
            byteBuffer.flip();
            receiver.accept(record, writeWrapper(keyContext, byteBuffer, byteBuffer2), transformHeaders(encryptionScheme, record));
            byteBuffer2.rewind();
            byteBuffer.rewind();
        });
    }

    private Header[] transformHeaders(@NonNull EncryptionScheme<K> encryptionScheme, Record record) {
        Header[] headerArr;
        Header[] headers = record.headers();
        if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) || headers.length == 0) {
            headerArr = this.encryptionHeader;
        } else {
            headerArr = new Header[1 + headers.length];
            headerArr[0] = this.encryptionHeader[0];
            System.arraycopy(headers, 0, headerArr, 1, headers.length);
        }
        return headerArr;
    }

    private int sizeOfWrapper(KeyContext keyContext, int i) {
        byte[] serializedEdek = keyContext.serializedEdek();
        return ByteUtils.sizeOfUnsignedVarint(serializedEdek.length) + serializedEdek.length + 1 + 1 + keyContext.encodedSize(i);
    }

    @Nullable
    private ByteBuffer writeWrapper(KeyContext keyContext, ByteBuffer byteBuffer, ByteBuffer byteBuffer2) {
        switch (this.encryptionVersion.wrapperVersion()) {
            case V1:
                byte[] serializedEdek = keyContext.serializedEdek();
                ByteUtils.writeUnsignedVarint(serializedEdek.length, byteBuffer2);
                byteBuffer2.put(serializedEdek);
                byteBuffer2.put(AadSpec.NONE.code());
                byteBuffer2.put(CipherCode.AES_GCM_96_128.code());
                keyContext.encodedSize(byteBuffer.limit());
                ByteBuffer byteBuffer3 = ByteUtils.EMPTY_BUF;
                keyContext.encode(byteBuffer, byteBuffer2);
                break;
        }
        byteBuffer2.flip();
        return byteBuffer2;
    }

    static EncryptionVersion decryptionVersion(String str, int i, Record record) {
        for (Header header : record.headers()) {
            if (ENCRYPTION_HEADER_NAME.equals(header.key())) {
                byte[] value = header.value();
                if (value.length == 1) {
                    return EncryptionVersion.fromCode(value[0]);
                }
                EncryptionException encryptionException = new EncryptionException("Invalid value for header with key 'kroxylicious.io/encryption' in record at offset " + record.offset() + " in partition " + encryptionException + " of topic " + i);
                throw encryptionException;
            }
        }
        return null;
    }

    private CompletableFuture<AesGcmEncryptor> makeDecryptor(E e) {
        return this.kms.decryptEdek(e).thenApply(AesGcmEncryptor::forDecrypt).toCompletableFuture();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.kroxylicious.filter.encryption.KeyManager
    @NonNull
    public CompletionStage<Void> decrypt(String str, int i, @NonNull List<? extends Record> list, @NonNull Receiver receiver) {
        ArrayList arrayList = new ArrayList(list.size());
        for (Record record : list) {
            EncryptionVersion decryptionVersion = decryptionVersion(str, i, record);
            if (decryptionVersion == null) {
                arrayList.add(CompletableFuture.completedStage(new DecryptState(record, record.value(), null, null)));
            } else {
                ByteBuffer value = record.value();
                arrayList.add(resolveEncryptor(decryptionVersion.wrapperVersion(), value).thenApply(aesGcmEncryptor -> {
                    return new DecryptState(record, value, decryptionVersion, aesGcmEncryptor);
                }));
            }
        }
        return EnvelopeEncryptionFilter.join(arrayList).thenApply(list2 -> {
            list2.forEach(decryptState -> {
                if (decryptState.encryptor() == null) {
                    receiver.accept(decryptState.kafkaRecord(), decryptState.valueWrapper(), decryptState.kafkaRecord().headers());
                } else {
                    decryptRecord(decryptState.decryptionVersion(), decryptState.encryptor(), decryptState.valueWrapper(), decryptState.kafkaRecord(), receiver);
                }
            });
            return null;
        });
    }

    private void decryptRecord(EncryptionVersion encryptionVersion, AesGcmEncryptor aesGcmEncryptor, ByteBuffer byteBuffer, Record record, @NonNull Receiver receiver) {
        ByteBuffer decryptParcel;
        switch (AadSpec.fromCode(byteBuffer.get())) {
            case NONE:
                ByteBuffer byteBuffer2 = ByteUtils.EMPTY_BUF;
                CipherCode.fromCode(byteBuffer.get());
                synchronized (aesGcmEncryptor) {
                    decryptParcel = decryptParcel(byteBuffer.slice(), aesGcmEncryptor);
                }
                Parcel.readParcel(encryptionVersion.parcelVersion(), decryptParcel, record, receiver);
                return;
            default:
                throw new IncompatibleClassChangeError();
        }
    }

    private CompletionStage<AesGcmEncryptor> resolveEncryptor(WrapperVersion wrapperVersion, ByteBuffer byteBuffer) {
        switch (wrapperVersion) {
            case V1:
                int readUnsignedVarint = ByteUtils.readUnsignedVarint(byteBuffer);
                Object deserialize = this.edekSerde.deserialize(byteBuffer.slice(byteBuffer.position(), readUnsignedVarint));
                byteBuffer.position(byteBuffer.position() + readUnsignedVarint);
                return this.decryptorCache.get(deserialize);
            default:
                throw new EncryptionException("Unknown wrapper version " + String.valueOf(wrapperVersion));
        }
    }

    private ByteBuffer decryptParcel(ByteBuffer byteBuffer, AesGcmEncryptor aesGcmEncryptor) {
        ByteBuffer duplicate = byteBuffer.duplicate();
        aesGcmEncryptor.decrypt(byteBuffer, duplicate);
        duplicate.flip();
        return duplicate;
    }
}
