package edu.emory.mathcs.nlp.component.doc;

import edu.emory.mathcs.nlp.common.collection.tuple.ObjectFloatPair;
import edu.emory.mathcs.nlp.common.constant.CharConst;
import edu.emory.mathcs.nlp.common.util.MathUtils;
import edu.emory.mathcs.nlp.common.util.XMLUtils;
import edu.emory.mathcs.nlp.component.doc.DOCState;
import edu.emory.mathcs.nlp.component.template.feature.FeatureItem;
import edu.emory.mathcs.nlp.component.template.feature.FeatureTemplate;
import edu.emory.mathcs.nlp.component.template.feature.Field;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.train.HyperParameter;
import edu.emory.mathcs.nlp.learning.util.SparseVector;
import it.unimi.dsi.fastutil.objects.Object2FloatMap;
import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.StringJoiner;
import java.util.stream.Collectors;
import org.w3c.dom.Element;

/* loaded from: input_file:edu/emory/mathcs/nlp/component/doc/DOCFeatureTemplate.class */
public class DOCFeatureTemplate<N extends AbstractNLPNode<N>, S extends DOCState<N>> extends FeatureTemplate<N, S> {
    private static final long serialVersionUID = 8581842859392646419L;
    protected List<Field> feature_list_type;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: edu.emory.mathcs.nlp.component.doc.DOCFeatureTemplate$1, reason: invalid class name */
    /* loaded from: input_file:edu/emory/mathcs/nlp/component/doc/DOCFeatureTemplate$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field = new int[Field.values().length];

        static {
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_words.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_words_norm.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_words_count.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_words_stopwords.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_words_stopwords_norm.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_words_stopwords_count.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_clusters.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_clusters_norm.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_clusters_count.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_clusters_stopwords.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_clusters_stopwords_norm.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[Field.bag_of_clusters_stopwords_count.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
        }
    }

    public DOCFeatureTemplate(Element element, HyperParameter hyperParameter) {
        super(element, hyperParameter);
    }

    @Override // edu.emory.mathcs.nlp.component.template.feature.FeatureTemplate
    protected void initFeatureItems(Element element) {
        FeatureItem[] createFeatureItems = createFeatureItems(element);
        if (this.feature_list_type == null) {
            this.feature_list_type = new ArrayList();
        }
        if (createFeatureItems == null || createFeatureItems.length <= 0 || createFeatureItems[0].field != Field.word_embedding) {
            add(createFeatureItems);
        } else {
            addWordEmbedding(createFeatureItems[0]);
        }
        this.feature_list_type.add(Field.valueOf(XMLUtils.getTrimmedAttribute(element, "t")));
    }

    @Override // edu.emory.mathcs.nlp.component.template.feature.FeatureTemplate
    public SparseVector createSparseVector(S s, boolean z) {
        SparseVector sparseVector = new SparseVector();
        int i = 0;
        int i2 = 0;
        while (i2 < this.feature_list.size()) {
            Collection<ObjectFloatPair<String>> weightedFeatures = getWeightedFeatures(s, this.feature_list.get(i2), this.feature_list_type.get(i2));
            if (weightedFeatures != null) {
                for (ObjectFloatPair<String> objectFloatPair : weightedFeatures) {
                    add(sparseVector, i, objectFloatPair.o, objectFloatPair.f, z);
                }
            }
            i2++;
            i++;
        }
        return sparseVector;
    }

    protected Collection<ObjectFloatPair<String>> getWeightedFeatures(S s, FeatureItem[] featureItemArr, Field field) {
        Object2FloatMap<String> bagOfLexicons = getBagOfLexicons(s, featureItemArr, field);
        if (bagOfLexicons == null || bagOfLexicons.isEmpty()) {
            return null;
        }
        return getBagOfLexicons(bagOfLexicons, field);
    }

    protected Object2FloatMap<String> getBagOfLexicons(S s, FeatureItem[] featureItemArr, Field field) {
        switch (AnonymousClass1.$SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[field.ordinal()]) {
            case 1:
            case 2:
            case 3:
                return getBagOfWords(s, featureItemArr, false);
            case 4:
            case 5:
            case 6:
                return getBagOfWords(s, featureItemArr, true);
            case 7:
            case 8:
            case CharConst.TAB /* 9 */:
                return getBagOfClusters(s, false);
            case CharConst.NEW_LINE /* 10 */:
            case 11:
            case 12:
                return getBagOfClusters(s, true);
            default:
                return null;
        }
    }

    protected Collection<ObjectFloatPair<String>> getBagOfLexicons(Object2FloatMap<String> object2FloatMap, Field field) {
        switch (AnonymousClass1.$SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[field.ordinal()]) {
            case 1:
            case 4:
            case 7:
            case CharConst.NEW_LINE /* 10 */:
                return (Collection) object2FloatMap.entrySet().stream().map(entry -> {
                    return new ObjectFloatPair(entry.getKey(), 1.0f);
                }).collect(Collectors.toList());
            case 2:
            case 5:
            case 8:
            case 11:
                return (Collection) object2FloatMap.entrySet().stream().map(entry2 -> {
                    return new ObjectFloatPair(entry2.getKey(), (float) MathUtils.sigmoid(((Float) entry2.getValue()).floatValue()));
                }).collect(Collectors.toList());
            case 3:
            case 6:
            case CharConst.TAB /* 9 */:
            case 12:
                return (Collection) object2FloatMap.entrySet().stream().map(entry3 -> {
                    return new ObjectFloatPair(entry3.getKey(), ((Float) entry3.getValue()).floatValue());
                }).collect(Collectors.toList());
            default:
                return null;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected Object2FloatMap<String> getBagOfWords(S s, FeatureItem[] featureItemArr, boolean z) {
        AbstractNLPNode relativeNode;
        String feature;
        Object2FloatOpenHashMap object2FloatOpenHashMap = new Object2FloatOpenHashMap();
        for (N[] nArr : s.getDocument(z)) {
            for (int i = 1; i < nArr.length; i++) {
                StringJoiner stringJoiner = new StringJoiner("_");
                int length = featureItemArr.length;
                int i2 = 0;
                while (true) {
                    if (i2 >= length) {
                        object2FloatOpenHashMap.merge(stringJoiner.toString(), Float.valueOf(1.0f), (f, f2) -> {
                            return Float.valueOf(f.floatValue() + f2.floatValue());
                        });
                        break;
                    }
                    FeatureItem featureItem = featureItemArr[i2];
                    int i3 = i + featureItem.window;
                    if (i3 >= 1 && i3 < nArr.length && (relativeNode = s.getRelativeNode(nArr[i3], featureItem.relation)) != null && (feature = getFeature(s, featureItem, relativeNode)) != null) {
                        stringJoiner.add(feature);
                        i2++;
                    }
                }
            }
        }
        return object2FloatOpenHashMap;
    }

    protected Object2FloatMap<String> getBagOfClusters(S s, boolean z) {
        Object2FloatOpenHashMap object2FloatOpenHashMap = new Object2FloatOpenHashMap();
        for (N[] nArr : s.getDocument(z)) {
            for (int i = 1; i < nArr.length; i++) {
                Set<String> wordClusters = nArr[i].getWordClusters();
                if (wordClusters != null) {
                    Iterator<String> it = wordClusters.iterator();
                    while (it.hasNext()) {
                        object2FloatOpenHashMap.merge(it.next(), Float.valueOf(1.0f), (f, f2) -> {
                            return Float.valueOf(f.floatValue() + f2.floatValue());
                        });
                    }
                }
            }
        }
        return object2FloatOpenHashMap;
    }

    @Override // edu.emory.mathcs.nlp.component.template.feature.FeatureTemplate
    public float[] createDenseVector(S s) {
        if (this.word_embeddings == null || this.word_embeddings.isEmpty()) {
            return null;
        }
        return getEmbeddings(s, true);
    }

    public float[] getEmbeddings(S s, boolean z) {
        float[] fArr = null;
        int i = 0;
        for (N[] nArr : s.getDocument()) {
            for (int i2 = 1; i2 < nArr.length; i2++) {
                N n = nArr[i2];
                if (!n.isStopWord() && n.hasWordEmbedding()) {
                    float[] wordEmbedding = n.getWordEmbedding();
                    if (fArr == null) {
                        fArr = new float[wordEmbedding.length];
                    }
                    MathUtils.add(fArr, wordEmbedding);
                    i++;
                }
            }
        }
        if (z && fArr != null) {
            for (int i3 = 0; i3 < fArr.length; i3++) {
                float[] fArr2 = fArr;
                int i4 = i3;
                fArr2[i4] = fArr2[i4] / i;
            }
        }
        return fArr;
    }
}
