package io.kroxylicious.proxy.filter.schema;

import io.kroxylicious.proxy.filter.FilterContext;
import io.kroxylicious.proxy.filter.ProduceRequestFilter;
import io.kroxylicious.proxy.filter.ProduceResponseFilter;
import io.kroxylicious.proxy.filter.RequestFilterResult;
import io.kroxylicious.proxy.filter.ResponseFilterResult;
import io.kroxylicious.proxy.filter.schema.validation.request.ProduceRequestValidationResult;
import io.kroxylicious.proxy.filter.schema.validation.request.ProduceRequestValidator;
import io.kroxylicious.proxy.filter.schema.validation.topic.PartitionValidationResult;
import io.kroxylicious.proxy.filter.schema.validation.topic.RecordValidationFailure;
import io.kroxylicious.proxy.filter.schema.validation.topic.TopicValidationResult;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionStage;
import org.apache.kafka.common.message.ProduceRequestData;
import org.apache.kafka.common.message.ProduceResponseData;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.protocol.Errors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kroxylicious/proxy/filter/schema/ProduceValidationFilter.class */
public class ProduceValidationFilter implements ProduceRequestFilter, ProduceResponseFilter {
    private static final Logger LOGGER = LoggerFactory.getLogger(ProduceValidationFilter.class);
    private final boolean forwardPartialRequests;
    private final ProduceRequestValidator validator;
    private final Map<Integer, ProduceRequestValidationResult> correlatedResults = new HashMap();

    public ProduceValidationFilter(boolean z, ProduceRequestValidator produceRequestValidator) {
        if (produceRequestValidator == null) {
            throw new IllegalArgumentException("validator is null");
        }
        this.forwardPartialRequests = z;
        this.validator = produceRequestValidator;
    }

    public CompletionStage<RequestFilterResult> onProduceRequest(short s, RequestHeaderData requestHeaderData, ProduceRequestData produceRequestData, FilterContext filterContext) {
        ProduceRequestValidationResult validateRequest = this.validator.validateRequest(produceRequestData);
        return validateRequest.isAnyTopicPartitionInvalid() ? handleInvalidTopicPartitions(requestHeaderData, produceRequestData, filterContext, validateRequest) : filterContext.forwardRequest(requestHeaderData, produceRequestData);
    }

    private CompletionStage<RequestFilterResult> handleInvalidTopicPartitions(RequestHeaderData requestHeaderData, ProduceRequestData produceRequestData, FilterContext filterContext, ProduceRequestValidationResult produceRequestValidationResult) {
        if (produceRequestValidationResult.isAllTopicPartitionsInvalid()) {
            LOGGER.debug("all topic-partitions for request contained invalid data: {}", produceRequestValidationResult);
            return filterContext.requestFilterResultBuilder().shortCircuitResponse(invalidateEntireRequest(produceRequestData, produceRequestValidationResult)).completed();
        }
        if (produceRequestData.transactionalId() != null || !this.forwardPartialRequests) {
            LOGGER.debug("some topic-partitions for transactional request with id: {}, contained invalid data: {}, invalidation entire request", produceRequestData.transactionalId(), produceRequestValidationResult);
            return filterContext.requestFilterResultBuilder().shortCircuitResponse(invalidateEntireRequest(produceRequestData, produceRequestValidationResult)).completed();
        }
        LOGGER.debug("some topic-partitions contained invalid data: {}, forwarding valid topic-partitions", produceRequestValidationResult);
        produceRequestData.topicData().removeIf(topicProduceData -> {
            return produceRequestValidationResult.isAllPartitionsInvalid(topicProduceData.name());
        });
        Iterator it = produceRequestData.topicData().iterator();
        while (it.hasNext()) {
            ProduceRequestData.TopicProduceData topicProduceData2 = (ProduceRequestData.TopicProduceData) it.next();
            topicProduceData2.partitionData().removeIf(partitionProduceData -> {
                return !produceRequestValidationResult.isPartitionValid(topicProduceData2.name(), partitionProduceData.index());
            });
        }
        this.correlatedResults.put(Integer.valueOf(requestHeaderData.correlationId()), produceRequestValidationResult);
        return filterContext.forwardRequest(requestHeaderData, produceRequestData);
    }

    private static ProduceResponseData invalidateEntireRequest(ProduceRequestData produceRequestData, ProduceRequestValidationResult produceRequestValidationResult) {
        ProduceResponseData produceResponseData = new ProduceResponseData();
        ProduceResponseData.TopicProduceResponseCollection topicProduceResponseCollection = new ProduceResponseData.TopicProduceResponseCollection();
        produceRequestData.topicData().forEach(topicProduceData -> {
            String name = topicProduceData.name();
            topicProduceResponseCollection.add(createInvalidatedTopicProduceResponse(name, topicProduceData, produceRequestValidationResult.topicResult(name)));
        });
        produceResponseData.setResponses(topicProduceResponseCollection);
        return produceResponseData;
    }

