package org.nd4j.linalg.profiler;

import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.profiler.data.StackAggregator;
import org.nd4j.linalg.profiler.data.StringAggregator;
import org.nd4j.linalg.profiler.data.StringCounter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/profiler/OpProfiler.class */
public class OpProfiler {
    private static AtomicLong invocationsCount = new AtomicLong(0);
    private static OpProfiler ourInstance = new OpProfiler();
    private static StringAggregator classAggergator = new StringAggregator();
    private static StringAggregator longAggergator = new StringAggregator();
    private static StringCounter classCounter = new StringCounter();
    private static StringCounter opCounter = new StringCounter();
    private static StringCounter classPairsCounter = new StringCounter();
    private static StringCounter opPairsCounter = new StringCounter();
    private static StringCounter matchingCounter = new StringCounter();
    private static StringCounter matchingCounterDetailed = new StringCounter();
    private static StringCounter matchingCounterInverted = new StringCounter();
    private static StringCounter orderCounter = new StringCounter();
    private static StackAggregator methodsAggregator = new StackAggregator();
    private static StackAggregator scalarAggregator = new StackAggregator();
    private static StackAggregator mixedOrderAggregator = new StackAggregator();
    private static StackAggregator nonEwsAggregator = new StackAggregator();
    private static StackAggregator stridedAggregator = new StackAggregator();
    private static StackAggregator tadStridedAggregator = new StackAggregator();
    private static StackAggregator tadNonEwsAggregator = new StackAggregator();
    private static StackAggregator blasAggregator = new StackAggregator();
    private static StringCounter blasOrderCounter = new StringCounter();
    private static Logger logger = LoggerFactory.getLogger((Class<?>) OpProfiler.class);
    private static final long THRESHOLD = 100000;
    private String prevOpClass = "";
    private String prevOpName = "";
    private String prevOpMatching = "";
    private String prevOpMatchingDetailed = "";
    private String prevOpMatchingInverted = "";
    private long lastZ = 0;

    /* loaded from: input_file:org/nd4j/linalg/profiler/OpProfiler$PenaltyCause.class */
    public enum PenaltyCause {
        NONE,
        NON_EWS_ACCESS,
        STRIDED_ACCESS,
        MIXED_ORDER,
        TAD_NON_EWS_ACCESS,
        TAD_STRIDED_ACCESS
    }

    public void reset() {
        invocationsCount.set(0L);
        classAggergator.reset();
        longAggergator.reset();
        classCounter.reset();
        opCounter.reset();
        classPairsCounter.reset();
        opPairsCounter.reset();
        matchingCounter.reset();
        matchingCounterDetailed.reset();
        matchingCounterInverted.reset();
        methodsAggregator.reset();
        scalarAggregator.reset();
        nonEwsAggregator.reset();
        stridedAggregator.reset();
        tadNonEwsAggregator.reset();
        tadStridedAggregator.reset();
        mixedOrderAggregator.reset();
        blasAggregator.reset();
        blasOrderCounter.reset();
        orderCounter.reset();
    }

    public static OpProfiler getInstance() {
        return ourInstance;
    }

    private OpProfiler() {
    }

    protected String getOpClass(Op op) {
        return op instanceof ScalarOp ? "ScalarOp" : op instanceof MetaOp ? "MetaOp" : op instanceof GridOp ? "GridOp" : op instanceof BroadcastOp ? "BroadcastOp" : op instanceof RandomOp ? "RandomOp" : op instanceof Accumulation ? "AccumulationOp" : op instanceof TransformOp ? op.y() == null ? "TransformOp" : "PairWiseTransformOp" : op instanceof IndexAccumulation ? "IndexAccumulationOp" : "Unknown Op calls";
    }

    public void processScalarCall() {
        invocationsCount.incrementAndGet();
        scalarAggregator.incrementCount();
    }

