package com.kotlinnlp.lssencoder.decoder;

import com.kotlinnlp.linguisticdescription.sentence.token.FormToken;
import com.kotlinnlp.linguisticdescription.sentence.token.TokenIdentificable;
import com.kotlinnlp.lssencoder.LatentSyntacticStructure;
import com.kotlinnlp.simplednn.simplemath.SimpleMathKt;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.MapsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.sequences.SequencesKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CosineDecoder.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��6\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010%\n\u0002\u0010\b\n\u0002\u0010\u0006\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\b\u0007\u0018�� \u00142\u00020\u0001:\u0001\u0014B\u0005¢\u0006\u0002\u0010\u0002J\u0018\u0010\t\u001a\u00020\n2\u000e\u0010\u000b\u001a\n\u0012\u0002\b\u0003\u0012\u0002\b\u00030\u0004H\u0016J\u001f\u0010\f\u001a\u00020\r\"\b\b��\u0010\u000e*\u00020\u000f2\u0006\u0010\u0010\u001a\u0002H\u000eH\u0002¢\u0006\u0002\u0010\u0011J\u001f\u0010\u0012\u001a\u00020\r\"\b\b��\u0010\u000e*\u00020\u000f2\u0006\u0010\u0010\u001a\u0002H\u000eH\u0002¢\u0006\u0002\u0010\u0011J\u001f\u0010\u0013\u001a\u00020\r\"\b\b��\u0010\u000e*\u00020\u000f2\u0006\u0010\u0010\u001a\u0002H\u000eH\u0002¢\u0006\u0002\u0010\u0011R\u0016\u0010\u0003\u001a\n\u0012\u0002\b\u0003\u0012\u0002\b\u00030\u0004X\u0082.¢\u0006\u0002\n��R&\u0010\u0005\u001a\u001a\u0012\u0004\u0012\u00020\u0007\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\b0\u00060\u0006X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u0015"}, d2 = {"Lcom/kotlinnlp/lssencoder/decoder/CosineDecoder;", "Lcom/kotlinnlp/lssencoder/decoder/HeadsDecoder;", "()V", "lssNorm", "Lcom/kotlinnlp/lssencoder/LatentSyntacticStructure;", "similarityMatrix", "", "", "", "decode", "Lcom/kotlinnlp/lssencoder/decoder/ScoredArcs;", "lss", "normalizeToDistribution", "", "T", "Lcom/kotlinnlp/linguisticdescription/sentence/token/TokenIdentificable;", "dependent", "(Lcom/kotlinnlp/linguisticdescription/sentence/token/TokenIdentificable;)V", "setHeadsScores", "setRootScore", "Companion", "lssencoder"})
/* loaded from: input_file:com/kotlinnlp/lssencoder/decoder/CosineDecoder.class */
public final class CosineDecoder implements HeadsDecoder {
    private final Map<Integer, Map<Integer, Double>> similarityMatrix = new LinkedHashMap();
    private LatentSyntacticStructure<?, ?> lssNorm;
    private static final double HALF_PI = 1.5707963267948966d;
    public static final Companion Companion = new Companion(null);

