package org.ranksys.javafm.learner.gd;

import java.util.List;
import java.util.logging.Logger;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.ranksys.javafm.FM;
import org.ranksys.javafm.FMInstance;
import org.ranksys.javafm.data.ListWiseFMData;
import org.ranksys.javafm.learner.FMLearner;

/* loaded from: input_file:org/ranksys/javafm/learner/gd/ListRank.class */
public class ListRank implements FMLearner<ListWiseFMData> {
    private static final Logger LOG = Logger.getLogger(ListRank.class.getName());
    private final double learnRate;
    private final int numIter;
    private final double regB;
    private final double[] regW;
    private final double[] regM;

    public ListRank(double d, int i, double d2, double[] dArr, double[] dArr2) {
        this.learnRate = d;
        this.numIter = i;
        this.regB = d2;
        this.regW = dArr;
        this.regM = dArr2;
    }

    private double[] getP(List<? extends FMInstance> list) {
        double[] array = list.stream().mapToDouble((v0) -> {
            return v0.getTarget();
        }).map(Math::exp).toArray();
        double sum = DoubleStream.of(array).sum();
        for (int i = 0; i < array.length; i++) {
            int i2 = i;
            array[i2] = array[i2] / sum;
        }
        return array;
    }

    private double[] getQ(FM fm, List<? extends FMInstance> list) {
        Stream<? extends FMInstance> stream = list.stream();
        fm.getClass();
        double[] array = stream.mapToDouble(fm::predict).map(Math::exp).toArray();
        double sum = DoubleStream.of(array).sum();
        for (int i = 0; i < array.length; i++) {
            int i2 = i;
            array[i2] = array[i2] / sum;
        }
        return array;
    }

    @Override // org.ranksys.javafm.learner.FMLearner
    public double error(FM fm, ListWiseFMData listWiseFMData) {
        return listWiseFMData.streamByGroup().map((v0) -> {
            return v0.getValue();
        }).mapToDouble(list -> {
            double[] p = getP(list);
            double[] q = getQ(fm, list);
            return IntStream.range(0, list.size()).mapToDouble(i -> {
                return (-p[i]) * Math.log(q[i]);
            }).sum();
        }).average().getAsDouble();
    }

    @Override // org.ranksys.javafm.learner.FMLearner
    public void learn(FM fm, ListWiseFMData listWiseFMData, ListWiseFMData listWiseFMData2) {
        LOG.fine(() -> {
            return String.format("iteration n = %3d e = %.6f e = %.6f", 0, Double.valueOf(error(fm, listWiseFMData)), Double.valueOf(error(fm, listWiseFMData2)));
        });
        for (int i = 1; i <= this.numIter; i++) {
            long nanoTime = System.nanoTime();
            listWiseFMData.shuffle();
            listWiseFMData.streamByGroup().map((v0) -> {
                return v0.getValue();
            }).forEach(list -> {
                double b = fm.getB();
                double[] w = fm.getW();
                double[][] m = fm.getM();
                double[] p = getP(list);
                double[] q = getQ(fm, list);
                for (int i2 = 0; i2 < list.size(); i2++) {
                    FMInstance fMInstance = (FMInstance) list.get(i2);
                    double d = (-p[i2]) + q[i2];
                    fm.setB(b - (this.learnRate * (d + (this.regB * b))));
                    double[] dArr = new double[m[0].length];
                    fMInstance.consume((i3, d2) -> {
                        for (int i3 = 0; i3 < dArr.length; i3++) {
                            int i4 = i3;
                            dArr[i4] = dArr[i4] + (d2 * m[i3][i3]);
                        }
                        w[i3] = w[i3] - (this.learnRate * ((d * d2) + (this.regW[i3] * w[i3])));
                    });
                    fMInstance.consume((i4, d3) -> {
                        for (int i4 = 0; i4 < m[i4].length; i4++) {
                            double[] dArr2 = m[i4];
                            int i5 = i4;
                            dArr2[i5] = dArr2[i5] - (this.learnRate * ((((d * d3) * dArr[i4]) - (((d * d3) * d3) * m[i4][i4])) + (this.regM[i4] * m[i4][i4])));
                        }
                    });
                }
            });
            int i2 = i;
            LOG.info(String.format("iteration n = %3d t = %.2fs", Integer.valueOf(i2), Double.valueOf((System.nanoTime() - nanoTime) / 1.0E9d)));
            LOG.fine(() -> {
                return String.format("iteration n = %3d e = %.6f e = %.6f", Integer.valueOf(i2), Double.valueOf(error(fm, listWiseFMData)), Double.valueOf(error(fm, listWiseFMData2)));
            });
        }
    }
}
