package org.deeplearning4j.rl4j.network;

import lombok.NonNull;
import org.apache.commons.collections4.map.HashedMap;
import org.deeplearning4j.rl4j.agent.learning.update.Features;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper.class */
public class ChannelToNetworkInputMapper {
    private final IdxBinding[] networkInputsToChannelNameMap;
    private final int inputCount;

    /* loaded from: input_file:org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper$IdxBinding.class */
    private static class IdxBinding {
        int networkInputIdx;
        int channelIdx;

        public IdxBinding(int i, int i2) {
            this.networkInputIdx = i;
            this.channelIdx = i2;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper$NetworkInputToChannelBinding.class */
    public static class NetworkInputToChannelBinding {
        private String networkInputName;
        private String channelName;

        public static NetworkInputToChannelBinding map(String str, String str2) {
            return new NetworkInputToChannelBinding(str, str2);
        }

        public NetworkInputToChannelBinding(String str, String str2) {
            this.networkInputName = str;
            this.channelName = str2;
        }

        public String getNetworkInputName() {
            return this.networkInputName;
        }

        public String getChannelName() {
            return this.channelName;
        }
    }

    public ChannelToNetworkInputMapper(@NonNull NetworkInputToChannelBinding[] networkInputToChannelBindingArr, String[] strArr, String[] strArr2) {
        if (networkInputToChannelBindingArr == null) {
            throw new NullPointerException("networkInputsToChannelNameMap is marked non-null but is null");
        }
        Preconditions.checkArgument(networkInputToChannelBindingArr.length > 0, "networkInputsToChannelNameMap is empty.");
        Preconditions.checkArgument(strArr.length > 0, "networkInputNames is empty.");
        Preconditions.checkArgument(strArr2.length > 0, "channelNames is empty.");
        int length = strArr.length;
        for (int i = 0; i < length; i++) {
            String str = strArr[i];
            int i2 = 0;
            for (NetworkInputToChannelBinding networkInputToChannelBinding : networkInputToChannelBindingArr) {
                i2 += str == networkInputToChannelBinding.networkInputName ? 1 : 0;
            }
            if (i2 != 1) {
                throw new IllegalArgumentException("All network inputs must be mapped exactly once. Input '" + str + "' is mapped " + i2 + " times.");
            }
        }
        HashedMap hashedMap = new HashedMap();
        for (int i3 = 0; i3 < strArr.length; i3++) {
            hashedMap.put(strArr[i3], Integer.valueOf(i3));
        }
        HashedMap hashedMap2 = new HashedMap();
        for (int i4 = 0; i4 < strArr2.length; i4++) {
            hashedMap2.put(strArr2[i4], Integer.valueOf(i4));
        }
        this.networkInputsToChannelNameMap = new IdxBinding[strArr.length];
        for (int i5 = 0; i5 < networkInputToChannelBindingArr.length; i5++) {
            NetworkInputToChannelBinding networkInputToChannelBinding2 = networkInputToChannelBindingArr[i5];
            Integer num = (Integer) hashedMap.get(networkInputToChannelBinding2.networkInputName);
            if (num == null) {
                throw new IllegalArgumentException("'" + networkInputToChannelBinding2.networkInputName + "' not found in networkInputNames");
            }
            Integer num2 = (Integer) hashedMap2.get(networkInputToChannelBinding2.channelName);
            if (num2 == null) {
                throw new IllegalArgumentException("'" + networkInputToChannelBinding2.channelName + "' not found in channelNames");
            }
            this.networkInputsToChannelNameMap[i5] = new IdxBinding(num.intValue(), num2.intValue());
        }
        this.inputCount = strArr.length;
    }

    public INDArray[] getNetworkInputs(Observation observation) {
        INDArray[] iNDArrayArr = new INDArray[this.inputCount];
        for (IdxBinding idxBinding : this.networkInputsToChannelNameMap) {
            iNDArrayArr[idxBinding.networkInputIdx] = observation.getChannelData(idxBinding.channelIdx);
        }
        return iNDArrayArr;
    }

    public INDArray[] getNetworkInputs(Features features) {
        INDArray[] iNDArrayArr = new INDArray[this.inputCount];
        for (IdxBinding idxBinding : this.networkInputsToChannelNameMap) {
            iNDArrayArr[idxBinding.networkInputIdx] = features.get(idxBinding.channelIdx);
        }
        return iNDArrayArr;
    }
}
