package org.openimaj.workinprogress.sgdsvm;

import gnu.trove.list.array.TDoubleArrayList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.openimaj.time.Timer;
import org.openimaj.util.array.SparseFloatArray;

/* loaded from: input_file:org/openimaj/workinprogress/sgdsvm/SvmSgdMain.class */
public class SvmSgdMain {
    Loss LOSS = LossFunctions.LogLoss;
    boolean BIAS = true;
    boolean REGULARIZED_BIAS = false;
    String trainfile = null;
    String testfile = null;
    boolean normalize = true;
    double lambda = 1.0E-5d;
    int epochs = 5;
    int maxtrain = -1;
    int[] dims = {0};
    List<SparseFloatArray> xtrain = new ArrayList();
    TDoubleArrayList ytrain = new TDoubleArrayList();
    List<SparseFloatArray> xtest = new ArrayList();
    TDoubleArrayList ytest = new TDoubleArrayList();
    static final /* synthetic */ boolean $assertionsDisabled;

    String NAM(String str) {
        return String.format("%16s ", str);
    }

    String DEF(Object obj) {
        return " (default: " + obj + ".)";
    }

    void usage(String str) {
        System.err.println("Usage: " + str + " [options] trainfile [testfile]");
        System.err.println("Options:");
        System.err.println(NAM("-lambda x") + "Regularization parameter" + DEF(Double.valueOf(this.lambda)));
        System.err.println(NAM("-epochs n") + "Number of training epochs" + DEF(Integer.valueOf(this.epochs)));
        System.err.println(NAM("-dontnormalize") + "Do not normalize the L2 norm of patterns.");
        System.err.println(NAM("-maxtrain n") + "Restrict training set to n examples.");
        System.exit(10);
    }

    void parse(String[] strArr) {
        int i = 0;
        while (i < strArr.length) {
            String str = strArr[i];
            if (str.charAt(0) != '-') {
                if (this.trainfile == null) {
                    this.trainfile = str;
                } else if (this.testfile == null) {
                    this.testfile = str;
                } else {
                    usage(getClass().getName());
                }
            } else if (str == "lambda" && i + 1 < strArr.length) {
                i++;
                this.lambda = Double.parseDouble(strArr[i]);
                if (!$assertionsDisabled && (this.lambda <= 0.0d || this.lambda >= 10000.0d)) {
                    throw new AssertionError();
                }
            } else if (str == "epochs" && i + 1 < strArr.length) {
                i++;
                this.epochs = Integer.parseInt(strArr[i]);
                if (!$assertionsDisabled && (this.epochs <= 0 || this.epochs >= 1000000.0d)) {
                    throw new AssertionError();
                }
            } else if (str == "dontnormalize") {
                this.normalize = false;
            } else if (str != "maxtrain" || i + 1 >= strArr.length) {
                System.err.println("Option " + strArr[i] + " not recognized.");
                usage(getClass().getName());
            } else {
                i++;
                this.maxtrain = Integer.parseInt(strArr[i]);
                if (!$assertionsDisabled && this.maxtrain <= 0) {
                    throw new AssertionError();
                }
            }
            i++;
        }
        if (this.trainfile == null) {
            usage(getClass().getName());
        }
    }

    void config(String str) {
        System.out.print("# Running: " + str);
        System.out.print(" -lambda " + this.lambda);
        System.out.print(" -epochs " + this.epochs);
        if (!this.normalize) {
            System.out.print(" -dontnormalize");
        }
        if (this.maxtrain > 0) {
            System.out.print(" -maxtrain " + this.maxtrain);
        }
        System.out.println();
        System.out.print("# Compiled with: -DLOSS=" + this.LOSS + " -DBIAS=" + this.BIAS + "-DREGULARIZED_BIAS=" + this.REGULARIZED_BIAS);
    }

    public static void main(String[] strArr) throws IOException {
        new SvmSgdMain().run(strArr);
    }

    void run(String[] strArr) throws IOException {
        parse(strArr);
        config(getClass().getName());
        if (this.trainfile != null) {
            load_datafile(this.trainfile, this.xtrain, this.ytrain, this.dims, this.normalize, this.maxtrain);
        }
        if (this.testfile != null) {
            load_datafile(this.testfile, this.xtest, this.ytest, this.dims, this.normalize);
        }
        System.out.println("# Number of features " + this.dims + ".");
        int size = this.xtrain.size() - 1;
        int size2 = this.xtest.size() - 1;
        SvmSgd svmSgd = new SvmSgd(this.dims[0], this.lambda);
        svmSgd.BIAS = this.BIAS;
        svmSgd.LOSS = this.LOSS;
        svmSgd.REGULARIZED_BIAS = this.REGULARIZED_BIAS;
        Timer timer = new Timer();
        int min = 0 + Math.min(1000, size);
        timer.start();
        svmSgd.determineEta0(0, min, this.xtrain, this.ytrain);
        timer.stop();
        for (int i = 0; i < this.epochs; i++) {
            System.out.println("--------- Epoch " + i + "1.");
            timer.start();
            svmSgd.train(0, size, this.xtrain, this.ytrain);
            timer.stop();
            System.out.println("Total training time " + (timer.duration() / 1000) + "secs.");
            svmSgd.test(0, size, this.xtrain, this.ytrain, "train:");
            if (size2 >= 0) {
                svmSgd.test(0, size2, this.xtest, this.ytest, "test: ");
            }
        }
    }

    private static int load_datafile(String str, List<SparseFloatArray> list, TDoubleArrayList tDoubleArrayList, int[] iArr, boolean z) throws IOException {
        return load_datafile(str, list, tDoubleArrayList, iArr, z, -1);
    }

    private static int load_datafile(String str, List<SparseFloatArray> list, TDoubleArrayList tDoubleArrayList, int[] iArr, boolean z, int i) throws IOException {
        int[] iArr2 = {0};
        int[] iArr3 = {0};
        int[] iArr4 = {0};
        new Loader(str).load(list, tDoubleArrayList, z, i, iArr2, iArr3, iArr4);
        if (iArr3[0] + iArr4[0] > 0) {
            System.out.println("# Read " + iArr3 + "+" + iArr4 + "=" + iArr3 + iArr4 + " examples from \"" + str + "\".");
        }
        if (iArr[0] < iArr2[0]) {
            iArr[0] = iArr2[0];
        }
        return iArr3[0] + iArr4[0];
    }

    static {
        $assertionsDisabled = !SvmSgdMain.class.desiredAssertionStatus();
    }
}
