package edu.emory.mathcs.nlp.component.template.config;

import edu.emory.mathcs.nlp.common.constant.CharConst;
import edu.emory.mathcs.nlp.common.util.Language;
import edu.emory.mathcs.nlp.common.util.Splitter;
import edu.emory.mathcs.nlp.common.util.XMLUtils;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.train.HyperParameter;
import edu.emory.mathcs.nlp.component.template.train.LOLS;
import edu.emory.mathcs.nlp.learning.activation.ActivationFunction;
import edu.emory.mathcs.nlp.learning.activation.HyperbolicTanFunction;
import edu.emory.mathcs.nlp.learning.activation.IdentityFunction;
import edu.emory.mathcs.nlp.learning.activation.RectifiedLinearUnitFunction;
import edu.emory.mathcs.nlp.learning.activation.SigmoidFunction;
import edu.emory.mathcs.nlp.learning.activation.SoftmaxFunction;
import edu.emory.mathcs.nlp.learning.activation.SoftplusFunction;
import edu.emory.mathcs.nlp.learning.initialization.RandomWeightGenerator;
import edu.emory.mathcs.nlp.learning.initialization.WeightGenerator;
import edu.emory.mathcs.nlp.learning.neural.FeedForwardNeuralNetworkSoftmax;
import edu.emory.mathcs.nlp.learning.optimization.OnlineOptimizer;
import edu.emory.mathcs.nlp.learning.optimization.method.AdaDeltaMiniBatch;
import edu.emory.mathcs.nlp.learning.optimization.method.AdaGrad;
import edu.emory.mathcs.nlp.learning.optimization.method.AdaGradMiniBatch;
import edu.emory.mathcs.nlp.learning.optimization.method.AdaGradRegression;
import edu.emory.mathcs.nlp.learning.optimization.method.Perceptron;
import edu.emory.mathcs.nlp.learning.optimization.method.SoftmaxRegression;
import edu.emory.mathcs.nlp.learning.optimization.reguralization.RegularizedDualAveraging;
import edu.emory.mathcs.nlp.learning.util.WeightVector;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.InputStream;
import java.util.Arrays;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

/* loaded from: input_file:edu/emory/mathcs/nlp/component/template/config/NLPConfig.class */
public class NLPConfig<N extends AbstractNLPNode<N>> implements ConfigXML {
    protected Object2IntMap<String> reader_map;
    protected Element xml;

    public NLPConfig() {
    }

    public NLPConfig(InputStream inputStream) {
        this.xml = XMLUtils.getDocumentElement(inputStream);
    }

    public Element getDocumentElement() {
        return this.xml;
    }

    public int getIntegerTextContent(String str) {
        return XMLUtils.getIntegerTextContentFromFirstElementByTagName(this.xml, str);
    }

    public String getTextContent(String str) {
        return XMLUtils.getTextContentFromFirstElementByTagName(this.xml, str);
    }

    public Language getLanguage() {
        String textContentFromFirstElementByTagName = XMLUtils.getTextContentFromFirstElementByTagName(this.xml, ConfigXML.LANGUAGE);
        return textContentFromFirstElementByTagName == null ? Language.ENGLISH : Language.getType(textContentFromFirstElementByTagName);
    }

    public Object2IntMap<String> getReaderFieldMap() {
        NodeList elementsByTagName = XMLUtils.getFirstElementByTagName(this.xml, "tsv").getElementsByTagName(ConfigXML.COLUMN);
        int length = elementsByTagName.getLength();
        Object2IntOpenHashMap object2IntOpenHashMap = new Object2IntOpenHashMap();
        for (int i = 0; i < length; i++) {
            Element element = (Element) elementsByTagName.item(i);
            object2IntOpenHashMap.put(XMLUtils.getTrimmedAttribute(element, "field"), XMLUtils.getIntegerAttribute(element, ConfigXML.INDEX));
        }
        return object2IntOpenHashMap;
    }

    public Element getFeatureTemplateElement() {
        return XMLUtils.getFirstElementByTagName(this.xml, ConfigXML.FEATURE_TEMPLATE);
    }