    public void processOpCall(Op op) {
        invocationsCount.incrementAndGet();
        opCounter.incrementCount(op.name());
        String opClass = getOpClass(op);
        classCounter.incrementCount(opClass);
        if (op.x().data().address() == this.lastZ && op.z() == op.x() && op.y() == null) {
            matchingCounter.incrementCount(this.prevOpMatching + " -> " + opClass);
            matchingCounterDetailed.incrementCount(this.prevOpMatchingDetailed + " -> " + opClass + " " + op.name());
        } else {
            matchingCounter.totalsIncrement();
            matchingCounterDetailed.totalsIncrement();
            if (op.y() == null || op.y().data().address() != this.lastZ) {
                matchingCounterInverted.totalsIncrement();
            } else {
                matchingCounterInverted.incrementCount(this.prevOpMatchingInverted + " -> " + opClass + " " + op.name());
            }
        }
        this.lastZ = op.z().data().address();
        this.prevOpMatching = opClass;
        this.prevOpMatchingDetailed = opClass + " " + op.name();
        this.prevOpMatchingInverted = opClass + " " + op.name();
        updatePairs(op.name(), opClass);
        int length = processOperands(op.x(), op.y(), op.z()).length;
        for (int i = 0; i < length; i++) {
            switch (r0[i]) {
                case NON_EWS_ACCESS:
                    nonEwsAggregator.incrementCount();
                    break;
                case STRIDED_ACCESS:
                    stridedAggregator.incrementCount();
                    break;
                case MIXED_ORDER:
                    mixedOrderAggregator.incrementCount();
                    break;
            }
        }
    }

    public void processOpCall(Op op, DataBuffer... dataBufferArr) {
        processOpCall(op);
        int length = processTADOperands(dataBufferArr).length;
        for (int i = 0; i < length; i++) {
            switch (r0[i]) {
                case TAD_NON_EWS_ACCESS:
                    tadNonEwsAggregator.incrementCount();
                    break;
                case TAD_STRIDED_ACCESS:
                    tadStridedAggregator.incrementCount();
                    break;
            }
        }
    }

    public StackAggregator getMixedOrderAggregator() {
        return mixedOrderAggregator;
    }

    public StackAggregator getScalarAggregator() {
        return scalarAggregator;
    }

    protected void updatePairs(String str, String str2) {
        String str3 = this.prevOpName + " -> " + str;
        classPairsCounter.incrementCount(this.prevOpClass + " -> " + str2);
        opPairsCounter.incrementCount(str3);
        this.prevOpName = str;
        this.prevOpClass = str2;
    }

    public void timeOpCall(Op op, long j) {
        long nanoTime = System.nanoTime() - j;
        classAggergator.putTime(getOpClass(op), op, nanoTime);
        if (nanoTime > THRESHOLD) {
            longAggergator.putTime(getOpClass(op) + " " + op.name() + " (" + op.opNum() + ")", nanoTime);
        }
    }

    @Deprecated
    public void processBlasCall(String str) {
        invocationsCount.incrementAndGet();
        opCounter.incrementCount(str);
        classCounter.incrementCount("BLAS");
        updatePairs(str, "BLAS");
        this.prevOpMatching = "";
        this.lastZ = 0L;
    }

    public void timeBlasCall() {
    }

