package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;

import java.util.ArrayList;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.class */
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
    private final QLearning.QLConfiguration configuration;
    private final DataManager dataManager;
    private final MDP<O, Integer, DiscreteSpace> mdp;
    private final IDQN currentDQN;
    private DQNPolicy<O> policy;
    private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
    private IDQN targetDQN;
    private int lastAction;
    private INDArray[] history;
    private double accuReward;
    private int lastMonitor;

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN idqn, QLearning.QLConfiguration qLConfiguration, DataManager dataManager, int i) {
        super(qLConfiguration);
        this.history = null;
        this.accuReward = 0.0d;
        this.lastMonitor = -10000;
        this.configuration = qLConfiguration;
        this.mdp = mdp;
        this.dataManager = dataManager;
        this.currentDQN = idqn;
        this.targetDQN = idqn.m8clone();
        this.policy = new DQNPolicy<>(getCurrentDQN());
        this.egPolicy = new EpsGreedy<>(this.policy, mdp, qLConfiguration.getUpdateStart(), i, getRandom(), qLConfiguration.getMinEpsilon(), this);
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void postEpoch() {
        if (getHistoryProcessor() != null) {
            getHistoryProcessor().stopMonitor();
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void preEpoch() {
        this.history = null;
        this.lastAction = 0;
        this.accuReward = 0.0d;
        if (getStepCounter() - this.lastMonitor < 10000 || getHistoryProcessor() == null || !getDataManager().isSaveData()) {
            return;
        }
        this.lastMonitor = getStepCounter();
        getHistoryProcessor().startMonitor(getDataManager().getVideoDir() + "/video-" + getEpochCounter() + "-" + getStepCounter() + ".mp4");
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected QLearning.QLStepReturn<O> trainStep(O o) {
        Integer num;
        INDArray input = getInput(o);
        boolean z = getHistoryProcessor() != null;
        if (z) {
            getHistoryProcessor().record(input);
        }
        int skipFrame = z ? getHistoryProcessor().getConf().getSkipFrame() : 1;
        int updateStart = getConfiguration().getUpdateStart() + ((getConfiguration().getBatchSize() + (z ? getHistoryProcessor().getConf().getHistoryLength() : 1)) * skipFrame);
        Double valueOf = Double.valueOf(Double.NaN);
        if (getStepCounter() % skipFrame != 0) {
            num = Integer.valueOf(this.lastAction);
        } else {
            if (this.history == null) {
                if (z) {
                    getHistoryProcessor().add(input);
                    this.history = getHistoryProcessor().getHistory();
                } else {
                    this.history = new INDArray[]{input};
                }
            }
            INDArray concat = Transition.concat(Transition.dup(this.history));
            if (concat.shape().length > 2) {
                concat = concat.reshape(Learning.makeShape(1, concat.shape()));
            }
            INDArray output = getCurrentDQN().output(concat);
            valueOf = Double.valueOf(output.getDouble(Learning.getMaxAction(output).intValue()));
            num = (Integer) getEgPolicy().nextAction(concat);
        }
        this.lastAction = num.intValue();
        StepReply step = getMdp().step(num);
        this.accuReward += step.getReward() * this.configuration.getRewardFactor();
        if (getStepCounter() % skipFrame == 0 || step.isDone()) {
            INDArray input2 = getInput((Encodable) step.getObservation());
            if (z) {
                getHistoryProcessor().add(input2);
            }
            INDArray[] history = z ? getHistoryProcessor().getHistory() : new INDArray[]{input2};
            getExpReplay().store(new Transition(this.history, num, this.accuReward, step.isDone(), history[0]));
            if (getStepCounter() > updateStart) {
                Pair<INDArray, INDArray> target = setTarget(getExpReplay().getBatch());
                getCurrentDQN().fit((INDArray) target.getFirst(), (INDArray) target.getSecond());
            }
            this.history = history;
            this.accuReward = 0.0d;
        }
        return new QLearning.QLStepReturn<>(valueOf, getCurrentDQN().getLatestScore(), step);
    }

    protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> arrayList) {
        if (arrayList.size() == 0) {
            throw new IllegalArgumentException("too few transitions");
        }
        int size = arrayList.size();
        int[] makeShape = makeShape(size, getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape());
        INDArray create = Nd4j.create(makeShape);
        INDArray create2 = Nd4j.create(makeShape);
        int[] iArr = new int[size];
        boolean[] zArr = new boolean[size];
        for (int i = 0; i < size; i++) {
            Transition<Integer> transition = arrayList.get(i);
            zArr[i] = transition.isTerminal();
            iArr[i] = transition.getAction().intValue();
            create.putRow(i, Transition.concat(transition.getObservation()));
            create2.putRow(i, Transition.concat(Transition.append(transition.getObservation(), transition.getNextObservation())));
        }
        INDArray dqnOutput = dqnOutput(create);
        INDArray dqnOutput2 = dqnOutput(create2);
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        INDArray iNDArray3 = null;
        if (getConfiguration().isDoubleDQN()) {
            iNDArray = targetDqnOutput(create2);
            iNDArray3 = Nd4j.argMax(dqnOutput2, new int[]{1});
        } else {
            iNDArray2 = Nd4j.max(dqnOutput2, 1);
        }
        for (int i2 = 0; i2 < size; i2++) {
            double reward = arrayList.get(i2).getReward();
            if (!zArr[i2]) {
                reward += getConfiguration().getGamma() * (getConfiguration().isDoubleDQN() ? 0.0d + iNDArray.getDouble(i2, iNDArray3.getInt(new int[]{i2})) : 0.0d + iNDArray2.getDouble(i2));
            }
            double d = dqnOutput.getDouble(i2, iArr[i2]);
            dqnOutput.putScalar(i2, iArr[i2], Math.min(d + getConfiguration().getErrorClamp(), Math.max(reward, d - getConfiguration().getErrorClamp())));
        }
        return new Pair<>(create, dqnOutput);
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public QLearning.QLConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override // org.deeplearning4j.rl4j.learning.Learning
    public DataManager getDataManager() {
        return this.dataManager;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public IDQN getCurrentDQN() {
        return this.currentDQN;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public DQNPolicy<O> getPolicy() {
        return this.policy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public EpsGreedy<O, Integer, DiscreteSpace> getEgPolicy() {
        return this.egPolicy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public IDQN getTargetDQN() {
        return this.targetDQN;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public void setTargetDQN(IDQN idqn) {
        this.targetDQN = idqn;
    }
}