    public HyperParameter getHyperParameter() {
        int i;
        double d;
        Element firstElementByTagName = XMLUtils.getFirstElementByTagName(this.xml, ConfigXML.OPTIMIZER);
        Element firstElementByTagName2 = XMLUtils.getFirstElementByTagName(firstElementByTagName, ConfigXML.LOLS);
        int integerTextContentFromFirstElementByTagName = XMLUtils.getIntegerTextContentFromFirstElementByTagName(firstElementByTagName, ConfigXML.FEATURE_CUTOFF);
        int integerTextContentFromFirstElementByTagName2 = XMLUtils.getIntegerTextContentFromFirstElementByTagName(firstElementByTagName, ConfigXML.BATCH_SIZE);
        int integerTextContentFromFirstElementByTagName3 = XMLUtils.getIntegerTextContentFromFirstElementByTagName(firstElementByTagName, ConfigXML.MAX_EPOCH);
        float floatTextContentFromFirstElementByTagName = XMLUtils.getFloatTextContentFromFirstElementByTagName(firstElementByTagName, ConfigXML.LEARNING_RATE);
        float floatTextContentFromFirstElementByTagName2 = XMLUtils.getFloatTextContentFromFirstElementByTagName(firstElementByTagName, ConfigXML.DECAYING_RATE);
        float floatTextContentFromFirstElementByTagName3 = XMLUtils.getFloatTextContentFromFirstElementByTagName(firstElementByTagName, ConfigXML.BIAS);
        float floatTextContentFromFirstElementByTagName4 = XMLUtils.getFloatTextContentFromFirstElementByTagName(firstElementByTagName, ConfigXML.L1_REGULARIZATION);
        if (firstElementByTagName2 != null) {
            i = XMLUtils.getIntegerAttribute(firstElementByTagName2, ConfigXML.FIXED);
            d = XMLUtils.getDoubleAttribute(firstElementByTagName2, ConfigXML.DECAYING);
        } else {
            i = 0;
            d = 1.0d;
        }
        RegularizedDualAveraging regularizedDualAveraging = floatTextContentFromFirstElementByTagName4 > 0.0f ? new RegularizedDualAveraging(floatTextContentFromFirstElementByTagName4) : null;
        HyperParameter hyperParameter = new HyperParameter();
        hyperParameter.setFeature_cutoff(integerTextContentFromFirstElementByTagName);
        hyperParameter.setBatchSize(integerTextContentFromFirstElementByTagName2);
        hyperParameter.setMaxEpochs(integerTextContentFromFirstElementByTagName3);
        hyperParameter.setLearningRate(floatTextContentFromFirstElementByTagName);
        hyperParameter.setDecayingRate(floatTextContentFromFirstElementByTagName2);
        hyperParameter.setBias(floatTextContentFromFirstElementByTagName3);
        hyperParameter.setL1Regularizer(regularizedDualAveraging);
        hyperParameter.setLOLS(new LOLS(i, d));
        hyperParameter.setHiddenDimensions(getHiddenDimensions(firstElementByTagName));
        hyperParameter.setActivationFunctions(getActivationFunction(firstElementByTagName));
        hyperParameter.setDropoutProb(getDropoutProb(firstElementByTagName));
        hyperParameter.setWeightGenerator(getWeightGenerator(firstElementByTagName));
        return hyperParameter;
    }

    public OnlineOptimizer getOnlineOptimizer(HyperParameter hyperParameter) {
        String textContentFromFirstElementByTagName = XMLUtils.getTextContentFromFirstElementByTagName(XMLUtils.getFirstElementByTagName(this.xml, ConfigXML.OPTIMIZER), ConfigXML.ALGORITHM);
        WeightVector weightVector = new WeightVector();
        boolean z = -1;
        switch (textContentFromFirstElementByTagName.hashCode()) {
            case -2046876940:
                if (textContentFromFirstElementByTagName.equals(ConfigXML.SOFTMAX_REGRESSION)) {
                    z = true;
                    break;
                }
                break;
            case -1150778388:
                if (textContentFromFirstElementByTagName.equals(ConfigXML.ADAGRAD)) {
                    z = 3;
                    break;
                }
                break;
            case -976593897:
                if (textContentFromFirstElementByTagName.equals(ConfigXML.ADADELTA_MINI_BATCH)) {
                    z = 5;
                    break;
                }
                break;
            case -919969170:
                if (textContentFromFirstElementByTagName.equals(ConfigXML.PERCEPTRON)) {
                    z = false;
                    break;
                }
                break;
            case -635660627:
                if (textContentFromFirstElementByTagName.equals(ConfigXML.FFNN_SOFTMAX)) {
                    z = 6;
                    break;
                }
                break;
            case 333905090:
                if (textContentFromFirstElementByTagName.equals(ConfigXML.ADAGRAD_REGRESSION)) {
                    z = 2;
                    break;
                }
                break;
            case 974302917:
                if (textContentFromFirstElementByTagName.equals(ConfigXML.ADAGRAD_MINI_BATCH)) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case CharConst.EMPTY /* 0 */:
                return new Perceptron(weightVector, hyperParameter.getLearningRate(), hyperParameter.getBias());
            case true:
                return new SoftmaxRegression(weightVector, hyperParameter.getLearningRate(), hyperParameter.getBias());
            case true:
                return new AdaGradRegression(weightVector, hyperParameter.getLearningRate(), hyperParameter.getBias());
            case true:
                return new AdaGrad(weightVector, hyperParameter.getLearningRate(), hyperParameter.getBias(), hyperParameter.getL1Regularizer());
            case true:
                return new AdaGradMiniBatch(weightVector, hyperParameter.getLearningRate(), hyperParameter.getBias(), hyperParameter.getL1Regularizer());
            case true:
                return new AdaDeltaMiniBatch(weightVector, hyperParameter.getLearningRate(), hyperParameter.getDecayingRate(), hyperParameter.getBias(), hyperParameter.getL1Regularizer());
            case true:
                return new FeedForwardNeuralNetworkSoftmax(hyperParameter.getHiddenDimensions(), hyperParameter.getActivationFunctions(), hyperParameter.getLearningRate(), hyperParameter.getBias(), hyperParameter.getWeightGenerator(), hyperParameter.getDropoutProb());
            default:
                throw new IllegalArgumentException(textContentFromFirstElementByTagName + " is not a valid algorithm name.");
        }
    }

