package org.deeplearning4j.rl4j.network;

import lombok.NonNull;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.network.ChannelToNetworkInputMapper;
import org.deeplearning4j.rl4j.network.CommonGradientNames;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/ActorCriticNetwork.class */
public class ActorCriticNetwork extends BaseNetwork<ActorCriticNetwork> {
    private static final String[] LABEL_NAMES = {"value", "policy"};
    private final boolean isCombined;

    /* loaded from: input_file:org/deeplearning4j/rl4j/network/ActorCriticNetwork$Builder.class */
    public static class Builder {
        private final NetworkHelper networkHelper = new NetworkHelper();
        private boolean isCombined;
        private ComputationGraph combinedNetwork;
        private ComputationGraph cgValueNetwork;
        private MultiLayerNetwork mlnValueNetwork;
        private ComputationGraph cgPolicyNetwork;
        private MultiLayerNetwork mlnPolicyNetwork;
        private String inputChannelName;
        private String[] channelNames;
        private ChannelToNetworkInputMapper.NetworkInputToChannelBinding[] networkInputsToFeatureBindings;

        public Builder withCombinedNetwork(@NonNull ComputationGraph computationGraph) {
            if (computationGraph == null) {
                throw new NullPointerException("combinedNetwork is marked non-null but is null");
            }
            this.isCombined = true;
            this.combinedNetwork = computationGraph;
            return this;
        }

        public Builder withSeparateNetworks(@NonNull ComputationGraph computationGraph, @NonNull ComputationGraph computationGraph2) {
            if (computationGraph == null) {
                throw new NullPointerException("valueNetwork is marked non-null but is null");
            }
            if (computationGraph2 == null) {
                throw new NullPointerException("policyNetwork is marked non-null but is null");
            }
            this.cgValueNetwork = computationGraph;
            this.cgPolicyNetwork = computationGraph2;
            this.isCombined = false;
            return this;
        }

        public Builder withSeparateNetworks(@NonNull MultiLayerNetwork multiLayerNetwork, @NonNull ComputationGraph computationGraph) {
            if (multiLayerNetwork == null) {
                throw new NullPointerException("valueNetwork is marked non-null but is null");
            }
            if (computationGraph == null) {
                throw new NullPointerException("policyNetwork is marked non-null but is null");
            }
            this.mlnValueNetwork = multiLayerNetwork;
            this.cgPolicyNetwork = computationGraph;
            this.isCombined = false;
            return this;
        }

        public Builder withSeparateNetworks(@NonNull ComputationGraph computationGraph, @NonNull MultiLayerNetwork multiLayerNetwork) {
            if (computationGraph == null) {
                throw new NullPointerException("valueNetwork is marked non-null but is null");
            }
            if (multiLayerNetwork == null) {
                throw new NullPointerException("policyNetwork is marked non-null but is null");
            }
            this.cgValueNetwork = computationGraph;
            this.mlnPolicyNetwork = multiLayerNetwork;
            this.isCombined = false;
            return this;
        }

        public Builder withSeparateNetworks(@NonNull MultiLayerNetwork multiLayerNetwork, @NonNull MultiLayerNetwork multiLayerNetwork2) {
            if (multiLayerNetwork == null) {
                throw new NullPointerException("valueNetwork is marked non-null but is null");
            }
            if (multiLayerNetwork2 == null) {
                throw new NullPointerException("policyNetwork is marked non-null but is null");
            }
            this.mlnValueNetwork = multiLayerNetwork;
            this.mlnPolicyNetwork = multiLayerNetwork2;
            this.isCombined = false;
            return this;
        }

        public Builder inputBindings(ChannelToNetworkInputMapper.NetworkInputToChannelBinding[] networkInputToChannelBindingArr) {
            this.networkInputsToFeatureBindings = networkInputToChannelBindingArr;
            return this;
        }

        public Builder specificBinding(String str) {
            this.inputChannelName = str;
            return this;
        }

        public Builder channelNames(String[] strArr) {
            this.channelNames = strArr;
            return this;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v48, types: [org.deeplearning4j.rl4j.network.INetworkHandler] */
        /* JADX WARN: Type inference failed for: r0v52, types: [org.deeplearning4j.rl4j.network.INetworkHandler] */
        public ActorCriticNetwork build() {
            INetworkHandler buildHandler;
            INetworkHandler buildHandler2;
            CompoundNetworkHandler compoundNetworkHandler;
            Preconditions.checkState(this.combinedNetwork != null || ((this.cgValueNetwork != null || this.mlnValueNetwork != null) && (this.cgPolicyNetwork != null || this.mlnPolicyNetwork != null)), "A network must be set.");
            if (this.isCombined) {
                compoundNetworkHandler = this.networkInputsToFeatureBindings == null ? this.networkHelper.buildHandler(this.combinedNetwork, this.inputChannelName, this.channelNames, ActorCriticNetwork.LABEL_NAMES, CommonGradientNames.ActorCritic.Combined) : this.networkHelper.buildHandler(this.combinedNetwork, this.networkInputsToFeatureBindings, this.channelNames, ActorCriticNetwork.LABEL_NAMES, CommonGradientNames.ActorCritic.Combined);
            } else {
                if (this.cgValueNetwork != null) {
                    buildHandler = this.networkInputsToFeatureBindings == null ? this.networkHelper.buildHandler(this.cgValueNetwork, this.inputChannelName, this.channelNames, new String[]{"value"}, "value") : this.networkHelper.buildHandler(this.cgValueNetwork, this.networkInputsToFeatureBindings, this.channelNames, ActorCriticNetwork.LABEL_NAMES, "value");
                } else {
                    buildHandler = this.networkHelper.buildHandler(this.mlnValueNetwork, this.inputChannelName, this.channelNames, "value", "value");
                }
                if (this.cgPolicyNetwork != null) {
                    buildHandler2 = this.networkInputsToFeatureBindings == null ? this.networkHelper.buildHandler(this.cgPolicyNetwork, this.inputChannelName, this.channelNames, new String[]{"policy"}, "policy") : this.networkHelper.buildHandler(this.cgPolicyNetwork, this.networkInputsToFeatureBindings, this.channelNames, ActorCriticNetwork.LABEL_NAMES, "policy");
                } else {
                    buildHandler2 = this.networkHelper.buildHandler(this.mlnPolicyNetwork, this.inputChannelName, this.channelNames, "policy", "policy");
                }
                compoundNetworkHandler = new CompoundNetworkHandler(buildHandler, buildHandler2);
            }
            return new ActorCriticNetwork(compoundNetworkHandler, this.isCombined);
        }
    }

    private ActorCriticNetwork(INetworkHandler iNetworkHandler, boolean z) {
        super(iNetworkHandler);
        this.isCombined = z;
    }

    @Override // org.deeplearning4j.rl4j.network.BaseNetwork
    protected NeuralNetOutput packageResult(INDArray[] iNDArrayArr) {
        NeuralNetOutput neuralNetOutput = new NeuralNetOutput();
        neuralNetOutput.put("value", iNDArrayArr[0]);
        neuralNetOutput.put("policy", iNDArrayArr[1]);
        return neuralNetOutput;
    }

    @Override // org.deeplearning4j.rl4j.network.BaseNetwork
    /* renamed from: clone */
    public ActorCriticNetwork mo26clone() {
        return new ActorCriticNetwork(getNetworkHandler().m27clone(), this.isCombined);
    }

    public static Builder builder() {
        return new Builder();
    }
}
