package edu.emory.mathcs.nlp.component.template.train;

import edu.emory.mathcs.nlp.common.collection.tuple.DoubleIntPair;
import edu.emory.mathcs.nlp.common.collection.tuple.ObjectDoublePair;
import edu.emory.mathcs.nlp.common.collection.tuple.Pair;
import edu.emory.mathcs.nlp.common.random.XORShiftRandom;
import edu.emory.mathcs.nlp.common.treebank.POSTagEn;
import edu.emory.mathcs.nlp.common.util.FileUtils;
import edu.emory.mathcs.nlp.common.util.IOUtils;
import edu.emory.mathcs.nlp.common.util.XMLUtils;
import edu.emory.mathcs.nlp.component.template.OnlineComponent;
import edu.emory.mathcs.nlp.component.template.config.NLPConfig;
import edu.emory.mathcs.nlp.component.template.eval.Eval;
import edu.emory.mathcs.nlp.component.template.lexicon.GlobalLexica;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.reader.TSVReader;
import edu.emory.mathcs.nlp.component.template.state.NLPState;
import edu.emory.mathcs.nlp.component.template.util.NLPFlag;
import edu.emory.mathcs.nlp.component.template.util.NLPMode;
import edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;

/* loaded from: input_file:edu/emory/mathcs/nlp/component/template/train/OnlineTrainer.class */
public abstract class OnlineTrainer<N extends AbstractNLPNode<N>, S extends NLPState<N>> {
    private static final Logger LOG = LoggerFactory.getLogger(OnlineTrainer.class);

    /* loaded from: input_file:edu/emory/mathcs/nlp/component/template/train/OnlineTrainer$TrainTask.class */
    class TrainTask implements Callable<Double> {
        private NLPMode mode;
        private List<String> trainFiles;
        private List<String> developFiles;
        private String configurationFile;
        private String modelFile;
        private String previousModelFile;

        public TrainTask(NLPMode nLPMode, List<String> list, List<String> list2, String str, String str2, String str3) {
            this.mode = nLPMode;
            this.trainFiles = list;
            this.developFiles = list2;
            this.configurationFile = str;
            this.modelFile = str2;
            this.previousModelFile = str3;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() {
            return Double.valueOf(OnlineTrainer.this.train(this.mode, this.trainFiles, this.developFiles, this.configurationFile, this.modelFile, this.previousModelFile));
        }
    }

    public OnlineComponent<N, S> initComponent(NLPMode nLPMode, InputStream inputStream, InputStream inputStream2, String str) {
        OnlineComponent<N, S> onlineComponent = null;
        NLPConfig<N> nLPConfig = null;
        if (inputStream2 != null) {
            LOG.info("Loading the previous model");
            ObjectInputStream createObjectXZBufferedInputStream = IOUtils.createObjectXZBufferedInputStream(inputStream2);
            try {
                onlineComponent = (OnlineComponent) createObjectXZBufferedInputStream.readObject();
                nLPConfig = onlineComponent.setConfiguration(inputStream);
                createObjectXZBufferedInputStream.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        } else {
            onlineComponent = createComponent(nLPMode, inputStream, str);
            nLPConfig = onlineComponent.getConfiguration();
        }
        HyperParameter hyperParameter = nLPConfig.getHyperParameter();
        onlineComponent.setHyperParameter(hyperParameter);
        if (onlineComponent.getOptimizer() != null) {
            onlineComponent.getOptimizer().adapt(hyperParameter);
        } else {
            onlineComponent.setOptimizer(nLPConfig.getOnlineOptimizer(hyperParameter));
            onlineComponent.initFeatureTemplate();
        }
        return onlineComponent;
    }

    public OnlineComponent<N, S> createComponent(NLPMode nLPMode, InputStream inputStream, String str) {
        if (str != null) {
            LOG.warn("Name not implemented for OnlineComponent. Input name - " + str + " will be ignored.");
        }
        return createComponent(nLPMode, inputStream);
    }

    public abstract OnlineComponent<N, S> createComponent(NLPMode nLPMode, InputStream inputStream);

    public abstract TSVReader<N> createTSVReader(Object2IntMap<String> object2IntMap);

    public abstract GlobalLexica<N> createGlobalLexica(InputStream inputStream);

    public double crossValidate(NLPMode nLPMode, List<String> list, String str, String str2, String str3, int i) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            Pair<List<String>, List<String>> split = split(list, i, i2);
            arrayList.add(newFixedThreadPool.submit(new TrainTask(nLPMode, split.o1, split.o2, str, str2 + POSTagEn.POS_PERIOD + i2, str3)));
        }
        newFixedThreadPool.shutdown();
        double d = 0.0d;
        try {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                d += ((Double) ((Future) it.next()).get()).doubleValue();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        double d2 = d / i;
        LOG.info(String.format("Cross-validation score: %5.2f\n", Double.valueOf(d2)));
        return d2;
    }