    private int[] getHiddenDimensions(Element element) {
        String textContentFromFirstElementByTagName = XMLUtils.getTextContentFromFirstElementByTagName(element, ConfigXML.HIDDEN_DIMENSIONS);
        if (textContentFromFirstElementByTagName == null || textContentFromFirstElementByTagName.isEmpty()) {
            return null;
        }
        return Arrays.stream(Splitter.splitCommas(textContentFromFirstElementByTagName)).mapToInt(Integer::parseInt).toArray();
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:9:0x0038. Please report as an issue. */
    private ActivationFunction[] getActivationFunction(Element element) {
        String textContentFromFirstElementByTagName = XMLUtils.getTextContentFromFirstElementByTagName(element, ConfigXML.ACTIVATION_FUNCTIONS);
        if (textContentFromFirstElementByTagName == null || textContentFromFirstElementByTagName.isEmpty()) {
            return null;
        }
        String[] splitCommas = Splitter.splitCommas(textContentFromFirstElementByTagName);
        ActivationFunction[] activationFunctionArr = new ActivationFunction[splitCommas.length];
        for (int i = 0; i < splitCommas.length; i++) {
            String str = splitCommas[i];
            boolean z = -1;
            switch (str.hashCode()) {
                case -2035660550:
                    if (str.equals(ConfigXML.SOFTMAX)) {
                        z = true;
                        break;
                    }
                    break;
                case -135761730:
                    if (str.equals(ConfigXML.IDENTITY)) {
                        z = 2;
                        break;
                    }
                    break;
                case 3496700:
                    if (str.equals(ConfigXML.RELU)) {
                        z = 3;
                        break;
                    }
                    break;
                case 3552487:
                    if (str.equals(ConfigXML.TANH)) {
                        z = 4;
                        break;
                    }
                    break;
                case 1319132356:
                    if (str.equals(ConfigXML.SOFTPLUS)) {
                        z = 5;
                        break;
                    }
                    break;
                case 2088248974:
                    if (str.equals(ConfigXML.SIGMOID)) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case CharConst.EMPTY /* 0 */:
                    activationFunctionArr[i] = new SigmoidFunction();
                    break;
                case true:
                    activationFunctionArr[i] = new SoftmaxFunction();
                    break;
                case true:
                    activationFunctionArr[i] = new IdentityFunction();
                    break;
                case true:
                    activationFunctionArr[i] = new RectifiedLinearUnitFunction();
                    break;
                case true:
                    activationFunctionArr[i] = new HyperbolicTanFunction();
                    break;
                case true:
                    activationFunctionArr[i] = new SoftplusFunction();
                    break;
            }
        }
        return activationFunctionArr;
    }

    private float[] getDropoutProb(Element element) {
        String textContentFromFirstElementByTagName = XMLUtils.getTextContentFromFirstElementByTagName(element, ConfigXML.DROPOUT_PROB);
        if (textContentFromFirstElementByTagName == null || textContentFromFirstElementByTagName.isEmpty()) {
            return null;
        }
        String[] splitCommas = Splitter.splitCommas(textContentFromFirstElementByTagName);
        float[] fArr = new float[splitCommas.length];
        for (int i = 0; i < splitCommas.length; i++) {
            fArr[i] = Float.parseFloat(splitCommas[i]);
        }
        return fArr;
    }

    private WeightGenerator getWeightGenerator(Element element) {
        Element firstElementByTagName = XMLUtils.getFirstElementByTagName(element, ConfigXML.WEIGHT_GENERATOR);
        if (firstElementByTagName == null) {
            return null;
        }
        return new RandomWeightGenerator(Float.parseFloat(XMLUtils.getTrimmedAttribute(firstElementByTagName, "lower")), Float.parseFloat(XMLUtils.getTrimmedAttribute(firstElementByTagName, "upper")));
    }
}
