package net.haesleinhuepf.clijx.assistant.interactive.handcrafted;

import ij.IJ;
import ij.gui.GenericDialog;
import ij.gui.Toolbar;
import ij.measure.ResultsTable;
import ij.plugin.frame.RoiManager;
import java.awt.Button;
import java.awt.Checkbox;
import java.awt.Frame;
import java.awt.Menu;
import java.awt.Panel;
import java.awt.TextField;
import java.util.HashMap;
import net.haesleinhuepf.clij.clearcl.ClearCLBuffer;
import net.haesleinhuepf.clij.macro.CLIJMacroPlugin;
import net.haesleinhuepf.clij2.assistant.annotation.AnnotationTool;
import net.haesleinhuepf.clij2.assistant.optimize.OptimizationUtilities;
import net.haesleinhuepf.clij2.assistant.services.AssistantGUIPlugin;
import net.haesleinhuepf.clij2.assistant.utilities.IJLogger;
import net.haesleinhuepf.clij2.assistant.utilities.Logger;
import net.haesleinhuepf.clijx.CLIJx;
import net.haesleinhuepf.clijx.assistant.AbstractCLIJxAssistantGUIPlugin;
import net.haesleinhuepf.clijx.assistant.utilities.AssistantUtilities;
import net.haesleinhuepf.clijx.weka.GenerateLabelFeatureImage;
import net.haesleinhuepf.clijx.weka.TrainWekaFromTable;
import net.haesleinhuepf.spimcat.io.CLIJxVirtualStack;
import org.scijava.plugin.Plugin;
import org.scijava.util.VersionUtils;

@Plugin(type = AssistantGUIPlugin.class)
/* loaded from: input_file:net/haesleinhuepf/clijx/assistant/interactive/handcrafted/WekaLabelClassifier.class */
public class WekaLabelClassifier extends AbstractCLIJxAssistantGUIPlugin {
    GenericDialog dialog;
    String features;
    String filename;
    int num_trees;
    int num_features;
    int max_depth;
    int radius_of_maximum;
    int radius_of_minimum;
    int radius_of_mean;
    int radius_of_standard_deviation;
    boolean show_table;

    public WekaLabelClassifier() {
        super(new net.haesleinhuepf.clijx.weka.WekaLabelClassifier());
        this.features = GenerateLabelFeatureImage.defaultFeatures();
        this.filename = "label_classification.model";
        this.num_trees = 200;
        this.num_features = 2;
        this.max_depth = 0;
        this.radius_of_maximum = 0;
        this.radius_of_minimum = 0;
        this.radius_of_mean = 1;
        this.radius_of_standard_deviation = 0;
        this.show_table = false;
    }

