package org.deeplearning4j.rl4j.agent.learning.algorithm.dqn;

import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.learning.algorithm.IUpdateAlgorithm;
import org.deeplearning4j.rl4j.agent.learning.update.Features;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesBuilder;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels;
import org.deeplearning4j.rl4j.experience.StateActionRewardState;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm.class */
public abstract class BaseTransitionTDAlgorithm implements IUpdateAlgorithm<FeaturesLabels, StateActionRewardState<Integer>> {
    protected final IOutputNeuralNet qNetwork;
    protected final double gamma;
    private final double errorClamp;
    private final boolean isClamped;
    private final FeaturesBuilder featuresBuilder;

    /* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm$Configuration.class */
    public static class Configuration {
        double gamma;
        double errorClamp;

        /* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> {
            private boolean gamma$set;
            private double gamma$value;
            private boolean errorClamp$set;
            private double errorClamp$value;

            protected abstract B self();

            public abstract C build();

            public B gamma(double d) {
                this.gamma$value = d;
                this.gamma$set = true;
                return self();
            }

            public B errorClamp(double d) {
                this.errorClamp$value = d;
                this.errorClamp$set = true;
                return self();
            }

            public String toString() {
                return "BaseTransitionTDAlgorithm.Configuration.ConfigurationBuilder(gamma$value=" + this.gamma$value + ", errorClamp$value=" + this.errorClamp$value + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        private static double $default$gamma() {
            return 0.99d;
        }

        private static double $default$errorClamp() {
            return Double.NaN;
        }

        protected Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            if (((ConfigurationBuilder) configurationBuilder).gamma$set) {
                this.gamma = ((ConfigurationBuilder) configurationBuilder).gamma$value;
            } else {
                this.gamma = $default$gamma();
            }
            if (((ConfigurationBuilder) configurationBuilder).errorClamp$set) {
                this.errorClamp = ((ConfigurationBuilder) configurationBuilder).errorClamp$value;
            } else {
                this.errorClamp = $default$errorClamp();
            }
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        public double getGamma() {
            return this.gamma;
        }

        public double getErrorClamp() {
            return this.errorClamp;
        }

        public void setGamma(double d) {
            this.gamma = d;
        }

        public void setErrorClamp(double d) {
            this.errorClamp = d;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            return configuration.canEqual(this) && Double.compare(getGamma(), configuration.getGamma()) == 0 && Double.compare(getErrorClamp(), configuration.getErrorClamp()) == 0;
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        public int hashCode() {
            long doubleToLongBits = Double.doubleToLongBits(getGamma());
            int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getErrorClamp());
            return (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        }

        public String toString() {
            return "BaseTransitionTDAlgorithm.Configuration(gamma=" + getGamma() + ", errorClamp=" + getErrorClamp() + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTransitionTDAlgorithm(@NonNull IOutputNeuralNet iOutputNeuralNet, @NonNull Configuration configuration) {
        if (iOutputNeuralNet == null) {
            throw new NullPointerException("qNetwork is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.qNetwork = iOutputNeuralNet;
        this.gamma = configuration.getGamma();
        this.errorClamp = configuration.getErrorClamp();
        this.isClamped = !Double.isNaN(this.errorClamp);
        this.featuresBuilder = new FeaturesBuilder(iOutputNeuralNet.isRecurrent());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initComputation(Features features, Features features2) {
    }

    protected abstract double computeTarget(int i, double d, boolean z);

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.IUpdateAlgorithm
    public FeaturesLabels compute(List<StateActionRewardState<Integer>> list) {
        int size = list.size();
        Features build = this.featuresBuilder.build(list);
        initComputation(build, this.featuresBuilder.build(list.stream().map(stateActionRewardState -> {
            return stateActionRewardState.getNextObservation();
        }), list.size()));
        INDArray iNDArray = this.qNetwork.output(build).get("Q");
        for (int i = 0; i < size; i++) {
            StateActionRewardState<Integer> stateActionRewardState2 = list.get(i);
            double computeTarget = computeTarget(i, stateActionRewardState2.getReward(), stateActionRewardState2.isTerminal());
            if (this.isClamped) {
                double d = iNDArray.getDouble(i, stateActionRewardState2.getAction().intValue());
                computeTarget = Math.min(d + this.errorClamp, Math.max(computeTarget, d - this.errorClamp));
            }
            iNDArray.putScalar(i, stateActionRewardState2.getAction().intValue(), computeTarget);
        }
        FeaturesLabels featuresLabels = new FeaturesLabels(build);
        featuresLabels.putLabels("Q", iNDArray);
        return featuresLabels;
    }
}
