package edu.cmu.graphchi.apps.pig;

import edu.cmu.graphchi.ChiEdge;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.ChiVertex;
import edu.cmu.graphchi.GraphChiContext;
import edu.cmu.graphchi.GraphChiProgram;
import edu.cmu.graphchi.datablocks.FloatConverter;
import edu.cmu.graphchi.datablocks.IntConverter;
import edu.cmu.graphchi.engine.GraphChiEngine;
import edu.cmu.graphchi.engine.VertexInterval;
import edu.cmu.graphchi.hadoop.PigGraphChiBase;
import edu.cmu.graphchi.preprocessing.EdgeProcessor;
import edu.cmu.graphchi.preprocessing.FastSharder;
import edu.cmu.graphchi.preprocessing.VertexIdTranslate;
import edu.cmu.graphchi.util.HugeDoubleMatrix;
import java.io.IOException;
import java.util.logging.Logger;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.CholeskyDecompositionImpl;
import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.linear.RealVector;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;

/* loaded from: input_file:edu/cmu/graphchi/apps/pig/PigALSMatrixFactorization.class */
public class PigALSMatrixFactorization extends PigGraphChiBase implements GraphChiProgram<Integer, Float> {
    private static Logger logger = ChiLogger.getLogger("ALS");
    private HugeDoubleMatrix leftSideMatrix;
    private HugeDoubleMatrix rightSideMatrix;
    private static final int LEFTSIDE = 0;
    private static final int RIGHTSIDE = 1;
    private int D = 5;
    double LAMBDA = 0.065d;
    double rmse = 0.0d;
    private int maxLeftVertexId = 0;
    private int maxRightVertexId = 0;
    private int outputCounter = 0;

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void update(ChiVertex<Integer, Float> chiVertex, GraphChiContext graphChiContext) {
        if (chiVertex.numEdges() == 0) {
            return;
        }
        VertexIdTranslate vertexIdTranslate = graphChiContext.getVertexIdTranslate();
        int i = 0;
        while (i <= 1) {
            HugeDoubleMatrix hugeDoubleMatrix = i == 0 ? this.leftSideMatrix : this.rightSideMatrix;
            HugeDoubleMatrix hugeDoubleMatrix2 = i == 0 ? this.rightSideMatrix : this.leftSideMatrix;
            if ((i != 0 || chiVertex.numOutEdges() != 0) && (i != 1 || chiVertex.numInEdges() != 0)) {
                BlockRealMatrix blockRealMatrix = new BlockRealMatrix(this.D, this.D);
                ArrayRealVector arrayRealVector = new ArrayRealVector(this.D);
                try {
                    double[] dArr = new double[this.D];
                    int numOutEdges = i == 0 ? chiVertex.numOutEdges() : chiVertex.numInEdges();
                    for (int i2 = 0; i2 < numOutEdges; i2++) {
                        ChiEdge<Float> outEdge = i == 0 ? chiVertex.outEdge(i2) : chiVertex.inEdge(i2);
                        float floatValue = outEdge.getValue().floatValue();
                        if (floatValue < 1.0d) {
                            throw new RuntimeException("Had invalid observation: " + floatValue + " on edge " + vertexIdTranslate.backward(chiVertex.getId()) + "->" + vertexIdTranslate.backward(outEdge.getVertexId()));
                        }
                        hugeDoubleMatrix2.getRow(vertexIdTranslate.backward(outEdge.getVertexId()), dArr);
                        for (int i3 = 0; i3 < this.D; i3++) {
                            arrayRealVector.setEntry(i3, arrayRealVector.getEntry(i3) + (dArr[i3] * floatValue));
                            for (int i4 = i3; i4 < this.D; i4++) {
                                blockRealMatrix.setEntry(i4, i3, blockRealMatrix.getEntry(i4, i3) + (dArr[i3] * dArr[i4]));
                            }
                        }
                    }
                    for (int i5 = 0; i5 < this.D; i5++) {
                        for (int i6 = i5 + 1; i6 < this.D; i6++) {
                            blockRealMatrix.setEntry(i5, i6, blockRealMatrix.getEntry(i6, i5));
                        }
                    }
                    for (int i7 = 0; i7 < this.D; i7++) {
                        blockRealMatrix.setEntry(i7, i7, blockRealMatrix.getEntry(i7, i7) + (this.LAMBDA * chiVertex.numEdges()));
                    }
                    RealVector solve = new CholeskyDecompositionImpl(blockRealMatrix).getSolver().solve(arrayRealVector);
                    for (int i8 = 0; i8 < this.D; i8++) {
                        hugeDoubleMatrix.setValue(vertexIdTranslate.backward(chiVertex.getId()), i8, solve.getEntry(i8));
                    }
                    if (graphChiContext.isLastIteration() && i == 1 && chiVertex.numInEdges() > 0) {
                        double d = 0.0d;
                        for (int i9 = 0; i9 < chiVertex.numInEdges(); i9++) {
                            ChiEdge<Float> inEdge = chiVertex.inEdge(i9);
                            float floatValue2 = inEdge.getValue().floatValue();
                            hugeDoubleMatrix2.getRow(vertexIdTranslate.backward(inEdge.getVertexId()), dArr);
                            double dotProduct = new ArrayRealVector(dArr).dotProduct(solve);
                            d += (dotProduct - floatValue2) * (dotProduct - floatValue2);
                        }
                        synchronized (this) {
                            this.rmse += d;
                        }
                    }
                } catch (NotPositiveDefiniteMatrixException e) {
                    logger.warning("Matrix was not positive definite: " + blockRealMatrix);
                } catch (Exception e2) {
                    e2.printStackTrace();
                    throw new RuntimeException(e2);
                }
            }
            i++;
        }
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void beginIteration(GraphChiContext graphChiContext) {
        if (graphChiContext.getIteration() == 0) {
            logger.info("Initializing latent factors for " + (1 + this.maxLeftVertexId) + " vertices on the left side");
            logger.info("Initializing latent factors for " + (1 + this.maxRightVertexId) + " vertices on the right side");
            this.leftSideMatrix = new HugeDoubleMatrix(this.maxLeftVertexId + 1, this.D);
            this.rightSideMatrix = new HugeDoubleMatrix(this.maxRightVertexId + 1, this.D);
            this.leftSideMatrix.randomize(0.0d, 1.0d);
            this.rightSideMatrix.randomize(0.0d, 1.0d);
        }
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void endIteration(GraphChiContext graphChiContext) {
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void beginInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void endInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void beginSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void endSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override // edu.cmu.graphchi.hadoop.PigGraphChiBase
    protected FastSharder createSharder(String str, int i) throws IOException {
        return new FastSharder(str, i, null, new EdgeProcessor<Float>() { // from class: edu.cmu.graphchi.apps.pig.PigALSMatrixFactorization.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.cmu.graphchi.preprocessing.EdgeProcessor
            public Float receiveEdge(int i2, int i3, String str2) {
                PigALSMatrixFactorization.this.maxLeftVertexId = Math.max(i2, PigALSMatrixFactorization.this.maxLeftVertexId);
                PigALSMatrixFactorization.this.maxRightVertexId = Math.max(i3, PigALSMatrixFactorization.this.maxRightVertexId);
                return Float.valueOf(str2 == null ? 0.0f : Float.parseFloat(str2));
            }
        }, new IntConverter(), new FloatConverter());
    }

    @Override // edu.cmu.graphchi.hadoop.PigGraphChiBase
    protected String getSchemaString() {
        String str = "{factor:string,id:int";
        for (int i = 0; i < this.D; i++) {
            str = str + ",x" + i;
        }
        return str + "}";
    }

    @Override // edu.cmu.graphchi.hadoop.PigGraphChiBase
    protected int getNumShards() {
        return 20;
    }

    @Override // edu.cmu.graphchi.hadoop.PigGraphChiBase
    protected void runGraphChi() throws Exception {
        GraphChiEngine graphChiEngine = new GraphChiEngine(getGraphName(), getNumShards());
        graphChiEngine.setEdataConverter(new FloatConverter());
        graphChiEngine.setEnableDeterministicExecution(false);
        graphChiEngine.setVertexDataConverter(null);
        graphChiEngine.setModifiesInedges(false);
        graphChiEngine.setModifiesOutedges(false);
        graphChiEngine.run(this, 5);
        logger.info("Train RMSE: " + Math.sqrt(this.rmse / (1.0d * graphChiEngine.numEdges())) + ", total edges:" + graphChiEngine.numEdges());
    }

    @Override // edu.cmu.graphchi.hadoop.PigGraphChiBase
    protected Tuple getNextResult(TupleFactory tupleFactory) throws ExecException {
        HugeDoubleMatrix hugeDoubleMatrix;
        int i;
        Object obj;
        if (this.outputCounter < this.maxLeftVertexId) {
            hugeDoubleMatrix = this.leftSideMatrix;
            i = this.outputCounter;
            obj = "U";
        } else {
            hugeDoubleMatrix = this.rightSideMatrix;
            i = this.outputCounter - this.maxLeftVertexId;
            obj = "V";
            if (i >= this.rightSideMatrix.getNumRows()) {
                return null;
            }
        }
        Tuple newTuple = tupleFactory.newTuple(2 + this.D);
        newTuple.set(0, obj);
        newTuple.set(1, Integer.valueOf(i));
        for (int i2 = 0; i2 < this.D; i2++) {
            newTuple.set(2 + i2, Double.valueOf(hugeDoubleMatrix.getValue(i, i2)));
        }
        this.outputCounter++;
        return newTuple;
    }
}
