package fi.evolver.ai.spring.provider.replicate;

import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.connector.AbstractConnector;
import fi.evolver.ai.spring.connector.BasicConnector;
import fi.evolver.ai.spring.connector.GenericConnector;
import fi.evolver.ai.spring.image.ImageApi;
import fi.evolver.ai.spring.image.ImageResponse;
import fi.evolver.ai.spring.image.prompt.ImageGenerationPrompt;
import fi.evolver.ai.spring.image.prompt.ImageVariationPrompt;
import fi.evolver.ai.spring.model.Flux;
import fi.evolver.ai.spring.provider.ConditionalOnProviderConfigured;
import fi.evolver.ai.spring.provider.replicate.response.RStatus;
import fi.evolver.ai.spring.provider.replicate.response.ReplicateFluxImageResponse;
import fi.evolver.ai.spring.util.Json;
import fi.evolver.basics.spring.http.LoggingHttpClient;
import fi.evolver.utils.ContextUtils;
import jakarta.annotation.PreDestroy;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;

@AbstractConnector.UseDefaultConnector(BasicConnector.class)
@ConditionalOnProviderConfigured(ReplicateService.class)
@Component
/* loaded from: input_file:fi/evolver/ai/spring/provider/replicate/ReplicateService.class */
public class ReplicateService implements ImageApi {
    public static final String REPLICATE_REQUEST_METHOD = "request-method";
    private final int maxModelVersionAge;
    private final GenericConnector connector;
    private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2);
    private static final Logger LOG = LoggerFactory.getLogger(ReplicateService.class);
    private static final List<String> FLUX_SINGLE_IMAGE_MODELS = List.of(Flux.FLUX_1_1_PRO.name(), Flux.FLUX_1_1_PRO_ULTRA.name());
    private static final Map<String, ModelVersionInfo> latestModelVersions = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:fi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo.class */
    public static final class ModelVersionInfo extends Record {
        private final String version;
        private final LocalDateTime fetchedAt;

        private ModelVersionInfo(String str, LocalDateTime localDateTime) {
            this.version = str;
            this.fetchedAt = localDateTime;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ModelVersionInfo.class), ModelVersionInfo.class, "version;fetchedAt", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->version:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->fetchedAt:Ljava/time/LocalDateTime;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ModelVersionInfo.class), ModelVersionInfo.class, "version;fetchedAt", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->version:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->fetchedAt:Ljava/time/LocalDateTime;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ModelVersionInfo.class, Object.class), ModelVersionInfo.class, "version;fetchedAt", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->version:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/replicate/ReplicateService$ModelVersionInfo;->fetchedAt:Ljava/time/LocalDateTime;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String version() {
            return this.version;
        }

        public LocalDateTime fetchedAt() {
            return this.fetchedAt;
        }
    }

    /* loaded from: input_file:fi/evolver/ai/spring/provider/replicate/ReplicateService$ReplicateRequestMethod.class */
    public enum ReplicateRequestMethod {
        MULTI_IMAGE("multi-image"),
        SINGLE_IMAGE("single-image");

        private final String code;

        ReplicateRequestMethod(String str) {
            this.code = str;
        }

        @JsonValue
        public String getCode() {
            return this.code;
        }

        public static ReplicateRequestMethod of(String str) {
            for (ReplicateRequestMethod replicateRequestMethod : values()) {
                if (replicateRequestMethod.code.equals(str)) {
                    return replicateRequestMethod;
                }
            }
            return null;
        }
    }

    @Autowired
    public ReplicateService(GenericConnector genericConnector, @Value("${evolver.replicate.max-model-version-age-h:24}") int i) {
        this.connector = genericConnector;
        this.maxModelVersionAge = i;
    }

    @PreDestroy
    private void shutdownScheduler() {
        this.scheduler.shutdown();
    }

    @Override // fi.evolver.ai.spring.image.ImageApi
    public ImageResponse send(ImageGenerationPrompt imageGenerationPrompt) {
        Optional<U> map = imageGenerationPrompt.getStringProperty(REPLICATE_REQUEST_METHOD).map(ReplicateRequestMethod::of);
        if (map.isPresent() && ((ReplicateRequestMethod) map.get()).equals(ReplicateRequestMethod.SINGLE_IMAGE)) {
            return sendFluxSingleImageRequest(imageGenerationPrompt);
        }
        if ((!map.isPresent() || !((ReplicateRequestMethod) map.get()).equals(ReplicateRequestMethod.MULTI_IMAGE)) && FLUX_SINGLE_IMAGE_MODELS.contains(imageGenerationPrompt.model().name())) {
            return sendFluxSingleImageRequest(imageGenerationPrompt);
        }
        return sendFluxMultiImageRequest(imageGenerationPrompt);
    }

    private String getOrFetchModelVersion(ImageGenerationPrompt imageGenerationPrompt) {
        LocalDateTime now = LocalDateTime.now();
        String name = imageGenerationPrompt.model().name();
        if (latestModelVersions.containsKey(name)) {
            ModelVersionInfo modelVersionInfo = latestModelVersions.get(name);
            if (modelVersionInfo.fetchedAt().isAfter(now.minusHours(this.maxModelVersionAge))) {
                return modelVersionInfo.version;
            }
        }
        String fetchModelVersion = fetchModelVersion(imageGenerationPrompt);
        latestModelVersions.put(name, new ModelVersionInfo(fetchModelVersion, now));
        return fetchModelVersion;
    }

    private String fetchModelVersion(ImageGenerationPrompt imageGenerationPrompt) {
        JsonNode jsonNode;
        try {
            String str = null;
            JsonNode readTree = Json.OBJECT_MAPPER.readTree(this.connector.builder(ReplicateService.class, imageGenerationPrompt, ImageApi.class).prepareUri("models", imageGenerationPrompt.model().name()).send(new LoggingHttpClient.LogParameters<>("ModelInfoRequest")));
            if (readTree != null && (jsonNode = readTree.get("latest_version")) != null) {
                str = jsonNode.get("id").asText();
            }
            if (str == null) {
                throw new ApiResponseException("Failed Replicate model info request. Could not find model version", new Object[0]);
            }
            return str;
        } catch (Exception e) {
            throw new ApiResponseException(e, "Failed Replicate model info request", new Object[0]);
        }
    }

    private ReplicateFluxImageResponse sendFluxMultiImageRequest(ImageGenerationPrompt imageGenerationPrompt) {
        String generate = ReplicateRequestGenerator.generate(imageGenerationPrompt, getOrFetchModelVersion(imageGenerationPrompt));
        AbstractConnector.ApiRequestBuilder builder = this.connector.builder(ReplicateService.class, imageGenerationPrompt, ImageApi.class);
        try {
            RStatus<List<String>> pollStatus = pollStatus(sendImageRequest(builder.newBuilder().prepareUri("predictions").body(generate)), builder, imageGenerationPrompt.timeout().orElse(DEFAULT_TIMEOUT).toMillis());
            if (!pollStatus.isSuccess()) {
                LOG.error(pollStatus.logs());
                if (pollStatus.isCancelled()) {
                    throw new InterruptedException("Request cancelled");
                }
                if (pollStatus.isFailed()) {
                    throw new ApiResponseException(new Exception(pollStatus.error()), "Failed Replicate image request", new Object[0]);
                }
            }
            return new ReplicateFluxImageResponse(imageGenerationPrompt, pollStatus.output().stream().map(ReplicateService::toUrlNoThrow).filter((v0) -> {
                return Objects.nonNull(v0);
            }).toList(), pollStatus.completedAt().toString());
        } catch (InterruptedException | ExecutionException e) {
            throw new ApiResponseException(e, "Failed Replicate image request", new Object[0]);
        }
    }

    private RStatus<List<String>> pollStatus(RStatus<List<String>> rStatus, AbstractConnector.ApiRequestBuilder apiRequestBuilder, long j) throws InterruptedException, ExecutionException {
        RStatus<List<String>> cancelRequest;
        CompletableFuture completableFuture = new CompletableFuture();
        ContextUtils.Context context = ContextUtils.getContext();
        ScheduledFuture<?> scheduleAtFixedRate = this.scheduler.scheduleAtFixedRate(() -> {
            try {
                ContextUtils.ContextCloser ensureContext = ContextUtils.ensureContext(context);
                try {
                    if (completableFuture.isDone()) {
                        if (ensureContext != null) {
                            ensureContext.close();
                        }
                    } else {
                        RStatus fetchRequestStatus = fetchRequestStatus(apiRequestBuilder.newBuilder(), rStatus.urls().get());
                        if (!fetchRequestStatus.isInProgress()) {
                            completableFuture.complete(fetchRequestStatus);
                        }
                        if (ensureContext != null) {
                            ensureContext.close();
                        }
                    }
                } finally {
                }
            } catch (Exception e) {
                completableFuture.completeExceptionally(e);
            }
        }, 0L, 500L, TimeUnit.MILLISECONDS);
        try {
            try {
                cancelRequest = (RStatus) completableFuture.get(j, TimeUnit.MILLISECONDS);
                scheduleAtFixedRate.cancel(false);
            } catch (TimeoutException e) {
                cancelRequest = cancelRequest(apiRequestBuilder.newBuilder(), rStatus.urls().cancel());
                scheduleAtFixedRate.cancel(false);
            }
            return cancelRequest;
        } catch (Throwable th) {
            scheduleAtFixedRate.cancel(false);
            throw th;
        }
    }

    private ReplicateFluxImageResponse sendFluxSingleImageRequest(ImageGenerationPrompt imageGenerationPrompt) {
        String generate = ReplicateRequestGenerator.generate(imageGenerationPrompt, getOrFetchModelVersion(imageGenerationPrompt));
        AbstractConnector.ApiRequestBuilder builder = this.connector.builder(ReplicateService.class, imageGenerationPrompt, ImageApi.class);
        AbstractConnector.ApiRequestBuilder body = builder.newBuilder().prepareUri("predictions").body(generate);
        try {
            int intValue = imageGenerationPrompt.getIntProperty(ReplicateRequestParameters.NUM_OUTPUTS).orElse(1).intValue();
            HashMap hashMap = new HashMap();
            for (int i = 0; i < intValue; i++) {
                hashMap.put(Integer.valueOf(i), sendImageRequest(body));
            }
            pollStatuses(hashMap, builder, imageGenerationPrompt.timeout().orElse(DEFAULT_TIMEOUT).toMillis());
            hashMap.values().stream().filter((v0) -> {
                return v0.isFailed();
            }).forEach(rStatus -> {
                LOG.error(rStatus.logs());
            });
            if (hashMap.values().stream().noneMatch((v0) -> {
                return v0.isSuccess();
            })) {
                throw new ApiResponseException("Failed Replicate image request", new Object[0]);
            }
            return new ReplicateFluxImageResponse(imageGenerationPrompt, hashMap.values().stream().filter((v0) -> {
                return v0.isSuccess();
            }).map((v0) -> {
                return v0.output();
            }).map(ReplicateService::toUrlNoThrow).filter((v0) -> {
                return Objects.nonNull(v0);
            }).toList(), hashMap.values().stream().map((v0) -> {
                return v0.completedAt();
            }).max((v0, v1) -> {
                return v0.compareTo(v1);
            }).toString());
        } catch (InterruptedException | ExecutionException e) {
            throw new ApiResponseException(e, "Failed Replicate image request", new Object[0]);
        }
    }

    private void pollStatuses(Map<Integer, RStatus<String>> map, AbstractConnector.ApiRequestBuilder apiRequestBuilder, long j) throws InterruptedException, ExecutionException {
        CompletableFuture completableFuture = new CompletableFuture();
        ContextUtils.Context context = ContextUtils.getContext();
        ScheduledFuture<?> scheduleAtFixedRate = this.scheduler.scheduleAtFixedRate(() -> {
            try {
                ContextUtils.ContextCloser ensureContext = ContextUtils.ensureContext(context);
                try {
                    if (completableFuture.isDone()) {
                        if (ensureContext != null) {
                            ensureContext.close();
                            return;
                        }
                        return;
                    }
                    for (Map.Entry entry : map.entrySet().stream().filter(entry2 -> {
                        return ((RStatus) entry2.getValue()).isInProgress();
                    }).toList()) {
                        map.put((Integer) entry.getKey(), fetchRequestStatus(apiRequestBuilder.newBuilder(), ((RStatus) entry.getValue()).urls().get()));
                    }
                    if (map.entrySet().stream().noneMatch(entry3 -> {
                        return ((RStatus) entry3.getValue()).isInProgress();
                    })) {
                        completableFuture.complete(null);
                    }
                    if (ensureContext != null) {
                        ensureContext.close();
                    }
                } finally {
                }
            } catch (Exception e) {
                completableFuture.completeExceptionally(e);
            }
        }, 0L, 500L, TimeUnit.MILLISECONDS);
        try {
            try {
                completableFuture.get(j, TimeUnit.MILLISECONDS);
                scheduleAtFixedRate.cancel(false);
            } catch (TimeoutException e) {
                for (Map.Entry<Integer, RStatus<String>> entry : map.entrySet().stream().filter(entry2 -> {
                    return ((RStatus) entry2.getValue()).isInProgress();
                }).toList()) {
                    map.put(entry.getKey(), cancelRequest(apiRequestBuilder.newBuilder(), entry.getValue().urls().cancel()));
                }
                scheduleAtFixedRate.cancel(false);
            }
        } catch (Throwable th) {
            scheduleAtFixedRate.cancel(false);
            throw th;
        }
    }

    private static <T> RStatus<T> sendImageRequest(AbstractConnector.ApiRequestBuilder apiRequestBuilder) {
        try {
            return (RStatus) Json.OBJECT_MAPPER.readValue(apiRequestBuilder.send(new LoggingHttpClient.LogParameters<>("ImageGenerationRequest")), new TypeReference<RStatus<T>>() { // from class: fi.evolver.ai.spring.provider.replicate.ReplicateService.1
            });
        } catch (IOException e) {
            throw new ApiResponseException(e, "Failed Replicate image request", new Object[0]);
        }
    }

    private static <T> RStatus<T> fetchRequestStatus(AbstractConnector.ApiRequestBuilder apiRequestBuilder, URI uri) {
        try {
            return (RStatus) Json.OBJECT_MAPPER.readValue(apiRequestBuilder.method(HttpMethod.GET).overrideUri(uri).send(new LoggingHttpClient.LogParameters<>("ImageGenerationStatusRequest")), new TypeReference<RStatus<T>>() { // from class: fi.evolver.ai.spring.provider.replicate.ReplicateService.2
            });
        } catch (IOException e) {
            throw new ApiResponseException(e, "Failed Replicate status request", new Object[0]);
        }
    }

    private static <T> RStatus<T> cancelRequest(AbstractConnector.ApiRequestBuilder apiRequestBuilder, URI uri) {
        try {
            return (RStatus) Json.OBJECT_MAPPER.readValue(apiRequestBuilder.method(HttpMethod.GET).overrideUri(uri).send(new LoggingHttpClient.LogParameters<>("ImageCancelRequest")), new TypeReference<RStatus<T>>() { // from class: fi.evolver.ai.spring.provider.replicate.ReplicateService.3
            });
        } catch (IOException e) {
            throw new ApiResponseException(e, "Failed Replicate cancellation request", new Object[0]);
        }
    }

    private static URL toUrlNoThrow(String str) {
        try {
            return new URL(str);
        } catch (MalformedURLException e) {
            return null;
        }
    }

    @Override // fi.evolver.ai.spring.image.ImageApi
    public ImageResponse send(ImageVariationPrompt imageVariationPrompt) {
        throw new UnsupportedOperationException("Unsupported method 'send ImageVariationPrompt'");
    }
}
