package org.nd4j.autodiff.validation;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/validation/GradCheckUtil.class */
public class GradCheckUtil {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) GradCheckUtil.class);
    private static final boolean DEFAULT_PRINT = true;
    private static final boolean DEFAULT_EXIT_FIRST_FAILURE = false;
    private static final boolean DEFAULT_DEBUG_MODE = false;
    private static final double DEFAULT_EPS = 1.0E-5d;
    private static final double DEFAULT_MAX_REL_ERROR = 1.0E-5d;
    private static final double DEFAULT_MIN_ABS_ERROR = 1.0E-6d;

    public static boolean checkGradients(TestCase testCase) {
        return checkGradients(testCase.sameDiff(), testCase.placeholderValues(), testCase.gradCheckEpsilon(), testCase.gradCheckMaxRelativeError(), testCase.gradCheckMinAbsError(), testCase.gradCheckPrint(), testCase.gradCheckDefaultExitFirstFailure(), false, testCase.gradCheckDebugMode(), testCase.gradCheckSkipVariables(), testCase.gradCheckMask());
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, String... strArr) {
        HashSet hashSet = null;
        if (strArr != null) {
            hashSet = new HashSet();
            Collections.addAll(hashSet, strArr);
        }
        return checkGradients(sameDiff, map, 1.0E-5d, 1.0E-5d, 1.0E-6d, true, false, false, false, hashSet, null);
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, boolean z, boolean z2) {
        return checkGradients(sameDiff, map, 1.0E-5d, 1.0E-5d, 1.0E-6d, z, z2);
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, double d, double d2, double d3, boolean z, boolean z2) {
        return checkGradients(sameDiff, map, d, d2, d3, z, z2, false, false, null, null);
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, double d, double d2, double d3, boolean z, boolean z2, boolean z3, boolean z4, Set<String> set, Map<String, INDArray> map2) {
        boolean isDebugMode = sameDiff.isDebugMode();
        if (z4) {
            sameDiff.enableDebugMode();
        }
        if (!z3) {
            validateInternalState(sameDiff, true);
        }
        if (Nd4j.dataType() != DataType.DOUBLE) {
            throw new IllegalStateException("Data type must be set to double");
        }
        HashSet hashSet = new HashSet();
        for (DifferentialFunction differentialFunction : sameDiff.functions()) {
            for (SDVariable sDVariable : differentialFunction.outputVariables()) {
                hashSet.add(sDVariable.getVarName());
            }
        }
        for (Variable variable : sameDiff.getVariables().values()) {
            if (variable.getVariable().getVariableType() != VariableType.ARRAY && variable.getVariable().getArr(true) == null) {
                throw new IllegalStateException("Variable \"" + variable.getName() + "\" does not have array associated with it");
            }
        }
        List<String> lossVariables = sameDiff.getLossVariables();
        Preconditions.checkState((lossVariables == null || lossVariables.isEmpty()) ? false : true, "Expected 1 or more loss function variables for gradient check, got %s", lossVariables);
        HashSet hashSet2 = new HashSet();
        for (Variable variable2 : sameDiff.getVariables().values()) {
            if (variable2.getVariable().dataType().isFPType() && (variable2.getVariable().getVariableType() == VariableType.VARIABLE || variable2.getVariable().getVariableType() == VariableType.PLACEHOLDER)) {
                SDVariable gradient = variable2.getVariable().getGradient();
                Preconditions.checkNotNull(gradient, "No gradient variable found for variable %s", variable2.getVariable());
                hashSet2.add(gradient.getVarName());
            }
        }
        sameDiff.execBackwards(map, new ArrayList(hashSet2));
        HashMap hashMap = new HashMap();
        for (SDVariable sDVariable2 : sameDiff.variables()) {
            if (!hashSet.contains(sDVariable2.getVarName()) && sDVariable2.hasGradient()) {
                SDVariable grad = sameDiff.grad(sDVariable2.getVarName());
                if (grad == null) {
                    throw new IllegalStateException("Null gradient variable for \"" + sDVariable2.getVarName() + "\"");
                }
                INDArray arr = grad.getArr();
                if (arr == null) {
                    throw new IllegalStateException("Null gradient array encountered for variable: " + sDVariable2.getVarName());
                }
                if (!Arrays.equals(sDVariable2.getArr().shape(), grad.getArr().shape())) {
                    throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + sDVariable2.getVarName() + "\": shape " + Arrays.toString(sDVariable2.getArr().shape()) + " vs. gradient shape " + Arrays.toString(arr.shape()));
                }
                hashMap.put(sDVariable2.getVarName(), arr.dup());
            }
        }
        int i = 0;
        int i2 = 0;
        double d4 = 0.0d;
        for (SDVariable sDVariable3 : sameDiff.variables()) {
            if (!hashSet.contains(sDVariable3.getVarName())) {
                if (set == null || !set.contains(sDVariable3.getVarName())) {
                    String varName = sDVariable3.getVarName();
                    INDArray arr2 = sDVariable3.getArr();
                    long length = arr2.length();
                    if (z) {
                        log.info("Starting test for variable \"{}\" with {} values", sDVariable3.getVarName(), Long.valueOf(length));
                    }
                    NdIndexIterator ndIndexIterator = new NdIndexIterator('c', arr2.shape());
                    INDArray iNDArray = map2 == null ? null : map2.get(sDVariable3.getVarName());
                    if (iNDArray != null) {
                        Preconditions.checkState(arr2.equalShapes(iNDArray), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", sDVariable3.getVarName(), arr2.shape(), iNDArray.shape());
                        Preconditions.checkState(iNDArray.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", sDVariable3.getVarName(), iNDArray.dataType());
                    }
                    int i3 = 0;
                    while (ndIndexIterator.hasNext()) {
                        long[] next = ndIndexIterator.next();
                        String replaceAll = z ? Arrays.toString(next).replaceAll(" ", "") : null;
                        if (iNDArray == null || iNDArray.getDouble(next) != 0.0d) {
                            i2++;
                            double d5 = arr2.getDouble(next);
                            arr2.putScalar(next, d5 + d);
                            double d6 = 0.0d;
                            Iterator<INDArray> it2 = sameDiff.exec(map, lossVariables).values().iterator();
                            while (it2.hasNext()) {
                                d6 += it2.next().sumNumber().doubleValue();
                            }
                            arr2.putScalar(next, d5 - d);
                            double d7 = 0.0d;
                            Iterator<INDArray> it3 = sameDiff.exec(map, lossVariables).values().iterator();
                            while (it3.hasNext()) {
                                d7 += it3.next().sumNumber().doubleValue();
                            }
                            arr2.putScalar(next, d5);
                            double d8 = (d6 - d7) / (2.0d * d);
                            double d9 = ((INDArray) hashMap.get(sDVariable3.getVarName())).getDouble(next);
                            if (Double.isInfinite(d8) || Double.isNaN(d8)) {
                                throw new IllegalStateException("Numerical gradient was " + d8 + " for variable \"" + varName + "\", parameter " + i3 + " of " + length + " (position: " + replaceAll + ")");
                            }
                            if (Double.isInfinite(d9) || Double.isNaN(d9)) {
                                throw new IllegalStateException("Analytic (SameDiff) gradient was " + d9 + " for variable \"" + varName + "\", parameter " + i3 + " of " + length + " (position: " + replaceAll + ")");
                            }
                            double abs = (d8 == 0.0d && d9 == 0.0d) ? 0.0d : Math.abs(d9 - d8) / Math.abs(Math.abs(d9) + Math.abs(d8));
                            if (abs > d4) {
                                d4 = abs;
                            }
                            if (abs > d2 || Double.isNaN(abs)) {
                                double abs2 = Math.abs(d9 - d8);
                                if (abs2 >= d3) {
                                    if (z) {
                                        log.info("Param " + i3 + " (" + varName + replaceAll + ") FAILED: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs + ", absError=" + abs2 + ", scorePlus=" + d6 + ", scoreMinus= " + d7);
                                    }
                                    if (z2) {
                                        return false;
                                    }
                                    i++;
                                } else if (z) {
                                    log.info("Param " + i3 + " (" + varName + replaceAll + ") passed: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                                }
                            } else if (z) {
                                log.info("Param " + i3 + " (" + varName + replaceAll + ") passed: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs);
                            }
                            i3++;
                        }
                    }
                } else {
                    log.info("Grad check: skipping variable \"{}\"", sDVariable3.getVarName());
                }
            }
        }
        if (z) {
            log.info("GradCheckUtil.checkGradients(): " + i2 + " params checked, " + (i2 - i) + " passed, " + i + " failed. Largest relative error = " + d4);
        }
        if (z4 && !isDebugMode) {
            sameDiff.disableDebugging();
        }
        return i == 0;
    }

    public static void validateInternalState(SameDiff sameDiff, boolean z) {
        DifferentialFunction[] functions = sameDiff.functions();
        List<SDVariable> variables = sameDiff.variables();
        Preconditions.checkState(variables.size() == new HashSet(variables).size(), "Duplicate variables in variables() list");
        HashSet hashSet = new HashSet();
        for (SDVariable sDVariable : variables) {
            if (hashSet.contains(sDVariable.getVarName())) {
                throw new IllegalStateException("Variable with name " + sDVariable.getVarName() + " already encountered");
            }
            hashSet.add(sDVariable.getVarName());
        }
        Map<String, SameDiffOp> ops = sameDiff.getOps();
        Preconditions.checkState(functions.length == ops.size(), "All functions not present in incomingArgsReverse");
        for (DifferentialFunction differentialFunction : functions) {
            Preconditions.checkState(ops.containsKey(differentialFunction.getOwnName()), differentialFunction.getOwnName() + " not present in ops map");
            List<String> inputsToOp = ops.get(differentialFunction.getOwnName()).getInputsToOp();
            if (inputsToOp != null) {
                for (String str : inputsToOp) {
                    Preconditions.checkState(hashSet.contains(str), "Variable " + str + " in op inputs not a known variable name");
                }
            }
            List<String> outputsOfOp = ops.get(differentialFunction.getOwnName()).getOutputsOfOp();
            if (outputsOfOp != null) {
                for (String str2 : outputsOfOp) {
                    Preconditions.checkState(hashSet.contains(str2), "Variable " + str2 + " in op outputs not a known variable name");
                }
            }
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, SameDiffOp> entry : ops.entrySet()) {
            List<String> outputsOfOp2 = entry.getValue().getOutputsOfOp();
            if (outputsOfOp2 != null) {
                for (String str3 : outputsOfOp2) {
                    if (hashMap.containsKey(str3)) {
                        throw new IllegalStateException("Already saw variable \"" + str3 + "\" as output for op \"" + ((String) hashMap.get(str3)) + "\": expected variables to be present as an output only once; also seen as output for op \"" + entry.getKey() + "\"");
                    }
                    hashMap.put(str3, entry.getKey());
                }
            }
        }
        Map<String, Variable> variables2 = sameDiff.getVariables();
        Preconditions.checkState(variables.size() == variables2.size(), "Variable map size check failed");
        for (Map.Entry<String, Variable> entry2 : variables2.entrySet()) {
            Preconditions.checkState(entry2.getKey().equals(entry2.getValue().getVariable().getVarName()), "Name not equal");
        }
        if (z) {
            if (sameDiff.getFunction("grad") == null) {
                sameDiff.createGradFunction();
            }
            SameDiff function = sameDiff.getFunction("grad");
            validateInternalState(function, false);
            for (DifferentialFunction differentialFunction2 : functions) {
                Preconditions.checkNotNull(function.getFunctionById(differentialFunction2.getOwnName()), "DifferentialFunction " + differentialFunction2.getOwnName() + " from original SameDiff instance not present in grad fn");
            }
        }
    }

    private static <T> T getObject(String str, Object obj, Class<?> cls) {
        try {
            Field declaredField = cls.getDeclaredField(str);
            declaredField.setAccessible(true);
            return (T) declaredField.get(obj);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
