package org.deeplearning4j.rl4j.trainer;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import lombok.NonNull;
import org.apache.commons.lang3.builder.Builder;
import org.deeplearning4j.rl4j.agent.IAgentLearner;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/trainer/AsyncTrainer.class */
public class AsyncTrainer<ACTION> implements ITrainer {
    private final Builder<IAgentLearner<ACTION>> agentLearnerBuilder;
    private final Predicate<AsyncTrainer<ACTION>> stoppingCondition;
    private final int numThreads;
    private final AtomicInteger episodeCount = new AtomicInteger();
    private final AtomicInteger stepCount = new AtomicInteger();
    private boolean shouldStop = false;

    /* loaded from: input_file:org/deeplearning4j/rl4j/trainer/AsyncTrainer$AgentLearnerThread.class */
    private class AgentLearnerThread extends Thread {
        private final IAgentLearner<ACTION> agentLearner;
        private final int deviceNum;

        public AgentLearnerThread(IAgentLearner<ACTION> iAgentLearner, int i) {
            this.agentLearner = iAgentLearner;
            this.deviceNum = i;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(this.deviceNum));
            while (!AsyncTrainer.this.shouldStop) {
                this.agentLearner.run();
                AsyncTrainer.this.onEpisodeEnded(this.agentLearner.getEpisodeStepCount());
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/trainer/AsyncTrainer$AsyncTrainerBuilder.class */
    public static class AsyncTrainerBuilder<ACTION> {
        private Builder<IAgentLearner<ACTION>> agentLearnerBuilder;
        private Predicate<AsyncTrainer<ACTION>> stoppingCondition;
        private int numThreads;

        AsyncTrainerBuilder() {
        }

        public AsyncTrainerBuilder<ACTION> agentLearnerBuilder(@NonNull Builder<IAgentLearner<ACTION>> builder) {
            if (builder == null) {
                throw new NullPointerException("agentLearnerBuilder is marked non-null but is null");
            }
            this.agentLearnerBuilder = builder;
            return this;
        }

        public AsyncTrainerBuilder<ACTION> stoppingCondition(@NonNull Predicate<AsyncTrainer<ACTION>> predicate) {
            if (predicate == null) {
                throw new NullPointerException("stoppingCondition is marked non-null but is null");
            }
            this.stoppingCondition = predicate;
            return this;
        }

        public AsyncTrainerBuilder<ACTION> numThreads(int i) {
            this.numThreads = i;
            return this;
        }

        public AsyncTrainer<ACTION> build() {
            return new AsyncTrainer<>(this.agentLearnerBuilder, this.stoppingCondition, this.numThreads);
        }

        public String toString() {
            return "AsyncTrainer.AsyncTrainerBuilder(agentLearnerBuilder=" + this.agentLearnerBuilder + ", stoppingCondition=" + this.stoppingCondition + ", numThreads=" + this.numThreads + ")";
        }
    }

    public AsyncTrainer(@NonNull Builder<IAgentLearner<ACTION>> builder, @NonNull Predicate<AsyncTrainer<ACTION>> predicate, int i) {
        if (builder == null) {
            throw new NullPointerException("agentLearnerBuilder is marked non-null but is null");
        }
        if (predicate == null) {
            throw new NullPointerException("stoppingCondition is marked non-null but is null");
        }
        Preconditions.checkArgument(i > 0, "numThreads must be greater than 0, got: ", i);
        this.agentLearnerBuilder = builder;
        this.stoppingCondition = predicate;
        this.numThreads = i;
    }

    @Override // org.deeplearning4j.rl4j.trainer.ITrainer
    public void train() {
        reset();
        Thread[] threadArr = new Thread[this.numThreads];
        for (int i = 0; i < this.numThreads; i++) {
            AgentLearnerThread agentLearnerThread = new AgentLearnerThread((IAgentLearner) this.agentLearnerBuilder.build(), i);
            threadArr[i] = agentLearnerThread;
            agentLearnerThread.start();
        }
        for (Thread thread : threadArr) {
            try {
                thread.join();
            } catch (InterruptedException e) {
            }
        }
    }

    private void reset() {
        this.episodeCount.set(0);
        this.stepCount.set(0);
        this.shouldStop = false;
    }

    @Override // org.deeplearning4j.rl4j.trainer.ITrainer
    public int getEpisodeCount() {
        return this.episodeCount.get();
    }

    @Override // org.deeplearning4j.rl4j.trainer.ITrainer
    public int getStepCount() {
        return this.stepCount.get();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void onEpisodeEnded(int i) {
        this.episodeCount.incrementAndGet();
        this.stepCount.addAndGet(i);
        if (this.stoppingCondition.test(this)) {
            this.shouldStop = true;
        }
    }

    public static <ACTION> AsyncTrainerBuilder<ACTION> builder() {
        return new AsyncTrainerBuilder<>();
    }
}
