package org.fnlp.nlp.tag;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.fnlp.data.reader.SequenceReader;
import org.fnlp.ml.classifier.linear.Linear;
import org.fnlp.ml.classifier.linear.OnlineTrainer;
import org.fnlp.ml.classifier.linear.inf.Inferencer;
import org.fnlp.ml.classifier.linear.update.Update;
import org.fnlp.ml.classifier.struct.inf.HigherOrderViterbi;
import org.fnlp.ml.classifier.struct.inf.LinearViterbi;
import org.fnlp.ml.classifier.struct.update.HigherOrderViterbiPAUpdate;
import org.fnlp.ml.classifier.struct.update.LinearViterbiPAUpdate;
import org.fnlp.ml.loss.struct.HammingLoss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.alphabet.IFeatureAlphabet;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.nlp.cn.tag.format.SimpleFormatter;
import org.fnlp.nlp.pipe.Pipe;
import org.fnlp.nlp.pipe.SeriesPipes;
import org.fnlp.nlp.pipe.Target2Label;
import org.fnlp.nlp.pipe.seq.Sequence2FeatureSequence;
import org.fnlp.nlp.pipe.seq.templet.TempletGroup;

/* loaded from: input_file:org/fnlp/nlp/tag/Tagger.class */
public class Tagger {
    protected Linear cl;
    protected String train;
    protected String templateFile;
    public static boolean standard = true;
    protected String model;
    protected int iterNum;
    protected float c;
    protected AlphabetFactory factory;
    protected Pipe featurePipe;
    protected TempletGroup templets;
    protected String newmodel;
    protected boolean hasLabel;
    protected String testfile = null;
    protected String output = null;
    protected boolean useLoss = true;
    protected String delimiter = "\\s+|\\t+";
    protected boolean interim = false;

    public void setFile(String str, String str2, String str3) {
        this.templateFile = str;
        this.train = str2;
        this.model = str3;
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        options.addOption("h", false, "Print help for this application");
        options.addOption("iter", true, "iterative num, default 50");
        options.addOption("c", true, "parameters C in PA algorithm, default 0.8");
        options.addOption("train", false, "switch to training mode(Default: test model");
        options.addOption("retrain", false, "switch to retraining mode(Default: test model");
        options.addOption("margin", false, "use hamming loss as margin threshold");
        options.addOption("interim", false, "save interim model file");
        options.addOption("haslabel", false, "test file has includes label or not");
        try {
            CommandLine parse = new BasicParser().parse(options, strArr);
            if (strArr.length == 0 || parse.hasOption('h')) {
                new HelpFormatter().printHelp("Tagger:\ntagger [option] -train templet_file train_file model_file [test_file];\ntagger [option] -retrain train_file model_file newmodel_file [test_file];\ntagger [option] -label model_file test_file output_file\n", options);
                return;
            }
            Tagger tagger = new Tagger();
            tagger.iterNum = Integer.parseInt(parse.getOptionValue("iter", "50"));
            tagger.c = Float.parseFloat(parse.getOptionValue("c", "0.8"));
            tagger.useLoss = parse.hasOption("margin");
            tagger.interim = parse.hasOption("interim");
            tagger.hasLabel = parse.hasOption("haslabel");
            String[] args = parse.getArgs();
            if (parse.hasOption("train") && args.length == 3) {
                tagger.templateFile = args[0];
                tagger.train = args[1];
                tagger.model = args[2];
                System.out.println("Training model ...");
                tagger.train();
            } else if (parse.hasOption("train") && args.length == 4) {
                tagger.templateFile = args[0];
                tagger.train = args[1];
                tagger.model = args[2];
                tagger.testfile = args[3];
                System.out.println("Training model ...");
                tagger.train();
            } else if (parse.hasOption("train") && args.length == 5) {
                tagger.templateFile = args[0];
                tagger.train = args[1];
                tagger.model = args[2];
                tagger.testfile = args[3];
                System.out.println("Training model ...");
                tagger.train();
                System.gc();
                tagger.output = args[4];
                tagger.test();
            } else if (parse.hasOption("retrain") && args.length == 3) {
                tagger.train = args[0];
                tagger.model = args[1];
                tagger.newmodel = args[2];
                System.out.println("Re-Training model ...");
                tagger.train(true);
            } else if (parse.hasOption("retrain") && args.length == 4) {
                tagger.train = args[0];
                tagger.model = args[1];
                tagger.newmodel = args[2];
                tagger.testfile = args[3];
                System.out.println("Re-Training model ...");
                tagger.train(true);
            } else if (parse.hasOption("retrain") && args.length == 5) {
                tagger.train = args[0];
                tagger.model = args[1];
                tagger.newmodel = args[2];
                tagger.testfile = args[3];
                System.out.println("Re-Training model ...");
                tagger.train(true);
                System.gc();
                tagger.output = args[4];
                tagger.test();
            } else if (args.length == 3) {
                tagger.model = args[0];
                tagger.testfile = args[1];
                tagger.output = args[2];
                tagger.test();
            } else if (args.length != 2) {
                System.err.println("paramenters format error!");
                System.err.println("Print option \"-h\" for help.");
                return;
            } else {
                tagger.model = args[0];
                tagger.testfile = args[1];
                tagger.test();
            }
            System.gc();
        } catch (Exception e) {
            System.err.println("Parameters format error");
        }
    }

