package org.deeplearning4j.spark.api.stats;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
import org.deeplearning4j.spark.stats.BaseEventStats;
import org.deeplearning4j.spark.stats.EventStats;
import org.deeplearning4j.spark.stats.ExampleCountEventStats;
import org.deeplearning4j.spark.time.TimeSource;
import org.deeplearning4j.spark.time.TimeSourceProvider;

/* loaded from: input_file:org/deeplearning4j/spark/api/stats/StatsCalculationHelper.class */
public class StatsCalculationHelper {
    private long methodStartTime;
    private long returnTime;
    private long initalModelBefore;
    private long initialModelAfter;
    private long lastDataSetBefore;
    private long lastProcessBefore;
    private long totalExampleCount;
    private List<EventStats> dataSetGetTimes = new ArrayList();
    private List<EventStats> processMiniBatchTimes = new ArrayList();
    private TimeSource timeSource = TimeSourceProvider.getInstance();

    public void logMethodStartTime() {
        this.methodStartTime = this.timeSource.currentTimeMillis();
    }

    public void logReturnTime() {
        this.returnTime = this.timeSource.currentTimeMillis();
    }

    public void logInitialModelBefore() {
        this.initalModelBefore = this.timeSource.currentTimeMillis();
    }

    public void logInitialModelAfter() {
        this.initialModelAfter = this.timeSource.currentTimeMillis();
    }

    public void logNextDataSetBefore() {
        this.lastDataSetBefore = this.timeSource.currentTimeMillis();
    }

    public void logNextDataSetAfter(long j) {
        this.dataSetGetTimes.add(new BaseEventStats(this.lastDataSetBefore, this.timeSource.currentTimeMillis() - this.lastDataSetBefore));
        this.totalExampleCount += j;
    }

    public void logProcessMinibatchBefore() {
        this.lastProcessBefore = this.timeSource.currentTimeMillis();
    }

    public void logProcessMinibatchAfter() {
        this.processMiniBatchTimes.add(new BaseEventStats(this.lastProcessBefore, this.timeSource.currentTimeMillis() - this.lastProcessBefore));
    }

    public CommonSparkTrainingStats build(SparkTrainingStats sparkTrainingStats) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new ExampleCountEventStats(this.methodStartTime, this.returnTime - this.methodStartTime, this.totalExampleCount));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new BaseEventStats(this.initalModelBefore, this.initialModelAfter - this.initalModelBefore));
        return new CommonSparkTrainingStats.Builder().trainingMasterSpecificStats(sparkTrainingStats).workerFlatMapTotalTimeMs(arrayList).workerFlatMapGetInitialModelTimeMs(arrayList2).workerFlatMapDataSetGetTimesMs(this.dataSetGetTimes).workerFlatMapProcessMiniBatchTimesMs(this.processMiniBatchTimes).build();
    }
}
