package org.deeplearning4j.rl4j.trainer;

import java.util.function.Predicate;
import lombok.NonNull;
import org.apache.commons.lang3.builder.Builder;
import org.deeplearning4j.rl4j.agent.IAgentLearner;

/* loaded from: input_file:org/deeplearning4j/rl4j/trainer/SyncTrainer.class */
public class SyncTrainer<ACTION> implements ITrainer {
    private final Predicate<SyncTrainer<ACTION>> stoppingCondition;
    private int episodeCount;
    private int stepCount;
    final IAgentLearner<ACTION> agentLearner;

    /* loaded from: input_file:org/deeplearning4j/rl4j/trainer/SyncTrainer$SyncTrainerBuilder.class */
    public static class SyncTrainerBuilder<ACTION> {
        private Builder<IAgentLearner<ACTION>> agentLearnerBuilder;
        private Predicate<SyncTrainer<ACTION>> stoppingCondition;

        SyncTrainerBuilder() {
        }

        public SyncTrainerBuilder<ACTION> agentLearnerBuilder(@NonNull Builder<IAgentLearner<ACTION>> builder) {
            if (builder == null) {
                throw new NullPointerException("agentLearnerBuilder is marked non-null but is null");
            }
            this.agentLearnerBuilder = builder;
            return this;
        }

        public SyncTrainerBuilder<ACTION> stoppingCondition(@NonNull Predicate<SyncTrainer<ACTION>> predicate) {
            if (predicate == null) {
                throw new NullPointerException("stoppingCondition is marked non-null but is null");
            }
            this.stoppingCondition = predicate;
            return this;
        }

        public SyncTrainer<ACTION> build() {
            return new SyncTrainer<>(this.agentLearnerBuilder, this.stoppingCondition);
        }

        public String toString() {
            return "SyncTrainer.SyncTrainerBuilder(agentLearnerBuilder=" + this.agentLearnerBuilder + ", stoppingCondition=" + this.stoppingCondition + ")";
        }
    }

    public SyncTrainer(@NonNull Builder<IAgentLearner<ACTION>> builder, @NonNull Predicate<SyncTrainer<ACTION>> predicate) {
        if (builder == null) {
            throw new NullPointerException("agentLearnerBuilder is marked non-null but is null");
        }
        if (predicate == null) {
            throw new NullPointerException("stoppingCondition is marked non-null but is null");
        }
        this.stoppingCondition = predicate;
        this.agentLearner = (IAgentLearner) builder.build();
    }

    @Override // org.deeplearning4j.rl4j.trainer.ITrainer
    public void train() {
        this.episodeCount = 0;
        this.stepCount = 0;
        while (!this.stoppingCondition.test(this)) {
            this.agentLearner.run();
            this.episodeCount++;
            this.stepCount += this.agentLearner.getEpisodeStepCount();
        }
    }

    public static <ACTION> SyncTrainerBuilder<ACTION> builder() {
        return new SyncTrainerBuilder<>();
    }

    @Override // org.deeplearning4j.rl4j.trainer.ITrainer
    public int getEpisodeCount() {
        return this.episodeCount;
    }

    @Override // org.deeplearning4j.rl4j.trainer.ITrainer
    public int getStepCount() {
        return this.stepCount;
    }

    public IAgentLearner<ACTION> getAgentLearner() {
        return this.agentLearner;
    }
}
