package edu.emory.mathcs.nlp.component.template.feature;

import edu.emory.mathcs.nlp.common.constant.CharConst;
import edu.emory.mathcs.nlp.common.constant.MetaConst;
import edu.emory.mathcs.nlp.common.propbank.frameset.PBFXml;
import edu.emory.mathcs.nlp.common.util.CharUtils;
import edu.emory.mathcs.nlp.common.util.FastUtils;
import edu.emory.mathcs.nlp.common.util.Joiner;
import edu.emory.mathcs.nlp.common.util.Splitter;
import edu.emory.mathcs.nlp.common.util.XMLUtils;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.node.Orthographic;
import edu.emory.mathcs.nlp.component.template.state.NLPState;
import edu.emory.mathcs.nlp.component.template.train.HyperParameter;
import edu.emory.mathcs.nlp.learning.util.ColumnMajorVector;
import edu.emory.mathcs.nlp.learning.util.FeatureMap;
import edu.emory.mathcs.nlp.learning.util.FeatureVector;
import edu.emory.mathcs.nlp.learning.util.MajorVector;
import edu.emory.mathcs.nlp.learning.util.SparseVector;
import edu.emory.mathcs.nlp.learning.util.StringPrediction;
import edu.emory.mathcs.nlp.learning.util.WeightVector;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.StringJoiner;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

/* loaded from: input_file:edu/emory/mathcs/nlp/component/template/feature/FeatureTemplate.class */
public class FeatureTemplate<N extends AbstractNLPNode<N>, S extends NLPState<N>> implements Serializable {
    private static final long serialVersionUID = -6755594173767815098L;
    protected List<FeatureItem[]> feature_list = new ArrayList();
    protected List<FeatureItem> feature_set = new ArrayList();
    protected List<FeatureItem> word_embeddings = new ArrayList();
    protected Object2IntMap<String> feature_count = new Object2IntOpenHashMap();
    protected FeatureMap feature_map = new FeatureMap();
    protected int cutoff;

    public FeatureTemplate(Element element, HyperParameter hyperParameter) {
        setCutoff(hyperParameter.getFeature_cutoff());
        init(element);
    }

    public List<FeatureItem[]> getFeatureList() {
        return this.feature_list;
    }

    public List<FeatureItem> getSetFeatureList() {
        return this.feature_set;
    }

    public List<FeatureItem> getEmbeddingFeatureList() {
        return this.word_embeddings;
    }

    protected void init(Element element) {
        if (element == null) {
            return;
        }
        NodeList elementsByTagName = element.getElementsByTagName("feature");
        for (int i = 0; i < elementsByTagName.getLength(); i++) {
            initFeatureItems((Element) elementsByTagName.item(i));
        }
    }

