package org.opensearch.ml.common.model;

import java.io.IOException;
import java.security.AccessController;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.regex.Pattern;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.repackage.com.google.common.base.Ascii;
import org.opensearch.transport.client.Client;

/* loaded from: input_file:org/opensearch/ml/common/model/ModelGuardrail.class */
public class ModelGuardrail extends Guardrail {

    @Generated
    private static final Logger log = LogManager.getLogger(ModelGuardrail.class);
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String RESPONSE_FILTER_FIELD = "response_filter";
    public static final String RESPONSE_VALIDATION_REGEX_FIELD = "response_validation_regex";
    private String modelId;
    private String responseFilter;
    private String responseAccept;
    private NamedXContentRegistry xContentRegistry;
    private Client client;
    private Pattern regexAcceptPattern;

    @Generated
    /* loaded from: input_file:org/opensearch/ml/common/model/ModelGuardrail$ModelGuardrailBuilder.class */
    public static class ModelGuardrailBuilder {

        @Generated
        private String modelId;

        @Generated
        private String responseFilter;

        @Generated
        private String responseAccept;

        @Generated
        ModelGuardrailBuilder() {
        }

        @Generated
        public ModelGuardrailBuilder modelId(String str) {
            this.modelId = str;
            return this;
        }

        @Generated
        public ModelGuardrailBuilder responseFilter(String str) {
            this.responseFilter = str;
            return this;
        }

        @Generated
        public ModelGuardrailBuilder responseAccept(String str) {
            this.responseAccept = str;
            return this;
        }

        @Generated
        public ModelGuardrail build() {
            return new ModelGuardrail(this.modelId, this.responseFilter, this.responseAccept);
        }

        @Generated
        public String toString() {
            return "ModelGuardrail.ModelGuardrailBuilder(modelId=" + this.modelId + ", responseFilter=" + this.responseFilter + ", responseAccept=" + this.responseAccept + ")";
        }
    }

    public ModelGuardrail(String str, String str2, String str3) {
        this.modelId = str;
        this.responseFilter = str2;
        this.responseAccept = str3;
    }

    public ModelGuardrail(@NonNull Map<String, Object> map) {
        this((String) map.get("model_id"), (String) map.get("response_filter"), (String) map.get(RESPONSE_VALIDATION_REGEX_FIELD));
        Objects.requireNonNull(map, "params is marked non-null but is null");
    }

    public ModelGuardrail(StreamInput streamInput) throws IOException {
        this.modelId = streamInput.readString();
        this.responseFilter = streamInput.readString();
        this.responseAccept = streamInput.readString();
    }

