package de.learnlib.algorithm.lsharp.ads;

import de.learnlib.algorithm.lsharp.ObservationTree;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import net.automatalib.common.util.HashUtil;
import net.automatalib.common.util.Pair;

/* loaded from: input_file:de/learnlib/algorithm/lsharp/ads/ADSTree.class */
public final class ADSTree<S extends Comparable<S>, I, O> implements ADS<I, O> {
    private final ADSNode<I, O> initialNode;
    private ADSNode<I, O> currentNode;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ADSTree(ObservationTree<S, I, O> observationTree, Collection<S> collection, O o) {
        ADSNode<I, O> constructADS = constructADS(observationTree, collection, o);
        this.initialNode = constructADS;
        this.currentNode = constructADS;
    }

    public int getScore() {
        return this.initialNode.getScore();
    }

    private static <A, B> Map<A, B> toMap(List<Pair<A, B>> list) {
        HashMap hashMap = new HashMap(HashUtil.capacity(list.size()));
        for (Pair<A, B> pair : list) {
            hashMap.put(pair.getFirst(), pair.getSecond());
        }
        return hashMap;
    }

    public static <S extends Comparable<S>, I, O> ADSNode<I, O> constructADS(ObservationTree<S, I, O> observationTree, Collection<S> collection, O o) {
        if (collection.size() == 1) {
            return new ADSNode<>();
        }
        HashMap hashMap = new HashMap();
        Map partitionOnOutput = partitionOnOutput(observationTree, collection, maximalBaseInput(observationTree, collection, hashMap).getFirst());
        int computeUI = computeUI(partitionOnOutput);
        int i = 0;
        for (Map.Entry entry : partitionOnOutput.entrySet()) {
            Object key = entry.getKey();
            List list = (List) entry.getValue();
            i += computeRegScore(list.size(), computeUI, Objects.equals(key, o) ? 0 : constructADS(observationTree, list, o).getScore());
        }
        ArrayList arrayList = new ArrayList(hashMap.size());
        for (Map.Entry entry2 : hashMap.entrySet()) {
            Object key2 = entry2.getKey();
            Pair pair = (Pair) entry2.getValue();
            if (((Integer) pair.getFirst()).intValue() + ((Integer) pair.getSecond()).intValue() >= i) {
                arrayList.add(key2);
            }
        }
        if (!$assertionsDisabled && arrayList.isEmpty()) {
            throw new AssertionError();
        }
        int i2 = -1;
        Object obj = null;
        ArrayList arrayList2 = null;
        for (Object obj2 : arrayList) {
            Map partitionOnOutput2 = partitionOnOutput(observationTree, collection, obj2);
            int computeUI2 = computeUI(partitionOnOutput2);
            int i3 = 0;
            ArrayList arrayList3 = new ArrayList(partitionOnOutput2.size());
            for (Map.Entry entry3 : partitionOnOutput2.entrySet()) {
                Pair computeOSubtree = computeOSubtree(observationTree, entry3.getKey(), (List) entry3.getValue(), o, computeUI2);
                i3 += ((Integer) computeOSubtree.getFirst()).intValue();
                arrayList3.add((Pair) computeOSubtree.getSecond());
            }
            if (i3 >= i && i3 >= i2) {
                i2 = i3;
                obj = obj2;
                arrayList2 = arrayList3;
            }
        }
        if ($assertionsDisabled || !(obj == null || arrayList2 == null)) {
            return new ADSNode<>(obj, toMap(arrayList2), i2);
        }
        throw new AssertionError();
    }

    public static <S extends Comparable<S>, I, O> Pair<Integer, Pair<O, ADSNode<I, O>>> computeOSubtree(ObservationTree<S, I, O> observationTree, O o, List<S> list, O o2, int i) {
        ADSNode aDSNode = Objects.equals(o, o2) ? new ADSNode() : constructADS(observationTree, list, o2);
        int score = aDSNode.getScore();
        return Pair.of(Integer.valueOf(computeRegScore(list.size(), i, score)), Pair.of(o, aDSNode));
    }

    private static int computeRegScore(int i, int i2, int i3) {
        return (i * (i2 - i)) + i3;
    }

    private static <S, O> int computeUI(Map<O, List<S>> map) {
        int i = 0;
        Iterator<List<S>> it = map.values().iterator();
        while (it.hasNext()) {
            i += it.next().size();
        }
        return i;
    }

    private static <S extends Comparable<S>, I, O> Map<O, List<S>> partitionOnOutput(ObservationTree<S, I, O> observationTree, Collection<S> collection, I i) {
        HashMap hashMap = new HashMap();
        Iterator<S> it = collection.iterator();
        while (it.hasNext()) {
            Pair<O, S> outSucc = observationTree.getOutSucc(it.next(), i);
            if (outSucc != null) {
                ((List) hashMap.computeIfAbsent(outSucc.getFirst(), obj -> {
                    return new ArrayList();
                })).add((Comparable) outSucc.getSecond());
            }
        }
        return hashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <S extends Comparable<S>, I, O> Pair<I, Integer> maximalBaseInput(ObservationTree<S, I, O> observationTree, Collection<S> collection, Map<I, Pair<Integer, Integer>> map) {
        Object symbol = observationTree.getInputAlphabet().getSymbol(0);
        int i = 0;
        for (Object obj : observationTree.getInputAlphabet()) {
            Map partitionOnOutput = partitionOnOutput(observationTree, collection, obj);
            ArrayList<Integer> arrayList = new ArrayList(partitionOnOutput.size());
            int i2 = 0;
            int i3 = 0;
            Iterator it = partitionOnOutput.values().iterator();
            while (it.hasNext()) {
                int size = ((List) it.next()).size();
                arrayList.add(Integer.valueOf(size));
                i2 += size * (size - 1);
                i3 += size;
            }
            int i4 = 0;
            for (Integer num : arrayList) {
                i4 += num.intValue() * (i3 - num.intValue());
            }
            map.put(obj, Pair.of(Integer.valueOf(i4), Integer.valueOf(i2)));
            if (i4 > i) {
                symbol = obj;
                i = i4;
            }
        }
        return Pair.of(symbol, Integer.valueOf(i));
    }

    @Override // de.learnlib.algorithm.lsharp.ads.ADS
    public I nextInput(O o) {
        if (o != null) {
            ADSNode<I, O> childNode = this.currentNode.getChildNode(o);
            if (childNode == null) {
                return null;
            }
            this.currentNode = childNode;
        }
        return this.currentNode.getInput();
    }

    @Override // de.learnlib.algorithm.lsharp.ads.ADS
    public void resetToRoot() {
        this.currentNode = this.initialNode;
    }

    static {
        $assertionsDisabled = !ADSTree.class.desiredAssertionStatus();
    }
}
