package org.fnlp.nlp.similarity.train;

import gnu.trove.iterator.TIntFloatIterator;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.iterator.TIntObjectIterator;
import gnu.trove.iterator.hash.TObjectHashIterator;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;
import gnu.trove.set.hash.TLinkedHashSet;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.util.Date;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.fnlp.data.reader.Reader;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.nlp.similarity.Cluster;
import org.fnlp.util.MyHashSparseArrays;

/* loaded from: input_file:org/fnlp/nlp/similarity/train/WordCluster.class */
public class WordCluster implements Serializable {
    private static final long serialVersionUID = 1632709924496094832L;
    private static float ENERGY = 0.999f;
    int lastid;
    public int totalword;
    public int slotsize = 50;
    LabelAlphabet alpahbet = new LabelAlphabet();
    TIntObjectHashMap<TIntHashSet> leftnodes = new TIntObjectHashMap<>();
    TIntObjectHashMap<TIntHashSet> rightnodes = new TIntObjectHashMap<>();
    TIntObjectHashMap<Cluster> clusters = new TIntObjectHashMap<>();
    TIntIntHashMap heads = new TIntIntHashMap(200, 0.5f, -1, -1);
    TIntHashSet slots = new TIntHashSet();
    TIntObjectHashMap<TIntFloatHashMap> pcc = new TIntObjectHashMap<>();
    TIntObjectHashMap<TIntFloatHashMap> wcc = new TIntObjectHashMap<>();
    TIntFloatHashMap wordProb = new TIntFloatHashMap();
    private boolean meger = true;

    public void read(Reader reader) {
        this.totalword = 0;
        while (reader.hasNext()) {
            String str = (String) reader.next().getData();
            int i = -1;
            this.wordProb.adjustOrPutValue(-1, 1.0f, 1.0f);
            this.totalword += str.length() + 2;
            for (int i2 = 0; i2 < str.length() + 1; i2++) {
                int lookupIndex = i2 < str.length() ? this.alpahbet.lookupIndex(String.valueOf(str.charAt(i2))) : -2;
                this.wordProb.adjustOrPutValue(lookupIndex, 1.0f, 1.0f);
                TIntFloatHashMap tIntFloatHashMap = (TIntFloatHashMap) this.pcc.get(i);
                if (tIntFloatHashMap == null) {
                    tIntFloatHashMap = new TIntFloatHashMap();
                    this.pcc.put(i, tIntFloatHashMap);
                }
                tIntFloatHashMap.adjustOrPutValue(lookupIndex, 1.0f, 1.0f);
                TIntHashSet tIntHashSet = (TIntHashSet) this.leftnodes.get(lookupIndex);
                if (tIntHashSet == null) {
                    tIntHashSet = new TIntHashSet();
                    this.leftnodes.put(lookupIndex, tIntHashSet);
                }
                tIntHashSet.add(i);
                TIntHashSet tIntHashSet2 = (TIntHashSet) this.rightnodes.get(i);
                if (tIntHashSet2 == null) {
                    tIntHashSet2 = new TIntHashSet();
                    this.rightnodes.put(i, tIntHashSet2);
                }
                tIntHashSet2.add(lookupIndex);
                i = lookupIndex;
            }
        }
        this.lastid = this.alpahbet.size();
        System.out.println("[总个数：]\t" + this.totalword);
        System.out.println("[字典大小：]\t" + this.alpahbet.size());
        statisticProb();
    }

    private void statisticProb() {
        System.out.println("统计概率");
        TIntFloatIterator it = this.wordProb.iterator();
        while (it.hasNext()) {
            it.advance();
            float value = it.value() / this.totalword;
            it.setValue(value);
            int key = it.key();
            if (key >= 0) {
                this.clusters.put(key, new Cluster(key, value, this.alpahbet.lookupString(key)));
            }
        }
        TIntObjectIterator it2 = this.pcc.iterator();
        while (it2.hasNext()) {
            it2.advance();
            TIntFloatIterator it3 = ((TIntFloatHashMap) it2.value()).iterator();
            while (it3.hasNext()) {
                it3.advance();
                it3.setValue(it3.value() / this.totalword);
            }
        }
    }

