package org.mlflow.artifacts;

import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.mlflow.api.proto.Service;
import org.mlflow.tracking.MlflowClientException;
import org.mlflow.tracking.creds.MlflowHostCreds;
import org.mlflow.tracking.creds.MlflowHostCredsProvider;
import org.mlflow_project.apachecommons.io.IOUtils;
import org.mlflow_project.google.common.annotations.VisibleForTesting;
import org.mlflow_project.google.common.collect.Lists;
import org.mlflow_project.google.gson.Gson;
import org.mlflow_project.google.gson.reflect.TypeToken;
import org.mlflow_project.google.protobuf.InvalidProtocolBufferException;
import org.mlflow_project.google.protobuf.util.JsonFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/mlflow/artifacts/CliBasedArtifactRepository.class */
public class CliBasedArtifactRepository implements ArtifactRepository {
    private static final Logger logger = LoggerFactory.getLogger(CliBasedArtifactRepository.class);
    private static final AtomicBoolean mlflowSuccessfullyLoaded = new AtomicBoolean(false);
    private final String PYTHON_EXECUTABLE = (String) Optional.ofNullable(System.getenv("MLFLOW_PYTHON_EXECUTABLE")).orElse("python");
    private final String artifactBaseDir;
    private final String runId;
    private final MlflowHostCredsProvider hostCredsProvider;

    public CliBasedArtifactRepository(String str, String str2, MlflowHostCredsProvider mlflowHostCredsProvider) {
        this.artifactBaseDir = str;
        this.runId = str2;
        this.hostCredsProvider = mlflowHostCredsProvider;
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public void logArtifact(File file, String str) {
        checkMlflowAccessible();
        if (!file.exists()) {
            throw new MlflowClientException("Local file does not exist: " + file);
        }
        if (file.isDirectory()) {
            throw new MlflowClientException("Local path points to a directory. Use logArtifacts instead: " + file);
        }
        forkMlflowProcess(appendRunIdArtifactPath(Lists.newArrayList("log-artifact", "--local-file", file.toString()), this.runId, str), "log file " + file + " to " + getTargetIdentifier(str));
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public void logArtifact(File file) {
        logArtifact(file, null);
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public void logArtifacts(File file, String str) {
        checkMlflowAccessible();
        if (!file.exists()) {
            throw new MlflowClientException("Local file does not exist: " + file);
        }
        if (file.isFile()) {
            throw new MlflowClientException("Local path points to a file. Use logArtifact instead: " + file);
        }
        forkMlflowProcess(appendRunIdArtifactPath(Lists.newArrayList("log-artifacts", "--local-dir", file.toString()), this.runId, str), "log dir " + file + " to " + getTargetIdentifier(str));
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public void logArtifacts(File file) {
        logArtifacts(file, null);
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public File downloadArtifacts(String str) {
        checkMlflowAccessible();
        return new File(forkMlflowProcess(appendRunIdArtifactPath(Lists.newArrayList("download"), this.runId, str), "download artifacts for " + getTargetIdentifier(str)).trim());
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public File downloadArtifacts() {
        return downloadArtifacts(null);
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public List<Service.FileInfo> listArtifacts(String str) {
        checkMlflowAccessible();
        return parseFileInfos(forkMlflowProcess(appendRunIdArtifactPath(Lists.newArrayList("list"), this.runId, str), "list artifacts in " + getTargetIdentifier(str)));
    }

    @Override // org.mlflow.artifacts.ArtifactRepository
    public List<Service.FileInfo> listArtifacts() {
        return listArtifacts(null);
    }

    private List<Service.FileInfo> parseFileInfos(String str) {
        Gson gson = new Gson();
        List list = (List) gson.fromJson(str, new TypeToken<List<Map<String, Object>>>() { // from class: org.mlflow.artifacts.CliBasedArtifactRepository.1
        }.getType());
        ArrayList arrayList = new ArrayList();
        Iterator it = list.iterator();
        while (it.hasNext()) {
            String json = gson.toJson((Map) it.next());
            try {
                Service.FileInfo.Builder newBuilder = Service.FileInfo.newBuilder();
                JsonFormat.parser().merge(json, newBuilder);
                arrayList.add(newBuilder.build());
            } catch (InvalidProtocolBufferException e) {
                throw new MlflowClientException("Failed to deserialize JSON into FileInfo: " + str, e);
            }
        }
        return arrayList;
    }

    private void checkMlflowAccessible() {
        if (mlflowSuccessfullyLoaded.get()) {
            return;
        }
        try {
            forkMlflowProcess(Lists.newArrayList("--help"), "get mlflow version");
            logger.info("Found local mlflow executable");
            mlflowSuccessfullyLoaded.set(true);
        } catch (MlflowClientException e) {
            throw new MlflowClientException(String.format("Failed to exec '%s -m mlflow.store.cli', needed to access artifacts within the non-Java-native artifact store at '%s'. Please make sure mlflow is available on your local system path (e.g., from 'pip install mlflow')", this.PYTHON_EXECUTABLE, this.artifactBaseDir), e);
        }
    }

    private String forkMlflowProcess(List<String> list, String str) {
        Process process = null;
        try {
            MlflowHostCreds hostCreds = this.hostCredsProvider.getHostCreds();
            ArrayList newArrayList = Lists.newArrayList(this.PYTHON_EXECUTABLE, "-m", "mlflow.store.cli");
            newArrayList.addAll(list);
            ProcessBuilder processBuilder = new ProcessBuilder(newArrayList);
            setProcessEnvironment(processBuilder.environment(), hostCreds);
            process = processBuilder.start();
            String iOUtils = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
            if (process.waitFor() != 0) {
                throw new MlflowClientException("Failed to " + str + ". Error: " + getErrorBestEffort(process));
            }
            return iOUtils;
        } catch (IOException | InterruptedException e) {
            throw new MlflowClientException("Failed to fork mlflow process to " + str + ". Process stderr: " + getErrorBestEffort(process), e);
        }
    }

    @VisibleForTesting
    void setProcessEnvironment(Map<String, String> map, MlflowHostCreds mlflowHostCreds) {
        map.put("MLFLOW_TRACKING_URI", mlflowHostCreds.getHost());
        if (mlflowHostCreds.getUsername() != null) {
            map.put("MLFLOW_TRACKING_USERNAME", mlflowHostCreds.getUsername());
        }
        if (mlflowHostCreds.getPassword() != null) {
            map.put("MLFLOW_TRACKING_PASSWORD", mlflowHostCreds.getPassword());
        }
        if (mlflowHostCreds.getToken() != null) {
            map.put("MLFLOW_TRACKING_TOKEN", mlflowHostCreds.getToken());
        }
        if (mlflowHostCreds.shouldIgnoreTlsVerification()) {
            map.put("MLFLOW_TRACKING_INSECURE_TLS", "true");
        }
    }

    private String getErrorBestEffort(Process process) {
        if (process == null) {
            return "<process not started>";
        }
        try {
            return IOUtils.toString(process.getErrorStream(), StandardCharsets.UTF_8);
        } catch (IOException e) {
            return "<error unknown>";
        }
    }

    private List<String> appendRunIdArtifactPath(List<String> list, String str, String str2) {
        list.add("--run-id");
        list.add(str);
        if (str2 != null) {
            list.add("--artifact-path");
            list.add(str2);
        }
        return list;
    }

    private String getTargetIdentifier(String str) {
        String str2 = "runId=" + this.runId;
        return str != null ? str2 + ", artifactPath=" + str : str2;
    }
}