    public void printOutDashboard() {
        logger.info("---Total Op Calls: {}", Long.valueOf(invocationsCount.get()));
        System.out.println();
        logger.info("--- OpClass calls statistics: ---");
        System.out.println(classCounter.asString());
        System.out.println();
        logger.info("--- OpClass pairs statistics: ---");
        System.out.println(classPairsCounter.asString());
        System.out.println();
        logger.info("--- Individual Op calls statistics: ---");
        System.out.println(opCounter.asString());
        System.out.println();
        logger.info("--- Matching Op calls statistics: ---");
        System.out.println(matchingCounter.asString());
        System.out.println();
        logger.info("--- Matching detailed Op calls statistics: ---");
        System.out.println(matchingCounterDetailed.asString());
        System.out.println();
        logger.info("--- Matching inverts Op calls statistics: ---");
        System.out.println(matchingCounterInverted.asString());
        System.out.println();
        logger.info("--- Time for OpClass calls statistics: ---");
        System.out.println(classAggergator.asString());
        System.out.println();
        logger.info("--- Time for long Op calls statistics: ---");
        System.out.println(longAggergator.asString());
        System.out.println();
        logger.info("--- Time spent for Op calls statistics: ---");
        System.out.println(classAggergator.asPercentageString());
        System.out.println();
        logger.info("--- Time spent for long Op calls statistics: ---");
        System.out.println(longAggergator.asPercentageString());
        System.out.println();
        logger.info("--- Time spent within methods: ---");
        methodsAggregator.renderTree(true);
        System.out.println();
        logger.info("--- Bad strides stack tree: ---");
        System.out.println("Unique entries: " + stridedAggregator.getUniqueBranchesNumber());
        stridedAggregator.renderTree();
        System.out.println();
        logger.info("--- non-EWS access stack tree: ---");
        System.out.println("Unique entries: " + nonEwsAggregator.getUniqueBranchesNumber());
        nonEwsAggregator.renderTree();
        System.out.println();
        logger.info("--- Mixed orders access stack tree: ---");
        System.out.println("Unique entries: " + mixedOrderAggregator.getUniqueBranchesNumber());
        mixedOrderAggregator.renderTree();
        System.out.println();
        logger.info("--- TAD bad strides stack tree: ---");
        System.out.println("Unique entries: " + tadStridedAggregator.getUniqueBranchesNumber());
        tadStridedAggregator.renderTree();
        System.out.println();
        logger.info("--- TAD non-EWS access stack tree: ---");
        System.out.println("Unique entries: " + tadNonEwsAggregator.getUniqueBranchesNumber());
        tadNonEwsAggregator.renderTree();
        System.out.println();
        logger.info("--- Scalar access stack tree: ---");
        System.out.println("Unique entries: " + scalarAggregator.getUniqueBranchesNumber());
        scalarAggregator.renderTree(false);
        System.out.println();
        logger.info("--- Blas GEMM odrders count: ---");
        System.out.println(blasOrderCounter.asString());
        System.out.println();
        logger.info("--- BLAS access stack trace: ---");
        System.out.println("Unique entries: " + blasAggregator.getUniqueBranchesNumber());
        blasAggregator.renderTree(false);
        System.out.println();
    }

    public long getInvocationsCount() {
        return invocationsCount.get();
    }

    public void processStackCall(Op op, long j) {
        methodsAggregator.incrementCount((System.nanoTime() - j) / 1000);
    }