    protected void initFeatureItems(Element element) {
        FeatureItem[] createFeatureItems = createFeatureItems(element);
        if (XMLUtils.getBooleanAttribute(element, "set")) {
            addSet(createFeatureItems[0]);
        } else if (createFeatureItems[0].field == Field.word_embedding) {
            addWordEmbedding(createFeatureItems[0]);
        } else {
            add(createFeatureItems);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FeatureItem[] createFeatureItems(Element element) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (true) {
            String attribute = element.getAttribute(PBFXml.A_F + i);
            if (attribute.isEmpty()) {
                break;
            }
            arrayList.add(attribute);
            i++;
        }
        FeatureItem[] featureItemArr = new FeatureItem[arrayList.size()];
        for (int i2 = 0; i2 < featureItemArr.length; i2++) {
            featureItemArr[i2] = createFeatureItem((String) arrayList.get(i2));
        }
        return featureItemArr;
    }

    private FeatureItem createFeatureItem(String str) {
        String[] splitColons = Splitter.splitColons(str);
        String str2 = splitColons[0];
        Source valueOf = Source.valueOf(str2.substring(0, 1));
        int indexOf = str2.indexOf(95);
        if (indexOf < 0) {
            indexOf = str2.length();
        }
        int parseInt = indexOf == 1 ? 0 : Integer.parseInt(str2.substring(1, indexOf));
        Relation valueOf2 = indexOf != str2.length() ? Relation.valueOf(str2.substring(indexOf + 1)) : null;
        Field valueOf3 = Field.valueOf(splitColons[1]);
        return new FeatureItem(valueOf, valueOf2, parseInt, valueOf3, splitColons.length > 2 ? createAttribute(valueOf3, splitColons[2]) : null);
    }

    protected Object createAttribute(Field field, String str) {
        switch (field) {
            case prefix:
                return new Integer(Integer.parseInt(str));
            case suffix:
                return new Integer(Integer.parseInt(str));
            case feats:
                return str;
            case valency:
                return Direction.valueOf(str);
            case dependent_set:
                return Field.valueOf(str);
            default:
                return null;
        }
    }

    public void add(FeatureItem... featureItemArr) {
        this.feature_list.add(featureItemArr);
    }

    public void addSet(FeatureItem featureItem) {
        this.feature_set.add(featureItem);
    }

    public void addWordEmbedding(FeatureItem featureItem) {
        this.word_embeddings.add(featureItem);
    }

    public int getSparseFeatureSize() {
        return this.feature_map.size();
    }

    public int getTemplateSize() {
        return this.feature_list.size() + this.feature_set.size() + this.word_embeddings.size();
    }

    public int getCutoff() {
        return this.cutoff;
    }

    public void setCutoff(int i) {
        this.cutoff = i;
    }

    public void clearFeatureCount() {
        this.feature_count.clear();
    }

    public void initFeatureCount() {
        this.feature_count = new Object2IntOpenHashMap();
    }

    public FeatureVector createFeatureVector(S s, boolean z) {
        return new FeatureVector(createSparseVector(s, z), createDenseVector(s));
    }

    public SparseVector createSparseVector(S s, boolean z) {
        SparseVector sparseVector = new SparseVector();
        int i = 0;
        int i2 = 0;
        while (i2 < this.feature_set.size()) {
            Collection<String> features = getFeatures(s, this.feature_set.get(i2));
            if (features != null) {
                Iterator<String> it = features.iterator();
                while (it.hasNext()) {
                    add(sparseVector, i, it.next(), 1.0f, z);
                }
            }
            i2++;
            i++;
        }
        int i3 = 0;
        while (i3 < this.feature_list.size()) {
            add(sparseVector, i, getFeature((FeatureTemplate<N, S>) s, this.feature_list.get(i3)), 1.0f, z);
            i3++;
            i++;
        }
        return sparseVector;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void add(SparseVector sparseVector, int i, String str, float f, boolean z) {
        int index;
        if (str != null) {
            if (z) {
                index = FastUtils.increment(this.feature_count, new StringBuilder().append(i).append(str).toString()) > this.cutoff ? this.feature_map.add(i, str) : -1;
            } else {
                index = this.feature_map.index(i, str);
            }
            if (index > 0) {
                sparseVector.add(index, f);
            }
        }
    }

    protected String getFeature(S s, FeatureItem... featureItemArr) {
        if (featureItemArr.length == 1) {
            return getFeature((FeatureTemplate<N, S>) s, featureItemArr[0]);
        }
        StringJoiner stringJoiner = new StringJoiner("_");
        for (FeatureItem featureItem : featureItemArr) {
            String feature = getFeature((FeatureTemplate<N, S>) s, featureItem);
            if (feature == null) {
                return null;
            }
            stringJoiner.add(feature);
        }
        return stringJoiner.toString();
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected String getFeature(S s, FeatureItem featureItem) {
        AbstractNLPNode node = s.getNode(featureItem);
        if (node == null) {
            return null;
        }
        return getFeature(s, featureItem, node);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String getFeature(S s, FeatureItem featureItem, N n) {
        String value = n.getValue(featureItem.field);
        if (value != null) {
            return value;
        }
        switch (featureItem.field) {
            case prefix:
                return getPrefix(n, ((Integer) featureItem.attribute).intValue());
            case suffix:
                return getSuffix(n, ((Integer) featureItem.attribute).intValue());
            case feats:
                return n.getFeat((String) featureItem.attribute);
            case valency:
                return n.getValency((Direction) featureItem.attribute);
            default:
                return null;
        }
    }

    protected String getPrefix(N n, int i) {
        String wordFormSimplifiedLowercase = n.getWordFormSimplifiedLowercase();
        if (i < wordFormSimplifiedLowercase.length()) {
            return wordFormSimplifiedLowercase.substring(0, i);
        }
        return null;
    }

    protected String getSuffix(N n, int i) {
        String wordFormSimplifiedLowercase = n.getWordFormSimplifiedLowercase();
        if (i < wordFormSimplifiedLowercase.length()) {
            return wordFormSimplifiedLowercase.substring(wordFormSimplifiedLowercase.length() - i);
        }
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected Collection<String> getFeatures(S s, FeatureItem featureItem) {
        AbstractNLPNode node = s.getNode(featureItem);
        if (node == null) {
            return null;
        }
        return getFeatures(s, featureItem, node);
    }

    protected Collection<String> getFeatures(S s, FeatureItem featureItem, N n) {
        switch (AnonymousClass1.$SwitchMap$edu$emory$mathcs$nlp$component$template$feature$Field[featureItem.field.ordinal()]) {
            case 5:
                return n.getDependentValueSet((Field) featureItem.attribute);
            case 6:
                return getPositionFeatures(s, n);
            case 7:
                return getOrthographicFeatures(s, n, true);
            case 8:
                return getOrthographicFeatures(s, n, false);
            case CharConst.TAB /* 9 */:
                return n.getAmbiguityClasseList();
            case CharConst.NEW_LINE /* 10 */:
                return n.getNamedEntityGazetteerSet();
            case 11:
                return n.getWordClusters();
            default:
                return null;
        }
    }

    protected List<String> getPositionFeatures(S s, N n) {
        ArrayList arrayList = new ArrayList();
        if (s.isFirst(n)) {
            arrayList.add("0");
        } else if (s.isLast(n)) {
            arrayList.add(Orthographic.ALL_UPPER);
        }
        if (arrayList.isEmpty()) {
            return null;
        }
        return arrayList;
    }

    protected List<String> getOrthographicFeatures(S s, N n, boolean z) {
        ArrayList arrayList = new ArrayList();
        if (MetaConst.HYPERLINK.equals(n.getWordFormSimplified())) {
            arrayList.add("0");
        } else {
            getOrthographicFeauturesAux(arrayList, n.getWordFormSimplified().toCharArray(), s.isFirst(n), z);
        }
        if (arrayList.isEmpty()) {
            return null;
        }
        return arrayList;
    }

    protected void getOrthographicFeauturesAux(List<String> list, char[] cArr, boolean z, boolean z2) {
        boolean z3 = false;
        boolean z4 = false;
        boolean z5 = false;
        boolean z6 = false;
        boolean z7 = false;
        boolean z8 = true;
        boolean z9 = true;
        boolean z10 = true;
        boolean z11 = true;
        boolean z12 = true;
        boolean z13 = true;
        int i = 0;
        int length = cArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            char c = cArr[i2];
            boolean isUpperCase = CharUtils.isUpperCase(c);
            boolean isLowerCase = CharUtils.isLowerCase(c);
            boolean isDigit = CharUtils.isDigit(c);
            boolean isPunctuation = CharUtils.isPunctuation(c);
            if (!isUpperCase) {
                z10 = false;
            } else if (i2 == 0) {
                z7 = true;
            } else {
                i++;
            }
            if (isLowerCase) {
                z12 = false;
            } else {
                z11 = false;
            }
            if (isDigit) {
                z3 = true;
            } else {
                z8 = false;
            }
            if (isPunctuation) {
                z6 = true;
                if (c == '.') {
                    z4 = true;
                }
                if (c == '-') {
                    z5 = true;
                }
            } else {
                z9 = false;
            }
            if (!isDigit && !isPunctuation) {
                z13 = false;
            }
        }
        if (z10) {
            if (z2) {
                list.add(Orthographic.ALL_UPPER);
            }
        } else if (z11) {
            if (z2) {
                list.add(Orthographic.ALL_LOWER);
            }
        } else if (z8) {
            list.add(Orthographic.ALL_DIGIT);
        } else if (z9) {
            list.add(Orthographic.ALL_PUNCT);
        } else if (z13) {
            list.add(Orthographic.ALL_DIGIT_OR_PUNCT);
        } else if (z12 && z2) {
            list.add(Orthographic.NO_LOWER);
        }
        if (z2 && !z10) {
            if (z7 && !z) {
                list.add(Orthographic.FST_UPPER);
            }
            if (i == 1) {
                list.add(Orthographic.UPPER_1);
            } else if (i > 1) {
                list.add(Orthographic.UPPER_2);
            }
        }
        if (!z8 && z3) {
            list.add(Orthographic.HAS_DIGIT);
        }
        if (z4) {
            list.add(Orthographic.HAS_PERIOD);
        }
        if (z5) {
            list.add(Orthographic.HAS_HYPHEN);
        }
        if (z9 || z4 || z5 || !z6) {
            return;
        }
        list.add(Orthographic.HAS_OTHER_PUNCT);
    }

    protected Set<StringPrediction> toSet(Object2IntMap<String> object2IntMap) {
        HashSet hashSet = new HashSet();
        ObjectIterator it = object2IntMap.object2IntEntrySet().iterator();
        while (it.hasNext()) {
            hashSet.add(new StringPrediction((String) ((Object2IntMap.Entry) it.next()).getKey(), r0.getIntValue()));
        }
        return hashSet;
    }

    public float[] createDenseVector(S s) {
        return getEmbeddings(s);
    }

    public float[] getEmbeddings(S s) {
        if (this.word_embeddings == null || this.word_embeddings.isEmpty()) {
            return null;
        }
        float[] fArr = null;
        int i = -1;
        Iterator<FeatureItem> it = this.word_embeddings.iterator();
        while (it.hasNext()) {
            AbstractNLPNode node = s.getNode(it.next());
            i++;
            if (node != null && node.hasWordEmbedding()) {
                float[] wordEmbedding = node.getWordEmbedding();
                if (fArr == null) {
                    fArr = new float[wordEmbedding.length * this.word_embeddings.size()];
                }
                System.arraycopy(wordEmbedding, 0, fArr, wordEmbedding.length * i, wordEmbedding.length);
            }
        }
        return fArr;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        Iterator<FeatureItem[]> it = this.feature_list.iterator();
        while (it.hasNext()) {
            sb.append("[" + Joiner.join(it.next(), "],[") + "]\n");
        }
        Iterator<FeatureItem> it2 = this.feature_set.iterator();
        while (it2.hasNext()) {
            sb.append(it2.next() + "\n");
        }
        if (this.word_embeddings != null) {
            sb.append(Joiner.join(this.word_embeddings, ",") + "\n");
        }
        return sb.toString();
    }

    public int reduce(WeightVector weightVector, float f) {
        int labelSize = weightVector.getLabelSize();
        int featureSize = weightVector.getSparseWeightVector().getFeatureSize();
        MajorVector sparseWeightVector = weightVector.getSparseWeightVector();
        int[] iArr = new int[featureSize];
        int i = 1;
        for (int i2 = 1; i2 < featureSize; i2++) {
            int i3 = i2 * labelSize;
            float f2 = sparseWeightVector.get(i3);
            float f3 = sparseWeightVector.get(i3);
            for (int i4 = 1; i4 < labelSize; i4++) {
                f2 = Math.max(f2, sparseWeightVector.get(i3 + i4));
                f3 = Math.min(f3, sparseWeightVector.get(i3 + i4));
            }
            if (Math.abs(f2 - f3) >= f) {
                int i5 = i;
                i++;
                iArr[i2] = i5;
            }
        }
        ColumnMajorVector columnMajorVector = new ColumnMajorVector();
        columnMajorVector.expand(labelSize, i);
        for (int i6 = 0; i6 < labelSize; i6++) {
            columnMajorVector.set(i6, sparseWeightVector.get(i6));
        }
        Iterator<Object2IntMap<String>> it = this.feature_map.getIndexMaps().iterator();
        while (it.hasNext()) {
            ObjectIterator it2 = it.next().object2IntEntrySet().iterator();
            while (it2.hasNext()) {
                Object2IntMap.Entry entry = (Object2IntMap.Entry) it2.next();
                int intValue = entry.getIntValue();
                int i7 = intValue < iArr.length ? iArr[intValue] : -1;
                if (i7 > 0) {
                    entry.setValue(i7);
                    int i8 = intValue * labelSize;
                    int i9 = i7 * labelSize;
                    for (int i10 = 0; i10 < labelSize; i10++) {
                        columnMajorVector.set(i9 + i10, sparseWeightVector.get(i8 + i10));
                    }
                } else {
                    it2.remove();
                }
            }
        }
        weightVector.setSparseWeightVector(columnMajorVector);
        this.feature_map.setSize(i);
        return i;
    }
}
