package de.rub.nds.tlsattacker.core.workflow;

import de.rub.nds.tlsattacker.core.state.State;
import de.rub.nds.tlsattacker.core.workflow.task.ITask;
import de.rub.nds.tlsattacker.core.workflow.task.StateExecutionTask;
import de.rub.nds.tlsattacker.core.workflow.task.TlsTask;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/rub/nds/tlsattacker/core/workflow/ParallelExecutor.class */
public class ParallelExecutor {
    private static final Logger LOGGER = LogManager.getLogger();
    private final ThreadPoolExecutor executorService;
    private Callable<Integer> timeoutAction;
    private final int size;
    private boolean shouldShutdown;
    private final int reexecutions;
    private Function<State, Integer> defaultBeforeTransportPreInitCallback;
    private Function<State, Integer> defaultBeforeTransportInitCallback;
    private Function<State, Integer> defaultAfterTransportInitCallback;
    private Function<State, Integer> defaultAfterExecutionCallback;

    public ParallelExecutor(int i, int i2, ThreadPoolExecutor threadPoolExecutor) {
        this.shouldShutdown = false;
        this.defaultBeforeTransportPreInitCallback = null;
        this.defaultBeforeTransportInitCallback = null;
        this.defaultAfterTransportInitCallback = null;
        this.defaultAfterExecutionCallback = null;
        this.executorService = threadPoolExecutor;
        this.reexecutions = i2;
        this.size = i;
        if (i2 < 0) {
            throw new IllegalArgumentException("Reexecutions is below zero");
        }
    }

    public ParallelExecutor(ThreadPoolExecutor threadPoolExecutor, int i) {
        this(-1, i, threadPoolExecutor);
    }

    public ParallelExecutor(int i, int i2) {
        this(i, i2, new ThreadPoolExecutor(i, i, 10L, TimeUnit.DAYS, new LinkedBlockingDeque()));
    }

    public ParallelExecutor(int i, int i2, ThreadFactory threadFactory) {
        this(i, i2, new ThreadPoolExecutor(i, i, 5L, TimeUnit.MINUTES, new LinkedBlockingDeque(), threadFactory));
    }

    private Future<ITask> addTask(TlsTask tlsTask) {
        if (this.executorService.isShutdown()) {
            throw new RuntimeException("Cannot add Tasks to already shutdown executor");
        }
        if (this.defaultBeforeTransportPreInitCallback != null && tlsTask.getBeforeTransportPreInitCallback() == null) {
            tlsTask.setBeforeTransportPreInitCallback(this.defaultBeforeTransportPreInitCallback);
        }
        if (this.defaultBeforeTransportInitCallback != null && tlsTask.getBeforeTransportInitCallback() == null) {
            tlsTask.setBeforeTransportInitCallback(this.defaultBeforeTransportInitCallback);
        }
        if (this.defaultAfterTransportInitCallback != null && tlsTask.getAfterTransportInitCallback() == null) {
            tlsTask.setAfterTransportInitCallback(this.defaultAfterTransportInitCallback);
        }
        if (this.defaultAfterExecutionCallback != null && tlsTask.getAfterExecutionCallback() == null) {
            tlsTask.setAfterExecutionCallback(this.defaultAfterExecutionCallback);
        }
        return this.executorService.submit(tlsTask);
    }

    private Future<ITask> addStateTask(State state) {
        return addTask(new StateExecutionTask(state, this.reexecutions));
    }

    public void bulkExecuteStateTasks(Iterable<State> iterable) {
        LinkedList linkedList = new LinkedList();
        Iterator<State> it = iterable.iterator();
        while (it.hasNext()) {
            linkedList.add(addStateTask(it.next()));
        }
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            try {
                ((Future) it2.next()).get();
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException("Failed to execute tasks!", e);
            }
        }
    }

    public void bulkExecuteStateTasks(State... stateArr) {
        bulkExecuteStateTasks(new ArrayList(Arrays.asList(stateArr)));
    }

    public List<ITask> bulkExecuteTasks(Iterable<TlsTask> iterable) {
        LinkedList linkedList = new LinkedList();
        ArrayList arrayList = new ArrayList(linkedList.size());
        Iterator<TlsTask> it = iterable.iterator();
        while (it.hasNext()) {
            linkedList.add(addTask(it.next()));
        }
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            try {
                arrayList.add((ITask) ((Future) it2.next()).get());
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException("Failed to execute tasks!", e);
            }
        }
        return arrayList;
    }

    public List<ITask> bulkExecuteTasks(TlsTask... tlsTaskArr) {
        return bulkExecuteTasks(new ArrayList(Arrays.asList(tlsTaskArr)));
    }

    public int getSize() {
        return this.size;
    }

    public void shutdown() {
        this.shouldShutdown = true;
        this.executorService.shutdown();
    }

    public void armTimeoutAction(int i) {
        if (this.timeoutAction == null) {
            LOGGER.warn("No TimeoutAction set, this won't do anything");
        } else {
            new Thread(() -> {
                monitorExecution(i);
            }).start();
        }
    }

    private void monitorExecution(int i) {
        long currentTimeMillis = System.currentTimeMillis() + i;
        long j = 0;
        while (!this.shouldShutdown) {
            long completedTaskCount = this.executorService.getCompletedTaskCount();
            if (this.executorService.getActiveCount() == 0 || completedTaskCount != j) {
                currentTimeMillis = System.currentTimeMillis() + i;
                j = completedTaskCount;
            } else if (System.currentTimeMillis() > currentTimeMillis) {
                LOGGER.debug("Timeout");
                try {
                    int intValue = this.timeoutAction.call().intValue();
                    if (intValue != 0) {
                        throw new RuntimeException("TimeoutAction did terminate with code " + intValue);
                        break;
                    }
                    currentTimeMillis = System.currentTimeMillis() + i;
                } catch (Exception e) {
                    LOGGER.warn("TimeoutAction did not succeed", e);
                }
            } else {
                continue;
            }
        }
    }

    public int getReexecutions() {
        return this.reexecutions;
    }

    public Callable<Integer> getTimeoutAction() {
        return this.timeoutAction;
    }

    public void setTimeoutAction(Callable<Integer> callable) {
        this.timeoutAction = callable;
    }

    public Function<State, Integer> getDefaultBeforeTransportPreInitCallback() {
        return this.defaultBeforeTransportPreInitCallback;
    }

    public void setDefaultBeforeTransportPreInitCallback(Function<State, Integer> function) {
        this.defaultBeforeTransportPreInitCallback = function;
    }

    public Function<State, Integer> getDefaultBeforeTransportInitCallback() {
        return this.defaultBeforeTransportInitCallback;
    }

    public void setDefaultBeforeTransportInitCallback(Function<State, Integer> function) {
        this.defaultBeforeTransportInitCallback = function;
    }

    public Function<State, Integer> getDefaultAfterTransportInitCallback() {
        return this.defaultAfterTransportInitCallback;
    }

    public void setDefaultAfterTransportInitCallback(Function<State, Integer> function) {
        this.defaultAfterTransportInitCallback = function;
    }

    public Function<State, Integer> getDefaultAfterExecutionCallback() {
        return this.defaultAfterExecutionCallback;
    }

    public void setDefaultAfterExecutionCallback(Function<State, Integer> function) {
        this.defaultAfterExecutionCallback = function;
    }
}