    private float weight(int i, int i2) {
        float f = this.wordProb.get(i);
        float f2 = this.wordProb.get(i2);
        float clacW = i == i2 ? clacW(getProb(i, i), f, f2) : clacW(getProb(i, i2), f, f2) + clacW(getProb(i2, i), f2, f);
        setweight(i, i2, clacW);
        return clacW;
    }

    private float weight(int i, int i2, int i3) {
        float clacW;
        float f = this.wordProb.get(i);
        float f2 = this.wordProb.get(i2);
        float f3 = this.wordProb.get(i3);
        float f4 = f + f2;
        if (i == i3) {
            float prob = getProb(i, i);
            float prob2 = getProb(i2, i2);
            clacW = clacW(prob + prob2 + getProb(i, i2) + getProb(i2, i), f4, f4);
        } else {
            clacW = clacW(getProb(i, i3) + getProb(i2, i3), f4, f3) + clacW(getProb(i3, i) + getProb(i3, i2), f3, f4);
        }
        return clacW;
    }

    private float clacW(float f, float f2, float f3) {
        float f4 = 0.0f;
        if (f != 0.0f) {
            f4 = f * ((float) ((Math.log(f) - Math.log(f2)) - Math.log(f3)));
        }
        return f4;
    }

    private float getProb(int i, int i2) {
        return ((TIntFloatHashMap) this.pcc.get(i)) == null ? 0.0f : ((TIntFloatHashMap) this.pcc.get(i)).get(i2);
    }