    public String processOrders(INDArray... iNDArrayArr) {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 0; i < iNDArrayArr.length; i++) {
            if (iNDArrayArr[i] == null) {
                stringBuffer.append("null");
            } else {
                stringBuffer.append(new String("" + iNDArrayArr[i].ordering()).toUpperCase());
            }
            if (i < iNDArrayArr.length - 1) {
                stringBuffer.append(" x ");
            }
        }
        orderCounter.incrementCount(stringBuffer.toString());
        return stringBuffer.toString();
    }

    public void processBlasCall(boolean z, INDArray... iNDArrayArr) {
        if (z) {
            blasOrderCounter.incrementCount(processOrders(iNDArrayArr));
            int length = processOperands(iNDArrayArr).length;
            for (int i = 0; i < length; i++) {
                switch (r0[i]) {
                    case NON_EWS_ACCESS:
                    case STRIDED_ACCESS:
                    case NONE:
                        blasAggregator.incrementCount();
                        break;
                }
            }
            return;
        }
        int length2 = processOperands(iNDArrayArr).length;
        for (int i2 = 0; i2 < length2; i2++) {
            switch (r0[i2]) {
                case NON_EWS_ACCESS:
                    nonEwsAggregator.incrementCount();
                    break;
                case STRIDED_ACCESS:
                    stridedAggregator.incrementCount();
                    break;
                case MIXED_ORDER:
                    mixedOrderAggregator.incrementCount();
                    break;
            }
        }
    }

    public PenaltyCause[] processOperands(INDArray iNDArray, INDArray iNDArray2) {
        ArrayList arrayList = new ArrayList();
        if (iNDArray.ordering() != iNDArray2.ordering()) {
            arrayList.add(PenaltyCause.MIXED_ORDER);
        }
        if (iNDArray.elementWiseStride() < 1) {
            arrayList.add(PenaltyCause.NON_EWS_ACCESS);
        } else if (iNDArray2.elementWiseStride() < 1) {
            arrayList.add(PenaltyCause.NON_EWS_ACCESS);
        }
        if (iNDArray.elementWiseStride() > 1) {
            arrayList.add(PenaltyCause.STRIDED_ACCESS);
        } else if (iNDArray2.elementWiseStride() > 1) {
            arrayList.add(PenaltyCause.STRIDED_ACCESS);
        }
        if (arrayList.isEmpty()) {
            arrayList.add(PenaltyCause.NONE);
        }
        return (PenaltyCause[]) arrayList.toArray(new PenaltyCause[0]);
    }

    public PenaltyCause[] processTADOperands(DataBuffer... dataBufferArr) {
        ArrayList arrayList = new ArrayList();
        for (DataBuffer dataBuffer : dataBufferArr) {
            if (dataBuffer != null) {
                int i = dataBuffer.getInt(0L);
                int i2 = dataBuffer.getInt(((i * 2) + 4) - 2);
                if ((i2 < 1 || i > 2 || (i == 2 && dataBuffer.getInt(1L) > 1 && dataBuffer.getInt(2L) > 1)) && !arrayList.contains(PenaltyCause.TAD_NON_EWS_ACCESS)) {
                    arrayList.add(PenaltyCause.TAD_NON_EWS_ACCESS);
                } else if (i2 > 1 && !arrayList.contains(PenaltyCause.TAD_STRIDED_ACCESS)) {
                    arrayList.add(PenaltyCause.TAD_STRIDED_ACCESS);
                }
            }
        }
        if (arrayList.isEmpty()) {
            arrayList.add(PenaltyCause.NONE);
        }
        return (PenaltyCause[]) arrayList.toArray(new PenaltyCause[0]);
    }

    public PenaltyCause[] processOperands(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (iNDArray2 == null) {
            return processOperands(iNDArray, iNDArray3);
        }
        if (iNDArray == iNDArray3 || iNDArray2 == iNDArray3) {
            return processOperands(iNDArray, iNDArray2);
        }
        PenaltyCause[] processOperands = processOperands(iNDArray, iNDArray2);
        PenaltyCause[] processOperands2 = processOperands(iNDArray, iNDArray3);
        return (processOperands.length == 1 && processOperands[0] == PenaltyCause.NONE && processOperands2.length == 1 && processOperands2[0] == PenaltyCause.NONE) ? processOperands : (processOperands.length == 1 && processOperands[0] == PenaltyCause.NONE) ? processOperands2 : (processOperands2.length == 1 && processOperands2[0] == PenaltyCause.NONE) ? processOperands : joinDistinct(processOperands, processOperands2);
    }

    protected PenaltyCause[] joinDistinct(PenaltyCause[] penaltyCauseArr, PenaltyCause[] penaltyCauseArr2) {
        ArrayList arrayList = new ArrayList();
        for (PenaltyCause penaltyCause : penaltyCauseArr) {
            if (penaltyCause != null && !arrayList.contains(penaltyCause)) {
                arrayList.add(penaltyCause);
            }
        }
        for (PenaltyCause penaltyCause2 : penaltyCauseArr2) {
            if (penaltyCause2 != null && !arrayList.contains(penaltyCause2)) {
                arrayList.add(penaltyCause2);
            }
        }
        return (PenaltyCause[]) arrayList.toArray(new PenaltyCause[0]);
    }

    public PenaltyCause[] processOperands(INDArray... iNDArrayArr) {
        if (iNDArrayArr == null) {
            return new PenaltyCause[]{PenaltyCause.NONE};
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iNDArrayArr.length - 1; i++) {
            if (iNDArrayArr[i] != null || iNDArrayArr[i + 1] != null) {
                for (PenaltyCause penaltyCause : processOperands(iNDArrayArr[i], iNDArrayArr[i + 1])) {
                    if (penaltyCause != PenaltyCause.NONE && !arrayList.contains(penaltyCause)) {
                        arrayList.add(penaltyCause);
                    }
                }
            }
        }
        if (arrayList.isEmpty()) {
            arrayList.add(PenaltyCause.NONE);
        }
        return (PenaltyCause[]) arrayList.toArray(new PenaltyCause[0]);
    }

    public void processMemoryAccess() {
    }
}
