package org.openimaj.ml.linear.evaluation;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.loss.LossFunction;
import org.openimaj.ml.linear.learner.loss.MatLossFunction;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/ml/linear/evaluation/SumLossEvaluator.class */
public class SumLossEvaluator extends BilinearEvaluator {
    Logger logger = LogManager.getLogger(SumLossEvaluator.class);

    @Override // org.openimaj.ml.linear.evaluation.BilinearEvaluator
    public double evaluate(List<Pair<Matrix>> list) {
        return sumLoss(list, this.learner.getU(), this.learner.getW(), this.learner.getBias(), this.learner.getParams());
    }

    public double sumLoss(List<Pair<Matrix>> list, Matrix matrix, Matrix matrix2, Matrix matrix3, BilinearLearnerParameters bilinearLearnerParameters) {
        MatLossFunction matLossFunction = new MatLossFunction((LossFunction) bilinearLearnerParameters.getTyped(BilinearLearnerParameters.LOSS));
        double d = 0.0d;
        int i = 0;
        for (Pair<Matrix> pair : list) {
            Matrix matrix4 = (Matrix) pair.firstObject();
            SparseMatrix expandY = BilinearSparseOnlineLearner.expandY((Matrix) pair.secondObject());
            Matrix times = matrix.transpose().times(matrix4.transpose()).times(matrix2);
            matLossFunction.setY(expandY);
            matLossFunction.setX(times);
            if (matrix3 != null) {
                matLossFunction.setBias(matrix3);
            }
            this.logger.debug("Testing pair: " + i);
            d += matLossFunction.eval(null);
            i++;
        }
        return d;
    }
}