    /* compiled from: CosineDecoder.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0012\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u000e\u0010\u0003\u001a\u00020\u0004X\u0082T¢\u0006\u0002\n��¨\u0006\u0005"}, d2 = {"Lcom/kotlinnlp/lssencoder/decoder/CosineDecoder$Companion;", "", "()V", "HALF_PI", "", "lssencoder"})
    /* loaded from: input_file:com/kotlinnlp/lssencoder/decoder/CosineDecoder$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    /* JADX WARN: Type inference failed for: r0v23, types: [com.kotlinnlp.linguisticdescription.sentence.SentenceIdentificable] */
    @Override // com.kotlinnlp.lssencoder.decoder.HeadsDecoder
    @NotNull
    public ScoredArcs decode(@NotNull LatentSyntacticStructure<?, ?> lss) {
        Intrinsics.checkParameterIsNotNull(lss, "lss");
        List<DenseNDArray> contextVectors = lss.getContextVectors();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(contextVectors, 10));
        Iterator<T> it = contextVectors.iterator();
        while (it.hasNext()) {
            arrayList.add((DenseNDArray) ((DenseNDArray) it.next()).normalize2());
        }
        ArrayList arrayList2 = arrayList;
        List<DenseNDArray> latentHeads = lss.getLatentHeads();
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(latentHeads, 10));
        Iterator<T> it2 = latentHeads.iterator();
        while (it2.hasNext()) {
            arrayList3.add((DenseNDArray) ((DenseNDArray) it2.next()).normalize2());
        }
        this.lssNorm = LatentSyntacticStructure.copy$default(lss, null, null, arrayList2, arrayList3, (DenseNDArray) lss.getVirtualRoot().normalize2(), 3, null);
        for (TokenIdentificable tokenIdentificable : lss.getSentence().getTokens()) {
            this.similarityMatrix.put(Integer.valueOf(tokenIdentificable.getId()), new LinkedHashMap());
            setHeadsScores(tokenIdentificable);
            setRootScore(tokenIdentificable);
            normalizeToDistribution(tokenIdentificable);
        }
        return new ScoredArcs(this.similarityMatrix);
    }

    /* JADX WARN: Type inference failed for: r0v6, types: [com.kotlinnlp.linguisticdescription.sentence.SentenceIdentificable] */
    private final <T extends TokenIdentificable> void setHeadsScores(final T t) {
        Map map = (Map) MapsKt.getValue(this.similarityMatrix, Integer.valueOf(t.getId()));
        LatentSyntacticStructure<?, ?> latentSyntacticStructure = this.lssNorm;
        if (latentSyntacticStructure == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lssNorm");
        }
        for (TokenIdentificable tokenIdentificable : SequencesKt.filter(CollectionsKt.asSequence(latentSyntacticStructure.getSentence().getTokens()), new Function1<TokenIdentificable, Boolean>() { // from class: com.kotlinnlp.lssencoder.decoder.CosineDecoder$setHeadsScores$1
            @Override // kotlin.jvm.functions.Function1
            public /* bridge */ /* synthetic */ Boolean invoke(TokenIdentificable tokenIdentificable2) {
                return Boolean.valueOf(invoke2(tokenIdentificable2));
            }

            /* renamed from: invoke, reason: avoid collision after fix types in other method */
            public final boolean invoke2(@NotNull TokenIdentificable it) {
                Intrinsics.checkParameterIsNotNull(it, "it");
                return it.getId() != TokenIdentificable.this.getId();
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(1);
            }
        })) {
            Integer valueOf = Integer.valueOf(tokenIdentificable.getId());
            LatentSyntacticStructure<?, ?> latentSyntacticStructure2 = this.lssNorm;
            if (latentSyntacticStructure2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("lssNorm");
            }
            DenseNDArray contextVectorById = latentSyntacticStructure2.getContextVectorById(tokenIdentificable.getId());
            LatentSyntacticStructure<?, ?> latentSyntacticStructure3 = this.lssNorm;
            if (latentSyntacticStructure3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("lssNorm");
            }
            Pair pair = TuplesKt.to(valueOf, Double.valueOf(SimpleMathKt.cosineSimilarity(contextVectorById, latentSyntacticStructure3.getLatentHeadById(t.getId()))));
            map.put(pair.getFirst(), pair.getSecond());
        }
    }

    private final <T extends TokenIdentificable> void setRootScore(T t) {
        ((Map) MapsKt.getValue(this.similarityMatrix, Integer.valueOf(t.getId()))).put(-1, Double.valueOf(0.0d));
        if (!(t instanceof FormToken) || ((FormToken) t).isPunctuation()) {
            return;
        }
        Map map = (Map) MapsKt.getValue(this.similarityMatrix, Integer.valueOf(t.getId()));
        LatentSyntacticStructure<?, ?> latentSyntacticStructure = this.lssNorm;
        if (latentSyntacticStructure == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lssNorm");
        }
        DenseNDArray latentHeadById = latentSyntacticStructure.getLatentHeadById(t.getId());
        LatentSyntacticStructure<?, ?> latentSyntacticStructure2 = this.lssNorm;
        if (latentSyntacticStructure2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("lssNorm");
        }
        map.put(-1, Double.valueOf(SimpleMathKt.cosineSimilarity(latentHeadById, latentSyntacticStructure2.getVirtualRoot())));
    }

    private final <T extends TokenIdentificable> void normalizeToDistribution(T t) {
        final Map map = (Map) MapsKt.getValue(this.similarityMatrix, Integer.valueOf(t.getId()));
        for (final Map.Entry entry : map.entrySet()) {
            map.compute(entry.getKey(), new BiFunction<Integer, Double, Double>() { // from class: com.kotlinnlp.lssencoder.decoder.CosineDecoder$normalizeToDistribution$1$1
                @Override // java.util.function.BiFunction
                public /* bridge */ /* synthetic */ Double apply(Integer num, Double d) {
                    return Double.valueOf(apply2(num, d));
                }

                /* renamed from: apply, reason: avoid collision after fix types in other method */
                public final double apply2(@NotNull Integer num, @Nullable Double d) {
                    Intrinsics.checkParameterIsNotNull(num, "<anonymous parameter 0>");
                    return 1.5707963267948966d - Math.acos(((Number) entry.getValue()).doubleValue());
                }
            });
        }
        final double sumOfDouble = CollectionsKt.sumOfDouble(map.values());
        for (final Map.Entry entry2 : map.entrySet()) {
            map.compute(entry2.getKey(), new BiFunction<Integer, Double, Double>() { // from class: com.kotlinnlp.lssencoder.decoder.CosineDecoder$normalizeToDistribution$$inlined$forEach$lambda$1
                @Override // java.util.function.BiFunction
                public /* bridge */ /* synthetic */ Double apply(Integer num, Double d) {
                    return Double.valueOf(apply2(num, d));
                }

                /* renamed from: apply, reason: avoid collision after fix types in other method */
                public final double apply2(@NotNull Integer num, @Nullable Double d) {
                    Intrinsics.checkParameterIsNotNull(num, "<anonymous parameter 0>");
                    return ((Number) entry2.getValue()).doubleValue() / sumOfDouble;
                }
            });
        }
    }
}