    public Pipe createProcessor(boolean z) throws Exception {
        if (!z) {
            this.templets = new TempletGroup();
            this.templets.load(this.templateFile);
        }
        if (this.cl != null) {
            this.factory = this.cl.getAlphabetFactory();
        } else {
            this.factory = AlphabetFactory.buildFactory();
        }
        LabelAlphabet DefaultLabelAlphabet = this.factory.DefaultLabelAlphabet();
        this.featurePipe = new Sequence2FeatureSequence(this.templets, this.factory.DefaultFeatureAlphabet(), DefaultLabelAlphabet);
        return new SeriesPipes(new Pipe[]{new Target2Label(DefaultLabelAlphabet), this.featurePipe});
    }

    public void train() throws Exception {
        train(false);
    }

    public void train(boolean z) throws Exception {
        Inferencer higherOrderViterbi;
        Update higherOrderViterbiPAUpdate;
        System.out.print("Loading training data ...");
        long currentTimeMillis = System.currentTimeMillis();
        if (z) {
            loadFrom(this.model);
        }
        Pipe createProcessor = createProcessor(z);
        InstanceSet instanceSet = new InstanceSet(createProcessor, this.factory);
        LabelAlphabet DefaultLabelAlphabet = this.factory.DefaultLabelAlphabet();
        IFeatureAlphabet DefaultFeatureAlphabet = this.factory.DefaultFeatureAlphabet();
        if (z) {
            DefaultFeatureAlphabet.setStopIncrement(false);
            DefaultLabelAlphabet.setStopIncrement(false);
        }
        instanceSet.loadThruStagePipes(new SequenceReader(this.train, true));
        long currentTimeMillis2 = System.currentTimeMillis();
        System.out.println(" done!");
        System.out.println("Time escape: " + ((currentTimeMillis2 - currentTimeMillis) / 1000) + "s");
        System.out.println();
        System.out.println("Training Number: " + instanceSet.size());
        System.out.println("Label Number: " + DefaultLabelAlphabet.size());
        System.out.println("Feature Number: " + DefaultFeatureAlphabet.size());
        System.out.println();
        DefaultFeatureAlphabet.setStopIncrement(true);
        DefaultLabelAlphabet.setStopIncrement(true);
        InstanceSet instanceSet2 = null;
        if (this.testfile != null) {
            instanceSet2 = new InstanceSet(createProcessor);
            instanceSet2.loadThruStagePipes(new SequenceReader(this.testfile, true, "utf8"));
            System.out.println("Test Number: " + instanceSet2.size());
        }
        HammingLoss hammingLoss = new HammingLoss();
        if (standard) {
            higherOrderViterbi = new LinearViterbi(this.templets, DefaultLabelAlphabet.size());
            higherOrderViterbiPAUpdate = new LinearViterbiPAUpdate((LinearViterbi) higherOrderViterbi, hammingLoss);
        } else {
            higherOrderViterbi = new HigherOrderViterbi(this.templets, DefaultLabelAlphabet.size());
            higherOrderViterbiPAUpdate = new HigherOrderViterbiPAUpdate(this.templets, DefaultLabelAlphabet.size(), true);
        }
        OnlineTrainer onlineTrainer = z ? new OnlineTrainer(this.cl, higherOrderViterbiPAUpdate, hammingLoss, DefaultFeatureAlphabet.size(), this.iterNum, this.c) : new OnlineTrainer(higherOrderViterbi, higherOrderViterbiPAUpdate, hammingLoss, DefaultFeatureAlphabet.size(), this.iterNum, this.c);
        onlineTrainer.innerOptimized = false;
        onlineTrainer.finalOptimized = false;
        this.cl = onlineTrainer.train(instanceSet, instanceSet2);
        if (z) {
            saveTo(this.newmodel);
        } else {
            saveTo(this.model);
        }
    }

