package org.deeplearning4j.nn.adapters;

import org.deeplearning4j.nn.api.OutputAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/adapters/ArgmaxAdapter.class */
public class ArgmaxAdapter implements OutputAdapter<int[]> {
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.nn.api.OutputAdapter
    public int[] apply(INDArray... iNDArrayArr) {
        Preconditions.checkArgument(iNDArrayArr.length == 1, "Argmax adapter can have only 1 output");
        INDArray iNDArray = iNDArrayArr[0];
        Preconditions.checkArgument(iNDArray.rank() < 3, "Argmax adapter requires 2D or 1D output");
        int[] iArr = iNDArray.rank() == 2 ? new int[(int) iNDArray.size(0)] : new int[1];
        if (iNDArray.rank() == 2) {
            INDArray argMax = Nd4j.argMax(iNDArray, 1);
            for (int i = 0; i < argMax.length(); i++) {
                iArr[i] = (int) argMax.getDouble(i);
            }
        } else {
            iArr[0] = (int) Nd4j.argMax(iNDArray, Integer.MAX_VALUE).getDouble(0L);
        }
        return iArr;
    }
}
