package org.fnlp.ml.classifier.struct;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.fnlp.ml.classifier.linear.Linear;
import org.fnlp.ml.classifier.linear.OnlineTrainer;
import org.fnlp.ml.classifier.linear.inf.Inferencer;
import org.fnlp.ml.classifier.linear.update.Update;
import org.fnlp.ml.loss.Loss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.util.MyArrays;

/* loaded from: input_file:org/fnlp/ml/classifier/struct/OnlineHybridTrainer.class */
public class OnlineHybridTrainer extends OnlineTrainer {
    public OnlineHybridTrainer(Inferencer inferencer, Update update, Loss loss, int i, int i2, float f) {
        super(inferencer, update, loss, i, i2, f);
    }

    @Override // org.fnlp.ml.classifier.linear.OnlineTrainer, org.fnlp.ml.classifier.linear.AbstractTrainer
    public Linear train(InstanceSet instanceSet, InstanceSet instanceSet2) {
        int size = instanceSet.size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            for (int[] iArr : (int[][]) instanceSet.getInstance(i2).getTarget()) {
                i += iArr.length;
            }
        }
        System.out.println("Training Number: " + size);
        System.out.println("Chars Number: " + i);
        int i3 = 0;
        int i4 = size / 10;
        float[] fArr = (this.method == OnlineTrainer.TrainMethod.Average || this.method == OnlineTrainer.TrainMethod.FastAverage) ? new float[this.weights.length] : null;
        long currentTimeMillis = System.currentTimeMillis();
        if (this.shuffle) {
            instanceSet.shuffle(this.random);
        }
        while (true) {
            int i5 = i3;
            i3++;
            if (i5 >= this.iternum) {
                break;
            }
            if (!this.simpleOutput) {
                System.out.print("iter:");
                System.out.print(i3 + "\t");
            }
            float f = 0.0f;
            float f2 = 0.0f;
            int i6 = i4;
            long currentTimeMillis2 = System.currentTimeMillis();
            float[] copyOf = this.method == OnlineTrainer.TrainMethod.Average ? Arrays.copyOf(this.weights, this.weights.length) : null;
            for (int i7 = 0; i7 < size; i7++) {
                Instance instanceSet3 = instanceSet.getInstance(i7);
                List list = (List) this.inferencer.getBest(instanceSet3, 1);
                float calc = this.loss.calc((int[][]) list.get(0), (int[][]) instanceSet3.getTarget());
                if (calc > 0.0f) {
                    f += calc;
                    f2 += 1.0f;
                    this.update.update(instanceSet3, this.weights, list.get(0), this.c);
                }
                if (this.method == OnlineTrainer.TrainMethod.Average) {
                    for (int i8 = 0; i8 < this.weights.length; i8++) {
                        float[] fArr2 = copyOf;
                        int i9 = i8;
                        fArr2[i9] = fArr2[i9] + this.weights[i8];
                    }
                }
                if (this.DEBUG && calc > 0.0f) {
                    this.loss.calc((int[]) ((List) this.inferencer.getBest(instanceSet3, 1)).get(0), (int[]) instanceSet3.getTarget());
                }
                if (!this.simpleOutput && i7 % i6 == 0) {
                    System.out.print('.');
                    i6 += i4;
                }
            }
            float f3 = f / i;
            long currentTimeMillis3 = System.currentTimeMillis();
            if (!this.simpleOutput) {
                System.out.println("\ttime:" + ((currentTimeMillis3 - currentTimeMillis2) / 1000.0d) + "s");
                System.out.print("Train:");
                System.out.print("\tTag acc:");
            }
            System.out.print(1.0f - f3);
            if (!this.simpleOutput) {
                System.out.print("\tSentence acc:");
                System.out.print(1.0f - (f2 / size));
                System.out.println();
            }
            System.out.print("Weight Numbers: " + MyArrays.countNoneZero(this.weights));
            if (this.innerOptimized) {
                MyArrays.set(this.weights, MyArrays.getTop((float[]) this.weights.clone(), this.threshold, false), 0.0f);
                System.out.print("\tAfter Optimized: " + MyArrays.countNoneZero(this.weights));
            }
            System.out.println();
            if (instanceSet2 != null) {
                evaluate(instanceSet2);
            }
            if (this.method == OnlineTrainer.TrainMethod.Average) {
                for (int i10 = 0; i10 < copyOf.length; i10++) {
                    float[] fArr3 = fArr;
                    int i11 = i10;
                    fArr3[i11] = fArr3[i11] + (copyOf[i10] / size);
                }
            } else if (this.method == OnlineTrainer.TrainMethod.FastAverage) {
                for (int i12 = 0; i12 < this.weights.length; i12++) {
                    float[] fArr4 = fArr;
                    int i13 = i12;
                    fArr4[i13] = fArr4[i13] + this.weights[i12];
                }
            }
            if (this.interim) {
                try {
                    new Linear(this.inferencer, instanceSet.getAlphabetFactory()).saveTo("tmp.model");
                } catch (IOException e) {
                    System.err.println("write model error!");
                }
            }
        }
        if (this.method == OnlineTrainer.TrainMethod.Average || this.method == OnlineTrainer.TrainMethod.FastAverage) {
            for (int i14 = 0; i14 < fArr.length; i14++) {
                float[] fArr5 = fArr;
                int i15 = i14;
                fArr5[i15] = fArr5[i15] / this.iternum;
            }
            this.weights = null;
            this.weights = fArr;
            this.inferencer.setWeights(this.weights);
        }
        System.out.print("Weight Numbers: " + MyArrays.countNoneZero(this.weights));
        if (this.finalOptimized) {
            MyArrays.set(this.weights, MyArrays.getTop((float[]) this.weights.clone(), this.threshold, false), 0.0f);
            System.out.print("\tAfter Optimized: " + MyArrays.countNoneZero(this.weights));
        }
        System.out.println();
        System.out.println("time escape:" + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + "s");
        return new Linear(this.inferencer, instanceSet.getAlphabetFactory());
    }

    @Override // org.fnlp.ml.classifier.linear.OnlineTrainer, org.fnlp.ml.classifier.linear.AbstractTrainer
    public void evaluate(InstanceSet instanceSet) {
        float f = 0.0f;
        float f2 = 0.0f;
        int i = 0;
        for (int i2 = 0; i2 < instanceSet.size(); i2++) {
            Instance instanceSet2 = instanceSet.getInstance(i2);
            i += ((int[]) instanceSet2.getTarget()).length;
            float calc = this.loss.calc(((List) this.inferencer.getBest(instanceSet2, 1)).get(0), instanceSet2.getTarget());
            if (calc > 0.0f) {
                f2 = (float) (f2 + 1.0d);
                f += calc;
            }
        }
        if (this.simpleOutput) {
            System.out.print('\t');
        } else {
            System.out.print("Test:\t");
            System.out.print(i - f);
            System.out.print('/');
            System.out.print(i);
            System.out.print("\tTag acc:");
        }
        System.out.print(1.0f - (f / i));
        if (!this.simpleOutput) {
            System.out.print("\tSentence acc:");
            System.out.println(1.0f - (f2 / instanceSet.size()));
        }
        System.out.println();
    }
}
