package org.fnlp.nlp.parser.dep.train;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.fnlp.ml.classifier.linear.Linear;
import org.fnlp.ml.classifier.linear.OnlineTrainer;
import org.fnlp.ml.classifier.linear.inf.LinearMax;
import org.fnlp.ml.classifier.linear.update.LinearMaxPAUpdate;
import org.fnlp.ml.feature.SFGenerator;
import org.fnlp.ml.loss.ZeroOneLoss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.alphabet.IFeatureAlphabet;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.nlp.parser.Sentence;
import org.fnlp.nlp.parser.dep.ParsingState;
import org.fnlp.nlp.parser.dep.reader.CoNLLReader;

/* loaded from: input_file:org/fnlp/nlp/parser/dep/train/ParserTrainer.class */
public class ParserTrainer {
    String modelfile;
    Charset charset;
    File fp;
    AlphabetFactory factory;

    public ParserTrainer(String str) {
        this(str, "UTF-8");
        this.factory = AlphabetFactory.buildFactory();
    }

    public ParserTrainer(String str, String str2) {
        this.modelfile = str;
        this.charset = Charset.forName(str2);
    }

    private void buildInstanceList(String str) throws IOException {
        System.out.print("generating training instances ...");
        CoNLLReader coNLLReader = new CoNLLReader(str);
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(this.fp), this.charset));
        LabelAlphabet buildLabelAlphabet = this.factory.buildLabelAlphabet("postag");
        int i = 0;
        while (coNLLReader.hasNext()) {
            Sentence sentence = (Sentence) coNLLReader.next();
            int[] iArr = (int[]) sentence.getTarget();
            ParsingState parsingState = new ParsingState(sentence, this.factory);
            while (!parsingState.isFinalState()) {
                int[] focusIndices = parsingState.getFocusIndices();
                HashSparseVector features = parsingState.getFeatures();
                ParsingState.Action action = getAction(focusIndices[0], focusIndices[1], iArr);
                parsingState.next(action);
                if (action == ParsingState.Action.LEFT) {
                    iArr[focusIndices[1]] = -1;
                }
                if (action == ParsingState.Action.RIGHT) {
                    iArr[focusIndices[0]] = -1;
                }
                String tagAt = sentence.getTagAt(focusIndices[0]);
                buildLabelAlphabet.lookupIndex(tagAt);
                bufferedWriter.write(tagAt);
                bufferedWriter.write(" ");
                switch (action) {
                    case LEFT:
                        bufferedWriter.write("L");
                        break;
                    case RIGHT:
                        bufferedWriter.write("R");
                        break;
                    default:
                        bufferedWriter.write("S");
                        break;
                }
                bufferedWriter.write(" ");
                int[] indices = features.indices();
                Arrays.sort(indices);
                for (int i2 : indices) {
                    bufferedWriter.write(String.valueOf(i2));
                    bufferedWriter.write(" ");
                }
                bufferedWriter.newLine();
            }
            bufferedWriter.write(10);
            bufferedWriter.flush();
            i++;
        }
        bufferedWriter.close();
        System.out.println(" ... finished");
        System.out.printf("%d instances have benn loaded.\n\n", Integer.valueOf(i));
    }

    public void train(String str, int i, float f) throws IOException {
        this.fp = File.createTempFile("train-features", null, new File("./tmp/"));
        buildInstanceList(str);
        LabelAlphabet buildLabelAlphabet = this.factory.buildLabelAlphabet("postag");
        IFeatureAlphabet DefaultFeatureAlphabet = this.factory.DefaultFeatureAlphabet();
        SFGenerator sFGenerator = new SFGenerator();
        Linear[] linearArr = new Linear[buildLabelAlphabet.size()];
        int size = DefaultFeatureAlphabet.size();
        for (int i2 = 0; i2 < buildLabelAlphabet.size(); i2++) {
            String lookupString = buildLabelAlphabet.lookupString(i2);
            InstanceSet readInstanceSet = readInstanceSet(lookupString);
            int size2 = this.factory.buildLabelAlphabet(lookupString).size();
            System.out.printf("Training with data: %s\n", lookupString);
            System.out.printf("Number of labels: %d\n", Integer.valueOf(size2));
            LinearMax linearMax = new LinearMax(sFGenerator, size2);
            ZeroOneLoss zeroOneLoss = new ZeroOneLoss();
            linearArr[i2] = new OnlineTrainer(linearMax, new LinearMaxPAUpdate(zeroOneLoss), zeroOneLoss, size, i, f).train(readInstanceSet, (InstanceSet) null);
            System.out.println();
        }
        this.factory.setStopIncrement(true);
        saveModels(this.modelfile, linearArr, this.factory);
        this.fp.delete();
        this.fp = null;
    }

    private InstanceSet readInstanceSet(String str) throws IOException {
        InstanceSet instanceSet = new InstanceSet();
        LabelAlphabet buildLabelAlphabet = this.factory.buildLabelAlphabet(str);
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(this.fp), this.charset));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                buildLabelAlphabet.setStopIncrement(true);
                instanceSet.setAlphabetFactory(this.factory);
                return instanceSet;
            }
            String trim = readLine.trim();
            if (!trim.matches("^$") && trim.startsWith(str + " ")) {
                List asList = Arrays.asList(trim.split("\\s+"));
                int[] iArr = new int[asList.size() - 2];
                for (int i = 0; i < iArr.length; i++) {
                    iArr[i] = Integer.parseInt((String) asList.get(i + 2));
                }
                Instance instance = new Instance(iArr);
                instance.setTarget(Integer.valueOf(buildLabelAlphabet.lookupIndex((String) asList.get(1))));
                instanceSet.add(instance);
            }
        }
    }

    public static void saveModels(String str, Linear[] linearArr, AlphabetFactory alphabetFactory) throws IOException {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(str)));
        objectOutputStream.writeObject(alphabetFactory);
        objectOutputStream.writeObject(linearArr);
        objectOutputStream.close();
    }

    private ParsingState.Action getAction(int i, int i2, int[] iArr) {
        return (iArr[i] == i2 && modifierNumOf(i, iArr) == 0) ? ParsingState.Action.RIGHT : (iArr[i2] == i && modifierNumOf(i2, iArr) == 0) ? ParsingState.Action.LEFT : ParsingState.Action.SHIFT;
    }

    private int modifierNumOf(int i, int[] iArr) {
        int i2 = 0;
        for (int i3 : iArr) {
            if (i3 == i) {
                i2++;
            }
        }
        return i2;
    }

    public static void main(String[] strArr) throws Exception {
        String[] strArr2 = {"./tmp/CoNLL2009-ST-Chinese-train.txt", "./tmp/modelConll.gz"};
        Options options = new Options();
        options.addOption("h", false, "Print help for this application");
        options.addOption("iter", true, "iterative num, default 50");
        options.addOption("c", true, "parameters 1, default 1");
        try {
            CommandLine parse = new BasicParser().parse(options, strArr2);
            if (strArr2.length == 0 || parse.hasOption('h')) {
                new HelpFormatter().printHelp("Tagger:\nParserTrainer [option] train_file model_file;\n", options);
                return;
            }
            String[] args = parse.getArgs();
            new ParserTrainer(args[1]).train(args[0], Integer.parseInt(parse.getOptionValue("iter", "50")), Float.parseFloat(parse.getOptionValue("c", "1")));
        } catch (Exception e) {
            System.err.println("Parameters format error");
        }
    }
}