    /* JADX WARN: Type inference failed for: r0v18, types: [java.lang.String[], java.lang.String[][]] */
    /* JADX WARN: Type inference failed for: r0v21, types: [java.lang.String[], java.lang.String[][]] */
    private void test() throws Exception {
        if (this.cl == null) {
            loadFrom(this.model);
        }
        long currentTimeMillis = System.currentTimeMillis();
        InstanceSet instanceSet = new InstanceSet(createProcessor(true));
        instanceSet.loadThruStagePipes(new SequenceReader(this.testfile, this.hasLabel, "utf8"));
        System.out.println("Test Number: " + instanceSet.size());
        long currentTimeMillis2 = System.currentTimeMillis();
        float f = 0.0f;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        HammingLoss hammingLoss = new HammingLoss();
        ?? r0 = new String[instanceSet.size()];
        ?? r02 = new String[instanceSet.size()];
        LabelAlphabet DefaultLabelAlphabet = this.cl.getAlphabetFactory().DefaultLabelAlphabet();
        for (int i5 = 0; i5 < instanceSet.size(); i5++) {
            Instance instance = instanceSet.get(i5);
            int[] iArr = (int[]) this.cl.classify(instance).getLabel(0);
            if (this.hasLabel) {
                i2 += iArr.length;
                float calc = hammingLoss.calc(instance.getTarget(), iArr);
                f += calc;
                if (calc != 0.0f) {
                    i++;
                }
                if (0 != 0) {
                    String[][] strArr = (String[][]) instance.getSource();
                    int[] iArr2 = (int[]) instance.getTarget();
                    for (int i6 = 0; i6 < iArr2.length; i6++) {
                        if (strArr[i6][0].contains("ENG")) {
                            i3++;
                            if (iArr2[i6] == iArr[i6]) {
                                i4++;
                            }
                        }
                    }
                }
            }
            r0[i5] = DefaultLabelAlphabet.lookupString(iArr);
            if (this.hasLabel) {
                r02[i5] = DefaultLabelAlphabet.lookupString((int[]) instance.getTarget());
            }
        }
        long currentTimeMillis3 = System.currentTimeMillis();
        System.out.println("totaltime\t" + ((currentTimeMillis3 - currentTimeMillis) / 1000.0d));
        System.out.println("feature\t" + ((currentTimeMillis2 - currentTimeMillis) / 1000.0d));
        System.out.println("predict\t" + ((currentTimeMillis3 - currentTimeMillis2) / 1000.0d));
        if (this.hasLabel) {
            System.out.println("Test Accuracy:\t" + (1.0f - (f / i2)));
            System.out.println("Sentence Accuracy:\t" + ((instanceSet.size() - i) / instanceSet.size()));
            if (0 != 0) {
                System.out.println("ENG Accuracy:\t" + (i4 / i3));
            }
        }
        if (this.output != null) {
            BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(this.output), "utf8"));
            bufferedWriter.write((this.hasLabel ? SimpleFormatter.format(instanceSet, (String[][]) r0, (String[][]) r02) : SimpleFormatter.format(instanceSet, (String[][]) r0)).trim());
            bufferedWriter.close();
        }
        System.out.println("Done");
    }

    protected void saveTo(String str) throws IOException {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new GZIPOutputStream(new FileOutputStream(str))));
        objectOutputStream.writeObject(this.templets);
        objectOutputStream.writeObject(this.cl);
        objectOutputStream.close();
    }

    protected void loadFrom(String str) throws IOException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(str))));
        this.templets = (TempletGroup) objectInputStream.readObject();
        this.cl = (Linear) objectInputStream.readObject();
        objectInputStream.close();
    }
}