    public WekaLabelClassifier(CLIJMacroPlugin cLIJMacroPlugin) {
        super(cLIJMacroPlugin);
        this.features = GenerateLabelFeatureImage.defaultFeatures();
        this.filename = "label_classification.model";
        this.num_trees = 200;
        this.num_features = 2;
        this.max_depth = 0;
        this.radius_of_maximum = 0;
        this.radius_of_minimum = 0;
        this.radius_of_mean = 1;
        this.radius_of_standard_deviation = 0;
        this.show_table = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.haesleinhuepf.clijx.assistant.AbstractCLIJxAssistantGUIPlugin
    public GenericDialog buildNonModalDialog(Frame frame) {
        GenericDialog genericDialog = new GenericDialog(AssistantUtilities.niceNameWithoutDimShape(getName()));
        this.dialog = genericDialog;
        String loadFeatures = BinaryWekaPixelClassifier.loadFeatures(this.filename + ".features.txt");
        if (loadFeatures.length() > 0) {
            this.features = loadFeatures;
        }
        genericDialog.addStringField("Feature definition", this.features, 30);
        TextField textField = (TextField) genericDialog.getStringFields().get(0);
        Panel panel = new Panel();
        Button button = new Button("Features...");
        button.addActionListener(actionEvent -> {
            GenericDialog genericDialog2 = new GenericDialog("Features");
            String[] allFeatures = GenerateLabelFeatureImage.allFeatures();
            for (String str : allFeatures) {
                genericDialog2.addCheckbox(str, (" " + this.features + " ").toLowerCase().contains(" " + str.toLowerCase() + " "));
            }
            genericDialog2.showDialog();
            if (genericDialog2.wasCanceled()) {
                return;
            }
            String str2 = " ";
            for (String str3 : allFeatures) {
                if (genericDialog2.getNextBoolean()) {
                    str2 = str2 + str3 + " ";
                }
            }
            if (str2.length() > 1) {
                textField.setText(str2.trim());
            }
        });
        panel.add(button);
        genericDialog.addPanel(panel);
        genericDialog.addStringField("Model file", this.filename, 30);
        TextField textField2 = (TextField) genericDialog.getStringFields().get(1);
        Panel panel2 = new Panel();
        Button button2 = new Button("File...");
        button2.addActionListener(actionEvent2 -> {
            String filePath = IJ.getFilePath("Model location");
            if (filePath.length() > 0) {
                textField2.setText(filePath);
                BinaryWekaPixelClassifier.loadFeatures(this.filename + ".features.txt");
                textField.setText(this.features);
            }
        });
        panel2.add(button2);
        genericDialog.addPanel(panel2);
        genericDialog.addNumericField("Number of trees", this.num_trees, 0);
        genericDialog.addNumericField("Number of features", this.num_features, 0);
        genericDialog.addNumericField("Max depth", this.max_depth, 0);
        if (getCLIJMacroPlugin() instanceof net.haesleinhuepf.clijx.weka.WekaRegionalLabelClassifier) {
            genericDialog.addNumericField("Radius_of_maximum", this.radius_of_maximum);
            genericDialog.addNumericField("Radius_of_minimum", this.radius_of_minimum);
            genericDialog.addNumericField("Radius_of_mean", this.radius_of_mean);
            genericDialog.addNumericField("Radius_of_standard_deviation", this.radius_of_standard_deviation);
        }
        genericDialog.addCheckbox("Show table while training", this.show_table);
        return genericDialog;
    }

    private void readDialog() {
        if (this.dialog != null) {
            this.features = ((TextField) this.dialog.getStringFields().get(0)).getText();
            this.filename = ((TextField) this.dialog.getStringFields().get(1)).getText();
            this.num_trees = (int) Double.parseDouble(((TextField) this.dialog.getNumericFields().get(0)).getText());
            this.num_features = (int) Double.parseDouble(((TextField) this.dialog.getNumericFields().get(1)).getText());
            this.max_depth = (int) Double.parseDouble(((TextField) this.dialog.getNumericFields().get(2)).getText());
            this.show_table = ((Checkbox) this.dialog.getCheckboxes().get(0)).getState();
            if (getCLIJMacroPlugin() instanceof net.haesleinhuepf.clijx.weka.WekaRegionalLabelClassifier) {
                this.radius_of_maximum = (int) Double.parseDouble(((TextField) this.dialog.getNumericFields().get(3)).getText());
                this.radius_of_minimum = (int) Double.parseDouble(((TextField) this.dialog.getNumericFields().get(4)).getText());
                this.radius_of_mean = (int) Double.parseDouble(((TextField) this.dialog.getNumericFields().get(5)).getText());
                this.radius_of_standard_deviation = (int) Double.parseDouble(((TextField) this.dialog.getNumericFields().get(6)).getText());
            }
        }
    }

    private void train(Logger logger) {
        if (getCLIJMacroPlugin() instanceof net.haesleinhuepf.clijx.weka.WekaRegionalLabelClassifier) {
            logger.log("Train Weka regional label classifier");
            logger.log("------------------------------------");
        } else {
            logger.log("Train Weka label classifier");
            logger.log("---------------------------");
        }
        CLIJx cLIJx = CLIJx.getInstance();
        logger.log("GPU: " + cLIJx.getGPUName() + " (OCLv: " + cLIJx.getOpenCLVersion() + ", AssistantV: " + VersionUtils.getVersion(getClass()) + ")");
        readDialog();
        RoiManager roiManager = RoiManager.getRoiManager();
        if (roiManager.getCount() == 0) {
            IJ.log("Please define reference ROIs in the ROI Manager.\n\nThese ROIs should have names starting with 'p' for positive and 'n' for negative.\n\nThe just activated annotation tool can help you with that.");
            Toolbar.addPlugInTool(new AnnotationTool());
            return;
        }
        HashMap makeLabelClassificationGroundTruth = OptimizationUtilities.makeLabelClassificationGroundTruth(cLIJx, this.my_sources[1], roiManager);
        ClearCLBuffer[][] imagePlusesToBuffers = CLIJxVirtualStack.imagePlusesToBuffers(this.my_sources);
        ClearCLBuffer clearCLBuffer = imagePlusesToBuffers[0][this.my_sources[0].getC() - 1];
        ClearCLBuffer clearCLBuffer2 = imagePlusesToBuffers[1][this.my_sources[1].getC() - 1];
        IJ.log("Intensity image: " + this.my_sources[0].getTitle());
        IJ.log("Label image: " + this.my_sources[1].getTitle());
        if (this.my_sources[0].getC() != this.my_sources[1].getC()) {
            IJ.log("Warning: intensity and label image have different selected channels.");
        }
        String text = ((TextField) this.dialog.getStringFields().get(0)).getText();
        ResultsTable resultsTable = new ResultsTable();
        ClearCLBuffer generateRegionalLabelFeatureImage = getCLIJMacroPlugin() instanceof net.haesleinhuepf.clijx.weka.WekaRegionalLabelClassifier ? net.haesleinhuepf.clijx.weka.WekaRegionalLabelClassifier.generateRegionalLabelFeatureImage(cLIJx, clearCLBuffer, clearCLBuffer2, text, this.radius_of_maximum, this.radius_of_minimum, this.radius_of_mean, this.radius_of_standard_deviation) : GenerateLabelFeatureImage.generateLabelFeatureImage(cLIJx, clearCLBuffer, clearCLBuffer2, text);
        cLIJx.pullToResultsTable(generateRegionalLabelFeatureImage, resultsTable);
        generateRegionalLabelFeatureImage.close();
        String text2 = ((TextField) this.dialog.getStringFields().get(1)).getText();
        for (Integer num : makeLabelClassificationGroundTruth.keySet()) {
            int intValue = ((Integer) makeLabelClassificationGroundTruth.get(num)).intValue();
            System.out.println("Label " + num + ", class " + intValue);
            if (num.intValue() > 0) {
                resultsTable.setValue("CLASS", num.intValue() - 1, intValue);
            }
        }
        ResultsTable filterTable = filterTable(resultsTable, "CLASS");
        if (this.show_table) {
            filterTable.show("TRAINING");
        }
        net.haesleinhuepf.clijx.weka.WekaLabelClassifier.invalidateCache();
        TrainWekaFromTable.trainWekaFromTable(cLIJx, filterTable, "CLASS", text2, Integer.valueOf(this.num_trees), Integer.valueOf(this.num_features), Integer.valueOf(this.max_depth));
        cleanup(this.my_sources, imagePlusesToBuffers);
        logger.log("Model saved to " + text2);
        BinaryWekaPixelClassifier.saveFeatures(text2 + ".features.txt", this.features);
        logger.log("Featurelist saved to " + text2 + ".features.txt");
        setTargetInvalid();
        logger.log("Bye.");
    }

    @Override // net.haesleinhuepf.clijx.assistant.AbstractCLIJxAssistantGUIPlugin
    protected void addMoreActions(Menu menu) {
        AssistantUtilities.addMenuAction(menu, "Train classifier", actionEvent -> {
            train(new IJLogger());
        });
    }

    public static ResultsTable filterTable(ResultsTable resultsTable, String str) {
        ResultsTable resultsTable2 = new ResultsTable();
        for (int i = 0; i < resultsTable.size(); i++) {
            if (resultsTable.getValue(str, i) > 0.0d) {
                resultsTable2.incrementCounter();
                for (String str2 : resultsTable.getHeadings()) {
                    resultsTable2.addValue(str2, resultsTable.getValue(str2, i));
                }
            }
        }
        return resultsTable2;
    }
}