    public void mergeCluster() {
        int i = -1;
        int i2 = -1;
        float f = Float.NEGATIVE_INFINITY;
        TIntIterator it = this.slots.iterator();
        while (it.hasNext()) {
            int next = it.next();
            TIntIterator it2 = this.slots.iterator();
            while (it2.hasNext()) {
                int next2 = it2.next();
                if (next < next2) {
                    float calcL = calcL(next, next2);
                    if (calcL > f) {
                        f = calcL;
                        i = next;
                        i2 = next2;
                    }
                }
            }
        }
        merge(i, i2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void merge(int i, int i2) {
        int i3 = this.lastid;
        this.lastid = i3 + 1;
        this.heads.put(i, i3);
        this.heads.put(i2, i3);
        TIntFloatHashMap tIntFloatHashMap = new TIntFloatHashMap();
        TIntFloatHashMap tIntFloatHashMap2 = new TIntFloatHashMap();
        TIntFloatHashMap tIntFloatHashMap3 = new TIntFloatHashMap();
        float f = this.wordProb.get(i) + this.wordProb.get(i2);
        float prob = getProb(i, i) + getProb(i2, i2) + getProb(i, i2) + getProb(i2, i);
        if (prob != 0.0f) {
            tIntFloatHashMap.put(i3, prob);
        }
        float clacW = clacW(prob, f, f);
        if (clacW != 0.0f) {
            tIntFloatHashMap3.put(i3, clacW);
        }
        TIntIterator it = this.slots.iterator();
        while (it.hasNext()) {
            int next = it.next();
            float f2 = this.wordProb.get(next);
            if (i != next && i2 != next) {
                float prob2 = getProb(i, next) + getProb(i2, next);
                if (prob2 != 0.0f) {
                    tIntFloatHashMap.put(i3, prob2);
                }
                float clacW2 = clacW(prob2, f, f2);
                float prob3 = getProb(next, i) + getProb(next, i2);
                if (prob3 != 0.0f) {
                    tIntFloatHashMap2.put(next, prob3);
                }
                float clacW3 = clacW2 + clacW(prob3, f2, f);
                if (clacW3 != 0.0f) {
                    tIntFloatHashMap3.put(i3, clacW3);
                }
            }
        }
        this.slots.remove(i);
        this.slots.remove(i2);
        this.slots.add(i3);
        this.pcc.put(i3, tIntFloatHashMap);
        this.pcc.remove(i);
        this.pcc.remove(i2);
        TIntFloatIterator it2 = tIntFloatHashMap2.iterator();
        while (it2.hasNext()) {
            it2.advance();
            TIntFloatHashMap tIntFloatHashMap4 = (TIntFloatHashMap) this.pcc.get(it2.key());
            tIntFloatHashMap4.put(i3, it2.value());
            tIntFloatHashMap4.remove(i);
            tIntFloatHashMap4.remove(i2);
        }
        this.wcc.put(i3, new TIntFloatHashMap());
        this.wcc.remove(i);
        this.wcc.remove(i2);
        TIntFloatIterator it3 = tIntFloatHashMap3.iterator();
        while (it3.hasNext()) {
            it3.advance();
            TIntFloatHashMap tIntFloatHashMap5 = (TIntFloatHashMap) this.wcc.get(it3.key());
            tIntFloatHashMap5.put(i3, it3.value());
            tIntFloatHashMap5.remove(i);
            tIntFloatHashMap5.remove(i2);
        }
        this.wordProb.remove(i);
        this.wordProb.remove(i2);
        this.wordProb.put(i3, f);
        Cluster cluster = new Cluster(i3, (Cluster) this.clusters.get(i), (Cluster) this.clusters.get(i2), f);
        this.clusters.put(i3, cluster);
        System.out.println("合并：" + cluster.rep);
    }

    public float calcL(int i, int i2) {
        float f = 0.0f;
        TIntIterator it = this.slots.iterator();
        while (it.hasNext()) {
            int next = it.next();
            if (next != i2) {
                f += weight(i, i2, next);
            }
        }
        TIntIterator it2 = this.slots.iterator();
        while (it2.hasNext()) {
            int next2 = it2.next();
            f = (f - getweight(i, next2)) - getweight(i2, next2);
        }
        return f;
    }

    private void setweight(int i, int i2, float f) {
        int i3;
        int i4;
        if (f == 0.0f) {
            return;
        }
        if (i <= i2) {
            i3 = i2;
            i4 = i;
        } else {
            i3 = i;
            i4 = i2;
        }
        TIntFloatHashMap tIntFloatHashMap = (TIntFloatHashMap) this.wcc.get(i4);
        if (tIntFloatHashMap == null) {
            tIntFloatHashMap = new TIntFloatHashMap();
            this.wcc.put(i4, tIntFloatHashMap);
        }
        tIntFloatHashMap.put(i3, f);
    }

    private float getweight(int i, int i2) {
        int i3;
        int i4;
        if (i <= i2) {
            i3 = i2;
            i4 = i;
        } else {
            i3 = i;
            i4 = i2;
        }
        TIntFloatHashMap tIntFloatHashMap = (TIntFloatHashMap) this.wcc.get(i4);
        return tIntFloatHashMap == null ? 0.0f : tIntFloatHashMap.get(i3);
    }

    public Cluster startClustering() {
        this.wordProb.remove(-1);
        this.wordProb.remove(-2);
        int[] trim = MyHashSparseArrays.trim(this.wordProb, ENERGY);
        int length = trim.length;
        int length2 = trim.length;
        System.out.println("[待合并个数：]\t" + length);
        System.out.println("[总个数：]\t" + this.totalword);
        int i = 0;
        while (i < Math.min(this.slotsize, length)) {
            this.slots.add(trim[i]);
            System.out.println(i + "\t" + this.alpahbet.lookupString(trim[i]) + "\t" + this.slots.size());
            i++;
        }
        TIntIterator it = this.slots.iterator();
        while (it.hasNext()) {
            int next = it.next();
            TIntIterator it2 = this.slots.iterator();
            while (it2.hasNext()) {
                int next2 = it2.next();
                if (next <= next2) {
                    weight(next, next2);
                }
            }
        }
        while (this.slots.size() > 1) {
            if (i < length) {
                System.out.println(i + "\t" + this.alpahbet.lookupString(trim[i]) + "\tSize:\t" + this.slots.size());
            } else {
                System.out.println(i + "\t\tSize:\t" + this.slots.size());
            }
            int i2 = length2;
            length2--;
            System.out.println("[待合并个数：]\t" + i2);
            long currentTimeMillis = System.currentTimeMillis();
            mergeCluster();
            System.out.println("\tTime:\t" + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
            if (i < length) {
                int i3 = trim[i];
                this.slots.add(i3);
                TIntIterator it3 = this.slots.iterator();
                while (it3.hasNext()) {
                    weight(it3.next(), i3);
                }
            } else if (!this.meger) {
                return null;
            }
            try {
                saveTxt("res-" + i);
            } catch (Exception e) {
                e.printStackTrace();
            }
            i++;
        }
        return (Cluster) this.clusters.get(this.slots.toArray()[0]);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap();
        for (int i = 0; i < this.alpahbet.size(); i++) {
            int head = getHead(i);
            TLinkedHashSet tLinkedHashSet = (TLinkedHashSet) tIntObjectHashMap.get(head);
            if (tLinkedHashSet == null) {
                tLinkedHashSet = new TLinkedHashSet();
                tIntObjectHashMap.put(head, tLinkedHashSet);
            }
            tLinkedHashSet.add(this.alpahbet.lookupString(i));
        }
        TIntObjectIterator it = tIntObjectHashMap.iterator();
        while (it.hasNext()) {
            it.advance();
            if (((TLinkedHashSet) it.value()).size() >= 2) {
                sb.append(this.wordProb.get(it.key()));
                sb.append(" ");
                TObjectHashIterator it2 = ((TLinkedHashSet) it.value()).iterator();
                while (it2.hasNext()) {
                    sb.append((String) it2.next());
                    sb.append(" ");
                }
                sb.append("\n");
            }
        }
        return sb.toString();
    }

    private int getHead(int i) {
        int i2 = this.heads.get(i);
        return i2 == -1 ? i : getHead(i2);
    }

    public void saveModel(String str) throws IOException {
        File parentFile = new File(str).getParentFile();
        if (!parentFile.exists()) {
            parentFile.mkdirs();
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(str))));
        objectOutputStream.writeObject(this);
        objectOutputStream.close();
    }

    public static WordCluster loadFrom(String str) throws IOException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new GZIPInputStream(new BufferedInputStream(new FileInputStream(str))));
        WordCluster wordCluster = (WordCluster) objectInputStream.readObject();
        objectInputStream.close();
        return wordCluster;
    }

    public void saveTxt(String str) throws Exception {
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(str), "UTF8"));
        bufferedWriter.write(toString());
        bufferedWriter.close();
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        options.addOption("path", true, "保存路径");
        options.addOption("res", true, "评测结果保存路径");
        options.addOption("slot", true, "槽大小");
        try {
            CommandLine parse = new BasicParser().parse(options, strArr);
            int parseInt = Integer.parseInt(parse.getOptionValue("slot", "50"));
            System.out.println("槽大小:" + parseInt);
            String optionValue = parse.getOptionValue("path", "./tmp/news.allsites.txt");
            System.out.println("数据路径:" + optionValue);
            String optionValue2 = parse.getOptionValue("res", "./tmp/res.txt");
            System.out.println("测试结果:" + optionValue2);
            SougouCA sougouCA = new SougouCA(optionValue);
            WordCluster wordCluster = new WordCluster();
            wordCluster.slotsize = parseInt;
            wordCluster.read(sougouCA);
            wordCluster.startClustering();
            wordCluster.saveModel(optionValue2 + ".m");
            wordCluster.saveTxt(optionValue2);
            loadFrom(optionValue2 + ".m").saveTxt(optionValue2 + "1");
            System.out.println(new Date().toString());
            System.out.println("Done");
        } catch (Exception e) {
            System.err.println("Parameters format error");
        }
    }
}