    private static ProduceResponseData.TopicProduceResponse createInvalidatedTopicProduceResponse(String str, ProduceRequestData.TopicProduceData topicProduceData, TopicValidationResult topicValidationResult) {
        ProduceResponseData.TopicProduceResponse topicProduceResponse = new ProduceResponseData.TopicProduceResponse();
        topicProduceResponse.setName(str);
        topicProduceResponse.setPartitionResponses(topicProduceData.partitionData().stream().map(partitionProduceData -> {
            return createInvalidatedPartitionProduceResponse(partitionProduceData, topicValidationResult.getPartitionResult(partitionProduceData.index()));
        }).toList());
        return topicProduceResponse;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ProduceResponseData.PartitionProduceResponse createInvalidatedPartitionProduceResponse(ProduceRequestData.PartitionProduceData partitionProduceData, PartitionValidationResult partitionValidationResult) {
        ProduceResponseData.PartitionProduceResponse partitionProduceResponse = new ProduceResponseData.PartitionProduceResponse();
        partitionProduceResponse.setIndex(partitionProduceData.index());
        partitionProduceResponse.setErrorCode(Errors.INVALID_RECORD.code());
        if (partitionValidationResult.allRecordsValid()) {
            partitionProduceResponse.setErrorMessage("Invalid record in another topic-partition caused whole ProduceRequest to be invalidated");
        } else {
            for (RecordValidationFailure recordValidationFailure : partitionValidationResult.recordValidationFailures()) {
                partitionProduceResponse.recordErrors().add(new ProduceResponseData.BatchIndexAndErrorMessage().setBatchIndex(recordValidationFailure.invalidIndex()).setBatchIndexErrorMessage(recordValidationFailure.errorMessage()));
            }
            partitionProduceResponse.setErrorMessage(toErrorString(partitionValidationResult.recordValidationFailures()));
        }
        return partitionProduceResponse;
    }

    private static String toErrorString(List<RecordValidationFailure> list) {
        return "Records in batch were invalid: [" + ((String) list.stream().findFirst().map((v0) -> {
            return v0.errorMessage();
        }).orElse("Failure List Empty")) + "]";
    }

    public CompletionStage<ResponseFilterResult> onProduceResponse(short s, ResponseHeaderData responseHeaderData, ProduceResponseData produceResponseData, FilterContext filterContext) {
        ProduceRequestValidationResult remove = this.correlatedResults.remove(Integer.valueOf(responseHeaderData.correlationId()));
        if (remove == null) {
            return filterContext.forwardResponse(responseHeaderData, produceResponseData);
        }
        LOGGER.debug("augmenting invalid topic-partition details into response: {}", remove);
        augmentResponseWithInvalidTopicPartitions(produceResponseData, remove);
        return filterContext.forwardResponse(responseHeaderData, produceResponseData);
    }

    private void augmentResponseWithInvalidTopicPartitions(ProduceResponseData produceResponseData, ProduceRequestValidationResult produceRequestValidationResult) {
        produceRequestValidationResult.topicsWithInvalidPartitions().forEach(topicValidationResult -> {
            ProduceResponseData.TopicProduceResponse find = produceResponseData.responses().find(topicValidationResult.topicName());
            if (find == null) {
                find = new ProduceResponseData.TopicProduceResponse();
                find.setName(topicValidationResult.topicName());
                produceResponseData.responses().add(find);
            }
            augmentTopicProduceResponse(topicValidationResult, find);
        });
    }

    private static void augmentTopicProduceResponse(TopicValidationResult topicValidationResult, ProduceResponseData.TopicProduceResponse topicProduceResponse) {
        topicValidationResult.invalidPartitions().forEach(partitionValidationResult -> {
            ProduceResponseData.PartitionProduceResponse partitionProduceResponse = new ProduceResponseData.PartitionProduceResponse();
            partitionProduceResponse.setIndex(partitionValidationResult.index());
            for (RecordValidationFailure recordValidationFailure : partitionValidationResult.recordValidationFailures()) {
                partitionProduceResponse.recordErrors().add(new ProduceResponseData.BatchIndexAndErrorMessage().setBatchIndex(recordValidationFailure.invalidIndex()).setBatchIndexErrorMessage(recordValidationFailure.errorMessage()));
            }
            partitionProduceResponse.setErrorCode(Errors.INVALID_RECORD.code());
            partitionProduceResponse.setErrorMessage(toErrorString(partitionValidationResult.recordValidationFailures()));
            topicProduceResponse.partitionResponses().add(partitionProduceResponse);
        });
    }
}