    public static Pair<List<String>, List<String>> split(List<String> list, int i, int i2) {
        int size = list.size() / i;
        int i3 = i2 * size;
        int size2 = i2 + 1 == i ? list.size() : (i2 + 1) * size;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i4 = 0; i4 < list.size(); i4++) {
            if (i3 > i4 || i4 >= size2) {
                arrayList.add(list.get(i4));
            } else {
                arrayList2.add(list.get(i4));
            }
        }
        return new Pair<>(arrayList, arrayList2);
    }

    public double train(NLPMode nLPMode, List<String> list, List<String> list2, String str, String str2, String str3) {
        return train(nLPMode, list, list2, str, str2, str3, 0);
    }

    public double train(NLPMode nLPMode, List<String> list, List<String> list2, String str, String str2, String str3, int i) {
        FileInputStream createFileInputStream = str3 != null ? IOUtils.createFileInputStream(str3) : null;
        GlobalLexica<N> createGlobalLexica = createGlobalLexica(IOUtils.createFileInputStream(str));
        OnlineComponent<N, S> initComponent = initComponent(nLPMode, IOUtils.createFileInputStream(str), createFileInputStream, str2 != null ? FileUtils.getBaseName(str2) : null);
        ObjectDoublePair<OnlineComponent<N, S>> objectDoublePair = null;
        try {
            objectDoublePair = train(createTSVReader(initComponent.getConfiguration().getReaderFieldMap()), list, list2, initComponent, createGlobalLexica, i);
            if (str2 != null) {
                saveModel(objectDoublePair.o, IOUtils.createFileOutputStream(str2));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return objectDoublePair.d;
    }

    public ObjectDoublePair<OnlineComponent<N, S>> train(TSVReader<N> tSVReader, List<String> list, List<String> list2, OnlineComponent<N, S> onlineComponent, GlobalLexica<N> globalLexica, int i) throws Exception {
        OnlineOptimizer optimizer = onlineComponent.getOptimizer();
        HyperParameter hyperParameter = onlineComponent.getHyperParameter();
        int i2 = -1;
        int i3 = -1;
        XORShiftRandom xORShiftRandom = new XORShiftRandom(9L);
        byte[] bArr = null;
        double d = 0.0d;
        LOG.info(optimizer.toString() + "\n" + hyperParameter.toString("- "));
        LOG.info("Training: " + i);
        for (int i4 = 1; i4 <= hyperParameter.getMaxEpochs(); i4++) {
            onlineComponent.setFlag(NLPFlag.TRAIN);
            Collections.shuffle(list, xORShiftRandom);
            hyperParameter.getLOLS().updateGoldProbability();
            iterate(tSVReader, list, onlineComponent, globalLexica, false);
            int labelSize = optimizer.getLabelSize();
            int sparseFeatureSize = onlineComponent.getFeatureTemplate().getSparseFeatureSize();
            int countNonZeroWeights = optimizer.getWeightVector().countNonZeroWeights();
            onlineComponent.getFeatureTemplate().initFeatureCount();
            DoubleIntPair evaluate = evaluate(list2, onlineComponent, globalLexica, tSVReader);
            double d2 = evaluate.d;
            LOG.info(String.format("%2d:%5d: %s, L = %3d, SF = %7d, NZW = %8d, N/S = %6d", Integer.valueOf(i), Integer.valueOf(i4), onlineComponent.getEval().toString(), Integer.valueOf(labelSize), Integer.valueOf(sparseFeatureSize), Integer.valueOf(countNonZeroWeights), Integer.valueOf(evaluate.i)));
            if (d < d2 || (d == d2 && countNonZeroWeights < i3)) {
                i3 = countNonZeroWeights;
                i2 = i4;
                d = d2;
                bArr = IOUtils.toByteArray(onlineComponent);
            }
        }
        if (bArr != null) {
            onlineComponent = (OnlineComponent) IOUtils.fromByteArray(bArr);
        }
        LOG.info(String.format("%2d: Best: %5.2f, epoch = %d", Integer.valueOf(i), Double.valueOf(d), Integer.valueOf(i2)));
        return new ObjectDoublePair<>(onlineComponent, d);
    }

    public DoubleIntPair evaluate(List<String> list, OnlineComponent<N, S> onlineComponent, GlobalLexica<N> globalLexica, TSVReader<N> tSVReader) {
        onlineComponent.setFlag(NLPFlag.EVALUATE);
        Eval eval = onlineComponent.getEval();
        eval.clear();
        return new DoubleIntPair(eval.score(), (int) Math.round(iterate(tSVReader, list, onlineComponent, globalLexica, true)));
    }

    protected double iterate(TSVReader<N> tSVReader, List<String> list, OnlineComponent<N, S> onlineComponent, GlobalLexica<N> globalLexica, boolean z) {
        long j = 0;
        long j2 = 0;
        int i = 0;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            tSVReader.open(IOUtils.createFileInputStream(it.next()));
            try {
                if (!onlineComponent.isDocumentBased()) {
                    while (true) {
                        N[] next = tSVReader.next();
                        if (next == null) {
                            break;
                        }
                        globalLexica.process(next);
                        long currentTimeMillis = System.currentTimeMillis();
                        onlineComponent.process(next);
                        long currentTimeMillis2 = System.currentTimeMillis();
                        if (!z) {
                            i = update(onlineComponent, i, false);
                        }
                        j += currentTimeMillis2 - currentTimeMillis;
                        j2 += next.length - 1;
                    }
                } else {
                    List<N[]> readDocument = tSVReader.readDocument();
                    globalLexica.process(readDocument);
                    long currentTimeMillis3 = System.currentTimeMillis();
                    onlineComponent.process(readDocument);
                    long currentTimeMillis4 = System.currentTimeMillis();
                    if (!z) {
                        i = update(onlineComponent, i, false);
                    }
                    j += currentTimeMillis4 - currentTimeMillis3;
                    j2++;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
            tSVReader.close();
        }
        if (!z) {
            update(onlineComponent, i, true);
        }
        return (1000.0d * j2) / j;
    }

    /* JADX WARN: Code restructure failed: missing block: B:10:0x001d, code lost:
    
        if (r5 == r0.getBatchSize()) goto L8;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected int update(edu.emory.mathcs.nlp.component.template.OnlineComponent<N, S> r4, int r5, boolean r6) {
        /*
            r3 = this;
            r0 = r4
            edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer r0 = r0.getOptimizer()
            r7 = r0
            r0 = r4
            edu.emory.mathcs.nlp.component.template.train.HyperParameter r0 = r0.getHyperParameter()
            r8 = r0
            r0 = r6
            if (r0 == 0) goto L14
            r0 = r5
            if (r0 > 0) goto L20
        L14:
            int r5 = r5 + 1
            r0 = r5
            r1 = r8
            int r1 = r1.getBatchSize()
            if (r0 != r1) goto L27
        L20:
            r0 = r7
            r0.updateMiniBatch()
            r0 = 0
            r5 = r0
        L27:
            r0 = r5
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: edu.emory.mathcs.nlp.component.template.train.OnlineTrainer.update(edu.emory.mathcs.nlp.component.template.OnlineComponent, int, boolean):int");
    }

    public void saveModel(OnlineComponent<N, S> onlineComponent, OutputStream outputStream) {
        ObjectOutputStream createObjectXZBufferedOutputStream = IOUtils.createObjectXZBufferedOutputStream(outputStream);
        LOG.info("Saving the model");
        try {
            createObjectXZBufferedOutputStream.writeObject(onlineComponent);
            createObjectXZBufferedOutputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void reduceModel(TSVReader<N> tSVReader, List<String> list, OnlineComponent<N, S> onlineComponent, GlobalLexica<N> globalLexica, String str, String str2) {
        LOG.info("Reducing:");
        LOG.info(String.format("%8.4f: %7d -> %s", Float.valueOf(0.0f), Integer.valueOf(onlineComponent.getFeatureTemplate().getSparseFeatureSize()), onlineComponent.getEval().toString()));
        Element firstElementByTagName = XMLUtils.getFirstElementByTagName(onlineComponent.getConfiguration().getDocumentElement(), "reducer");
        double doubleTextContentFromFirstElementByTagName = XMLUtils.getDoubleTextContentFromFirstElementByTagName(firstElementByTagName, "lower_bound");
        float floatTextContentFromFirstElementByTagName = XMLUtils.getFloatTextContentFromFirstElementByTagName(firstElementByTagName, "start");
        float floatTextContentFromFirstElementByTagName2 = XMLUtils.getFloatTextContentFromFirstElementByTagName(firstElementByTagName, "increment");
        float floatTextContentFromFirstElementByTagName3 = XMLUtils.getFloatTextContentFromFirstElementByTagName(firstElementByTagName, "range");
        int integerTextContentFromFirstElementByTagName = XMLUtils.getIntegerTextContentFromFirstElementByTagName(firstElementByTagName, "iteration");
        float f = floatTextContentFromFirstElementByTagName;
        while (true) {
            float f2 = f;
            byte[] byteArray = IOUtils.toByteArray(onlineComponent);
            onlineComponent.getFeatureTemplate().reduce(onlineComponent.getOptimizer().getWeightVector(), f2);
            DoubleIntPair evaluate = evaluate(list, onlineComponent, globalLexica, tSVReader);
            LOG.info(String.format("%8.4f: %7d -> %s, N/S = %6d", Float.valueOf(f2), Integer.valueOf(onlineComponent.getFeatureTemplate().getSparseFeatureSize()), onlineComponent.getEval().toString(), Integer.valueOf(evaluate.i)));
            if (integerTextContentFromFirstElementByTagName <= 0 || Math.abs(doubleTextContentFromFirstElementByTagName - evaluate.d) <= floatTextContentFromFirstElementByTagName3) {
                break;
            }
            if (evaluate.d < doubleTextContentFromFirstElementByTagName) {
                onlineComponent = (OnlineComponent) IOUtils.fromByteArray(byteArray);
                f2 -= floatTextContentFromFirstElementByTagName2;
                floatTextContentFromFirstElementByTagName2 /= 2.0f;
                integerTextContentFromFirstElementByTagName--;
            }
            f = f2 + floatTextContentFromFirstElementByTagName2;
        }
        saveModel(onlineComponent, IOUtils.createFileOutputStream(str2));
    }
}