    @Override // org.opensearch.ml.common.model.Guardrail
    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeString(this.modelId);
        streamOutput.writeString(this.responseFilter);
        streamOutput.writeString(this.responseAccept);
    }

    private Boolean validateAcceptRegex(String str) {
        return Boolean.valueOf(this.regexAcceptPattern.matcher(str).matches());
    }

    @Override // org.opensearch.ml.common.model.Guardrail
    public Boolean validate(String str, Map<String, String> map) {
        String str2 = map == null ? null : map.get(MLInput.QUESTION_FIELD);
        if (str2 == null || str2.isEmpty()) {
            log.info("Guardrail request is empty.");
            return true;
        }
        log.info("Guardrail request: {}", str2);
        AtomicBoolean atomicBoolean = new AtomicBoolean(true);
        ActionListener wrapActionListener = wrapActionListener(ActionListener.wrap(mLTaskResponse -> {
            ModelTensor modelTensor = ((ModelTensorOutput) mLTaskResponse.getOutput()).getMlModelOutputs().get(0).getMlModelTensors().get(0);
            String str3 = (String) AccessController.doPrivileged(() -> {
                return StringUtils.gson.toJson(modelTensor.getDataAsMap().get("response"));
            });
            log.info("Guardrail response: {}", str3);
            if (validateAcceptRegex(str3).booleanValue()) {
                return;
            }
            atomicBoolean.set(false);
        }, exc -> {
            log.error("[ModelGuardrail] Failed to get prediction response.", exc);
        }), actionResponse -> {
            return MLTaskResponse.fromActionResponse(actionResponse);
        });
        CountDownLatch countDownLatch = new CountDownLatch(1);
        HashMap hashMap = new HashMap();
        hashMap.put(MLInput.QUESTION_FIELD, str2);
        if (this.responseFilter != null && !this.responseFilter.isEmpty()) {
            hashMap.put("response_filter", this.responseFilter);
        }
        log.info("Guardrail resFilter: {}", this.responseFilter);
        this.client.execute(MLPredictionTaskAction.INSTANCE, new MLPredictionTaskRequest(this.modelId, RemoteInferenceMLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(RemoteInferenceInputDataSet.builder().parameters(hashMap).build()).build()), new LatchedActionListener(wrapActionListener, countDownLatch));
        try {
            countDownLatch.await(5L, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            log.error("[ModelGuardrail] Validation was timeout.", e);
        }
        return Boolean.valueOf(atomicBoolean.get());
    }

    @Override // org.opensearch.ml.common.model.Guardrail
    public void init(NamedXContentRegistry namedXContentRegistry, Client client) {
        this.xContentRegistry = namedXContentRegistry;
        this.client = client;
        this.regexAcceptPattern = Pattern.compile(this.responseAccept);
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        if (this.modelId != null) {
            xContentBuilder.field("model_id", this.modelId);
        }
        if (this.responseFilter != null) {
            xContentBuilder.field("response_filter", this.responseFilter);
        }
        if (this.responseAccept != null) {
            xContentBuilder.field(RESPONSE_VALIDATION_REGEX_FIELD, this.responseAccept);
        }
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x003a. Please report as an issue. */
    public static ModelGuardrail parse(XContentParser xContentParser) throws IOException {
        String str = null;
        String str2 = null;
        String str3 = null;
        XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.currentToken(), xContentParser);
        while (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) {
            String currentName = xContentParser.currentName();
            xContentParser.nextToken();
            boolean z = -1;
            switch (currentName.hashCode()) {
                case -719530689:
                    if (currentName.equals(RESPONSE_VALIDATION_REGEX_FIELD)) {
                        z = 2;
                        break;
                    }
                    break;
                case -619038223:
                    if (currentName.equals("model_id")) {
                        z = false;
                        break;
                    }
                    break;
                case 361732406:
                    if (currentName.equals("response_filter")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    str = xContentParser.text();
                    break;
                case Ascii.SOH /* 1 */:
                    str2 = xContentParser.text();
                    break;
                case true:
                    str3 = xContentParser.text();
                    break;
                default:
                    xContentParser.skipChildren();
                    break;
            }
        }
        return builder().modelId(str).responseFilter(str2).responseAccept(str3).build();
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> actionListener, Function<ActionResponse, T> function) {
        return ActionListener.wrap(actionResponse -> {
            actionListener.onResponse((ActionResponse) function.apply(actionResponse));
        }, exc -> {
            actionListener.onFailure(exc);
        });
    }

    @Generated
    public static ModelGuardrailBuilder builder() {
        return new ModelGuardrailBuilder();
    }

    @Generated
    public ModelGuardrailBuilder toBuilder() {
        return new ModelGuardrailBuilder().modelId(this.modelId).responseFilter(this.responseFilter).responseAccept(this.responseAccept);
    }

    @Generated
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ModelGuardrail)) {
            return false;
        }
        ModelGuardrail modelGuardrail = (ModelGuardrail) obj;
        if (!modelGuardrail.canEqual(this)) {
            return false;
        }
        String modelId = getModelId();
        String modelId2 = modelGuardrail.getModelId();
        if (modelId == null) {
            if (modelId2 != null) {
                return false;
            }
        } else if (!modelId.equals(modelId2)) {
            return false;
        }
        String responseFilter = getResponseFilter();
        String responseFilter2 = modelGuardrail.getResponseFilter();
        if (responseFilter == null) {
            if (responseFilter2 != null) {
                return false;
            }
        } else if (!responseFilter.equals(responseFilter2)) {
            return false;
        }
        String responseAccept = getResponseAccept();
        String responseAccept2 = modelGuardrail.getResponseAccept();
        if (responseAccept == null) {
            if (responseAccept2 != null) {
                return false;
            }
        } else if (!responseAccept.equals(responseAccept2)) {
            return false;
        }
        NamedXContentRegistry xContentRegistry = getXContentRegistry();
        NamedXContentRegistry xContentRegistry2 = modelGuardrail.getXContentRegistry();
        if (xContentRegistry == null) {
            if (xContentRegistry2 != null) {
                return false;
            }
        } else if (!xContentRegistry.equals(xContentRegistry2)) {
            return false;
        }
        Client client = getClient();
        Client client2 = modelGuardrail.getClient();
        if (client == null) {
            if (client2 != null) {
                return false;
            }
        } else if (!client.equals(client2)) {
            return false;
        }
        Pattern regexAcceptPattern = getRegexAcceptPattern();
        Pattern regexAcceptPattern2 = modelGuardrail.getRegexAcceptPattern();
        return regexAcceptPattern == null ? regexAcceptPattern2 == null : regexAcceptPattern.equals(regexAcceptPattern2);
    }

    @Generated
    protected boolean canEqual(Object obj) {
        return obj instanceof ModelGuardrail;
    }

    @Generated
    public int hashCode() {
        String modelId = getModelId();
        int hashCode = (1 * 59) + (modelId == null ? 43 : modelId.hashCode());
        String responseFilter = getResponseFilter();
        int hashCode2 = (hashCode * 59) + (responseFilter == null ? 43 : responseFilter.hashCode());
        String responseAccept = getResponseAccept();
        int hashCode3 = (hashCode2 * 59) + (responseAccept == null ? 43 : responseAccept.hashCode());
        NamedXContentRegistry xContentRegistry = getXContentRegistry();
        int hashCode4 = (hashCode3 * 59) + (xContentRegistry == null ? 43 : xContentRegistry.hashCode());
        Client client = getClient();
        int hashCode5 = (hashCode4 * 59) + (client == null ? 43 : client.hashCode());
        Pattern regexAcceptPattern = getRegexAcceptPattern();
        return (hashCode5 * 59) + (regexAcceptPattern == null ? 43 : regexAcceptPattern.hashCode());
    }

    @Generated
    public String getModelId() {
        return this.modelId;
    }

    @Generated
    public String getResponseFilter() {
        return this.responseFilter;
    }

    @Generated
    public String getResponseAccept() {
        return this.responseAccept;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public Pattern getRegexAcceptPattern() {
        return this.regexAcceptPattern;
    }
}
