package org.nd4j.autodiff.samediff.config;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/nd4j/autodiff/samediff/config/EvaluationConfig.class */
public class EvaluationConfig {
    private MultiDataSetIterator data;
    private SameDiff sd;

    @NonNull
    private Map<String, List<IEvaluation>> evaluations = new HashMap();

    @NonNull
    private Map<String, Integer> labelIndices = new HashMap();

    @NonNull
    private List<Listener> listeners = new ArrayList();
    private boolean singleInput = false;

    public EvaluationConfig(@NonNull SameDiff sameDiff) {
        if (sameDiff == null) {
            throw new NullPointerException("sd is marked @NonNull but is null");
        }
        this.sd = sameDiff;
    }

    public EvaluationConfig evaluate(@NonNull String str, int i, @NonNull IEvaluation... iEvaluationArr) {
        if (str == null) {
            throw new NullPointerException("param is marked @NonNull but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        return evaluate(str, iEvaluationArr).labelIndex(str, i);
    }

    public EvaluationConfig evaluate(@NonNull SDVariable sDVariable, int i, @NonNull IEvaluation... iEvaluationArr) {
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        return evaluate(sDVariable.getVarName(), i, iEvaluationArr);
    }

    public EvaluationConfig evaluate(@NonNull String str, @NonNull IEvaluation... iEvaluationArr) {
        if (str == null) {
            throw new NullPointerException("param is marked @NonNull but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        if (this.evaluations.get(str) == null) {
            this.evaluations.put(str, new ArrayList());
        }
        this.evaluations.get(str).addAll(Arrays.asList(iEvaluationArr));
        return this;
    }

    public EvaluationConfig evaluate(@NonNull SDVariable sDVariable, @NonNull IEvaluation... iEvaluationArr) {
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        return evaluate(sDVariable.getVarName(), iEvaluationArr);
    }

    public EvaluationConfig labelIndex(@NonNull String str, int i) {
        if (str == null) {
            throw new NullPointerException("param is marked @NonNull but is null");
        }
        if (this.labelIndices.get(str) != null) {
            int intValue = this.labelIndices.get(str).intValue();
            Preconditions.checkArgument(intValue == i, "Different label index already specified for param %s.  Already specified: %s, given: %s", str, Integer.valueOf(intValue), Integer.valueOf(i));
        }
        this.labelIndices.put(str, Integer.valueOf(i));
        return this;
    }

    public EvaluationConfig labelIndex(@NonNull SDVariable sDVariable, int i) {
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        return labelIndex(sDVariable.getVarName(), i);
    }

    public EvaluationConfig listeners(@NonNull Listener... listenerArr) {
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners.addAll(Arrays.asList(listenerArr));
        return this;
    }

    public EvaluationConfig data(@NonNull MultiDataSetIterator multiDataSetIterator) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.data = multiDataSetIterator;
        this.singleInput = false;
        return this;
    }

    public EvaluationConfig data(@NonNull DataSetIterator dataSetIterator) {
        if (dataSetIterator == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        this.data = new MultiDataSetIteratorAdapter(dataSetIterator);
        this.singleInput = true;
        return this;
    }

    private void validateConfig() {
        Preconditions.checkNotNull(this.data, "Must specify data.  It may not be null.");
        if (!this.singleInput) {
            for (String str : this.evaluations.keySet()) {
                Preconditions.checkState(this.labelIndices.containsKey(str), "Using multiple input dataset iterator without specifying a label index for %s", str);
            }
        }
        for (String str2 : this.evaluations.keySet()) {
            Preconditions.checkState(this.sd.variableMap().containsKey(str2), "Parameter %s not present in this SameDiff graph", str2);
        }
    }

    public EvaluationRecord exec() {
        validateConfig();
        if (this.singleInput) {
            Iterator<String> it2 = this.evaluations.keySet().iterator();
            while (it2.hasNext()) {
                this.labelIndices.put(it2.next(), 0);
            }
        }
        this.sd.evaluate(this.data, this.evaluations, this.labelIndices, (Listener[]) this.listeners.toArray(new Listener[0]));
        return new EvaluationRecord(this.evaluations);
    }

    @NonNull
    public Map<String, List<IEvaluation>> getEvaluations() {
        return this.evaluations;
    }

    @NonNull
    public Map<String, Integer> getLabelIndices() {
        return this.labelIndices;
    }

    public MultiDataSetIterator getData() {
        return this.data;
    }

    @NonNull
    public List<Listener> getListeners() {
        return this.listeners;
    }

    public boolean isSingleInput() {
        return this.singleInput;
    }

    public SameDiff getSd() {
        return this.sd;
    }

    public void setEvaluations(@NonNull Map<String, List<IEvaluation>> map) {
        if (map == null) {
            throw new NullPointerException("evaluations is marked @NonNull but is null");
        }
        this.evaluations = map;
    }

    public void setLabelIndices(@NonNull Map<String, Integer> map) {
        if (map == null) {
            throw new NullPointerException("labelIndices is marked @NonNull but is null");
        }
        this.labelIndices = map;
    }

    public void setData(MultiDataSetIterator multiDataSetIterator) {
        this.data = multiDataSetIterator;
    }

    public void setListeners(@NonNull List<Listener> list) {
        if (list == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners = list;
    }

    public void setSingleInput(boolean z) {
        this.singleInput = z;
    }
}
