package com.kotlinnlp.simplednn.deeplearning.transformers;

import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.embeddings.EmbeddingsMap;
import com.kotlinnlp.simplednn.core.layers.LayerParameters;
import com.kotlinnlp.simplednn.core.layers.StackedLayersParameters;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters;
import com.kotlinnlp.simplednn.core.layers.models.merge.concatff.ConcatFFLayerParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.utils.DictionarySet;
import com.kotlinnlp.utils.IOKt;
import com.kotlinnlp.utils.progressindicator.ProgressIndicator;
import com.kotlinnlp.utils.progressindicator.ProgressIndicatorBar;
import java.io.File;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.TypeCastException;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.collections.MapsKt;
import kotlin.io.FilesKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Ref;
import kotlin.ranges.RangesKt;
import kotlin.sequences.SequencesKt;
import kotlin.text.MatchResult;
import kotlin.text.Regex;
import kotlin.text.StringsKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: BERTBaseImportHelper.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��D\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010$\n\u0002\u0010\u000e\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\bÆ\u0002\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J0\u0010\u0005\u001a\u00020\u00062\u0012\u0010\u0007\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\n0\b2\f\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\t0\f2\u0006\u0010\r\u001a\u00020\u000eJ\u001c\u0010\u000f\u001a\u00020\u000e2\u0012\u0010\u0007\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\n0\bH\u0002J0\u0010\u0010\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\n0\b2\u0012\u0010\u0007\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\n0\b2\u0006\u0010\r\u001a\u00020\u000eH\u0002J\u001c\u0010\u0011\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\u00120\b2\u0006\u0010\u0013\u001a\u00020\u0006H\u0002J,\u0010\u0014\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\u00120\b2\u0006\u0010\u0007\u001a\u00020\u00152\u0006\u0010\u0016\u001a\u00020\u000e2\u0006\u0010\r\u001a\u00020\u000eH\u0002J\"\u0010\u0017\u001a\u000e\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\n0\b2\u0006\u0010\u0018\u001a\u00020\t2\u0006\u0010\r\u001a\u00020\u000eR\u000e\u0010\u0003\u001a\u00020\u0004X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u0019"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTBaseImportHelper;", "", "()V", "LAYER_INDEX_REGEX", "Lkotlin/text/Regex;", "buildModel", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTModel;", "params", "", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "vocab", "Lcom/kotlinnlp/utils/DictionarySet;", "numOfHeads", "", "countLayers", "expandAttentionParams", "getAssignMap", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "model", "getLayerAssignMap", "Lcom/kotlinnlp/simplednn/deeplearning/transformers/BERTParameters;", "i", "readParams", "filename", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/transformers/BERTBaseImportHelper.class */
public final class BERTBaseImportHelper {
    public static final BERTBaseImportHelper INSTANCE = new BERTBaseImportHelper();
    private static final Regex LAYER_INDEX_REGEX = new Regex("^bert\\.encoder\\.layer\\.(\\d+)");

    @NotNull
    public final BERTModel buildModel(@NotNull Map<String, DenseNDArray> map, @NotNull DictionarySet<String> dictionarySet, int i) {
        Intrinsics.checkParameterIsNotNull(map, "params");
        Intrinsics.checkParameterIsNotNull(dictionarySet, "vocab");
        DenseNDArray denseNDArray = (DenseNDArray) MapsKt.getValue(map, "bert.embeddings.word_embeddings.weight");
        int dim2 = denseNDArray.getShape().getDim2();
        if (!(dim2 % i == 0)) {
            throw new IllegalArgumentException(("The number of attention heads (" + i + ") must be an exact divider of the input size (" + dim2 + ')').toString());
        }
        EmbeddingsMap embeddingsMap = new EmbeddingsMap(dim2, null, false, 6, null);
        List<DenseNDArray> mo155getRows = denseNDArray.mo155getRows();
        for (String str : dictionarySet.getElements()) {
            Integer id = dictionarySet.getId(str);
            if (id == null) {
                Intrinsics.throwNpe();
            }
            embeddingsMap.set(str, new ParamsArray(mo155getRows.get(id.intValue()), (ParamsArray.ErrorsType) null, 2, (DefaultConstructorMarker) null));
        }
        BERTModel bERTModel = new BERTModel(dim2, dim2 / i, dim2 / i, ((DenseNDArray) MapsKt.getValue(map, "bert.encoder.layer.0.intermediate.dense.bias")).getShape().getDim1(), i, dictionarySet, embeddingsMap, countLayers(map), null, null);
        int i2 = 0;
        for (Object obj : ((DenseNDArray) MapsKt.getValue(map, "bert.embeddings.position_embeddings.weight")).mo155getRows()) {
            int i3 = i2;
            i2++;
            if (i3 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            bERTModel.getPositionalEmb().set(Integer.valueOf(i3), new ParamsArray((DenseNDArray) obj, (ParamsArray.ErrorsType) null, 2, (DefaultConstructorMarker) null));
        }
        int i4 = 0;
        for (Object obj2 : ((DenseNDArray) MapsKt.getValue(map, "bert.embeddings.token_type_embeddings.weight")).mo155getRows()) {
            int i5 = i4;
            i4++;
            if (i5 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            EmbeddingsMap.getOrSet$default(bERTModel.getTokenTypeEmb(), Integer.valueOf(i5), 0.0d, (ParamsArray) null, 6, (Object) null).getValues().assignValues((NDArray<?>) obj2);
        }
        for (Map.Entry<String, ParamsArray> entry : INSTANCE.getAssignMap(bERTModel).entrySet()) {
            entry.getValue().getValues().assignValues((NDArray<?>) MapsKt.getValue(map, entry.getKey()));
        }
        return bERTModel;
    }

    @NotNull
    public final Map<String, DenseNDArray> readParams(@NotNull String str, int i) {
        Intrinsics.checkParameterIsNotNull(str, "filename");
        final ProgressIndicatorBar progressIndicatorBar = new ProgressIndicatorBar(IOKt.getLinesCount(str), (OutputStream) null, 0, 6, (DefaultConstructorMarker) null);
        final LinkedHashMap linkedHashMap = new LinkedHashMap();
        final Ref.BooleanRef booleanRef = new Ref.BooleanRef();
        booleanRef.element = true;
        final Ref.ObjectRef objectRef = new Ref.ObjectRef();
        objectRef.element = "";
        final ArrayList arrayList = new ArrayList();
        FilesKt.forEachLine$default(new File(str), (Charset) null, new Function1<String, Unit>() { // from class: com.kotlinnlp.simplednn.deeplearning.transformers.BERTBaseImportHelper$readParams$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                invoke((String) obj);
                return Unit.INSTANCE;
            }

            public final void invoke(@NotNull String str2) {
                Intrinsics.checkParameterIsNotNull(str2, "line");
                ProgressIndicator.tick$default(progressIndicatorBar, 0, 1, (Object) null);
                if (booleanRef.element) {
                    booleanRef.element = false;
                    objectRef.element = StringsKt.trim(str2).toString();
                    arrayList.clear();
                } else {
                    if (StringsKt.isBlank(str2)) {
                        booleanRef.element = true;
                        linkedHashMap.put((String) objectRef.element, DenseNDArrayFactory.INSTANCE.arrayOf(arrayList));
                        return;
                    }
                    List list = arrayList;
                    List split$default = StringsKt.split$default(str2, new String[]{"\t"}, false, 0, 6, (Object) null);
                    ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(split$default, 10));
                    Iterator it = split$default.iterator();
                    while (it.hasNext()) {
                        arrayList2.add(Double.valueOf(Double.parseDouble((String) it.next())));
                    }
                    Object[] array = arrayList2.toArray(new Double[0]);
                    if (array == null) {
                        throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
                    }
                    list.add(ArraysKt.toDoubleArray((Double[]) array));
                }
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }
        }, 1, (Object) null);
        return expandAttentionParams(linkedHashMap, i);
    }

    private final Map<String, DenseNDArray> expandAttentionParams(Map<String, DenseNDArray> map, int i) {
        Map<String, DenseNDArray> mutableMap = MapsKt.toMutableMap(map);
        IntIterator it = RangesKt.until(0, countLayers(map)).iterator();
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            String str = "bert.encoder.layer." + nextInt + ".attention.self";
            int dim1 = ((DenseNDArray) MapsKt.getValue(map, str + ".query.bias")).getShape().getDim1() / i;
            List<DenseNDArray> mo155getRows = ((DenseNDArray) MapsKt.getValue(map, str + ".query.weight")).mo155getRows();
            List<DenseNDArray> mo155getRows2 = ((DenseNDArray) MapsKt.getValue(map, str + ".key.weight")).mo155getRows();
            List<DenseNDArray> mo155getRows3 = ((DenseNDArray) MapsKt.getValue(map, str + ".value.weight")).mo155getRows();
            IntIterator it2 = RangesKt.until(0, i).iterator();
            while (it2.hasNext()) {
                int nextInt2 = it2.nextInt();
                String str2 = "bert.encoder.layer." + nextInt + '.' + nextInt2 + ".attention.self";
                int i2 = nextInt2 * dim1;
                int i3 = (nextInt2 + 1) * dim1;
                mutableMap.put(str2 + ".query.weight", DenseNDArrayFactory.INSTANCE.fromRows(mo155getRows.subList(i2, i3)));
                mutableMap.put(str2 + ".query.bias", ((DenseNDArray) MapsKt.getValue(map, str + ".query.bias")).getRange(i2, i3));
                mutableMap.put(str2 + ".key.weight", DenseNDArrayFactory.INSTANCE.fromRows(mo155getRows2.subList(i2, i3)));
                mutableMap.put(str2 + ".key.bias", ((DenseNDArray) MapsKt.getValue(map, str + ".key.bias")).getRange(i2, i3));
                mutableMap.put(str2 + ".value.weight", DenseNDArrayFactory.INSTANCE.fromRows(mo155getRows3.subList(i2, i3)));
                mutableMap.put(str2 + ".value.bias", ((DenseNDArray) MapsKt.getValue(map, str + ".value.bias")).getRange(i2, i3));
            }
        }
        return mutableMap;
    }

    private final Map<String, ParamsArray> getAssignMap(BERTModel bERTModel) {
        Pair[] pairArr = new Pair[8];
        StackedLayersParameters embNorm = bERTModel.getEmbNorm();
        if (!(0 < embNorm.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (embNorm.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters = embNorm.getParamsPerLayer().get(0);
        if (layerParameters == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[0] = TuplesKt.to("bert.embeddings.LayerNorm.weight", ((NormLayerParameters) layerParameters).getG());
        StackedLayersParameters embNorm2 = bERTModel.getEmbNorm();
        if (!(0 < embNorm2.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (embNorm2.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters2 = embNorm2.getParamsPerLayer().get(0);
        if (layerParameters2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[1] = TuplesKt.to("bert.embeddings.LayerNorm.bias", ((NormLayerParameters) layerParameters2).getB());
        StackedLayersParameters classifier = bERTModel.getClassifier();
        if (!(0 < classifier.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (classifier.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters3 = classifier.getParamsPerLayer().get(0);
        if (layerParameters3 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[2] = TuplesKt.to("cls.predictions.transform.dense.weight", ((FeedforwardLayerParameters) layerParameters3).getUnit().getWeights());
        StackedLayersParameters classifier2 = bERTModel.getClassifier();
        if (!(0 < classifier2.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (classifier2.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters4 = classifier2.getParamsPerLayer().get(0);
        if (layerParameters4 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[3] = TuplesKt.to("cls.predictions.transform.dense.bias", ((FeedforwardLayerParameters) layerParameters4).getUnit().getBiases());
        StackedLayersParameters classifier3 = bERTModel.getClassifier();
        if (!(1 < classifier3.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (1) out of range ([0, " + (classifier3.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters5 = classifier3.getParamsPerLayer().get(1);
        if (layerParameters5 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[4] = TuplesKt.to("cls.predictions.transform.LayerNorm.weight", ((NormLayerParameters) layerParameters5).getG());
        StackedLayersParameters classifier4 = bERTModel.getClassifier();
        if (!(1 < classifier4.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (1) out of range ([0, " + (classifier4.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters6 = classifier4.getParamsPerLayer().get(1);
        if (layerParameters6 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[5] = TuplesKt.to("cls.predictions.transform.LayerNorm.bias", ((NormLayerParameters) layerParameters6).getB());
        StackedLayersParameters classifier5 = bERTModel.getClassifier();
        if (!(2 < classifier5.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (2) out of range ([0, " + (classifier5.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters7 = classifier5.getParamsPerLayer().get(2);
        if (layerParameters7 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[6] = TuplesKt.to("cls.predictions.decoder.weight", ((FeedforwardLayerParameters) layerParameters7).getUnit().getWeights());
        StackedLayersParameters classifier6 = bERTModel.getClassifier();
        if (!(2 < classifier6.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (2) out of range ([0, " + (classifier6.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters8 = classifier6.getParamsPerLayer().get(2);
        if (layerParameters8 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[7] = TuplesKt.to("cls.predictions.decoder.bias", ((FeedforwardLayerParameters) layerParameters8).getUnit().getBiases());
        Map<String, ParamsArray> mutableMapOf = MapsKt.mutableMapOf(pairArr);
        int i = 0;
        for (Object obj : bERTModel.getLayers()) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            mutableMapOf.putAll(INSTANCE.getLayerAssignMap((BERTParameters) obj, i2, bERTModel.getNumOfHeads()));
        }
        return mutableMapOf;
    }

    private final Map<String, ParamsArray> getLayerAssignMap(BERTParameters bERTParameters, int i, int i2) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        IntIterator it = RangesKt.until(0, i2).iterator();
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            linkedHashMap.putAll(MapsKt.mapOf(new Pair[]{TuplesKt.to("bert.encoder.layer." + i + '.' + nextInt + ".attention.self.query.weight", bERTParameters.getAttention().getAttention().get(nextInt).getQueries().getWeights()), TuplesKt.to("bert.encoder.layer." + i + '.' + nextInt + ".attention.self.query.bias", bERTParameters.getAttention().getAttention().get(nextInt).getQueries().getBiases()), TuplesKt.to("bert.encoder.layer." + i + '.' + nextInt + ".attention.self.key.weight", bERTParameters.getAttention().getAttention().get(nextInt).getKeys().getWeights()), TuplesKt.to("bert.encoder.layer." + i + '.' + nextInt + ".attention.self.key.bias", bERTParameters.getAttention().getAttention().get(nextInt).getKeys().getBiases()), TuplesKt.to("bert.encoder.layer." + i + '.' + nextInt + ".attention.self.value.weight", bERTParameters.getAttention().getAttention().get(nextInt).getValues().getWeights()), TuplesKt.to("bert.encoder.layer." + i + '.' + nextInt + ".attention.self.value.bias", bERTParameters.getAttention().getAttention().get(nextInt).getValues().getBiases())}));
        }
        Pair[] pairArr = new Pair[10];
        String str = "bert.encoder.layer." + i + ".attention.output.dense.weight";
        StackedLayersParameters merge = bERTParameters.getAttention().getMerge();
        if (!(0 < merge.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (merge.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters = merge.getParamsPerLayer().get(0);
        if (layerParameters == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.merge.concatff.ConcatFFLayerParameters");
        }
        pairArr[0] = TuplesKt.to(str, ((ConcatFFLayerParameters) layerParameters).getOutput().getUnit().getWeights());
        String str2 = "bert.encoder.layer." + i + ".attention.output.dense.bias";
        StackedLayersParameters merge2 = bERTParameters.getAttention().getMerge();
        if (!(0 < merge2.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (merge2.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters2 = merge2.getParamsPerLayer().get(0);
        if (layerParameters2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.merge.concatff.ConcatFFLayerParameters");
        }
        pairArr[1] = TuplesKt.to(str2, ((ConcatFFLayerParameters) layerParameters2).getOutput().getUnit().getBiases());
        String str3 = "bert.encoder.layer." + i + ".attention.output.LayerNorm.weight";
        StackedLayersParameters multiHeadNorm = bERTParameters.getMultiHeadNorm();
        if (!(0 < multiHeadNorm.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (multiHeadNorm.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters3 = multiHeadNorm.getParamsPerLayer().get(0);
        if (layerParameters3 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[2] = TuplesKt.to(str3, ((NormLayerParameters) layerParameters3).getG());
        String str4 = "bert.encoder.layer." + i + ".attention.output.LayerNorm.bias";
        StackedLayersParameters multiHeadNorm2 = bERTParameters.getMultiHeadNorm();
        if (!(0 < multiHeadNorm2.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (multiHeadNorm2.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters4 = multiHeadNorm2.getParamsPerLayer().get(0);
        if (layerParameters4 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[3] = TuplesKt.to(str4, ((NormLayerParameters) layerParameters4).getB());
        String str5 = "bert.encoder.layer." + i + ".intermediate.dense.weight";
        StackedLayersParameters outputFF = bERTParameters.getOutputFF();
        if (!(0 < outputFF.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (outputFF.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters5 = outputFF.getParamsPerLayer().get(0);
        if (layerParameters5 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[4] = TuplesKt.to(str5, ((FeedforwardLayerParameters) layerParameters5).getUnit().getWeights());
        String str6 = "bert.encoder.layer." + i + ".intermediate.dense.bias";
        StackedLayersParameters outputFF2 = bERTParameters.getOutputFF();
        if (!(0 < outputFF2.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (outputFF2.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters6 = outputFF2.getParamsPerLayer().get(0);
        if (layerParameters6 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[5] = TuplesKt.to(str6, ((FeedforwardLayerParameters) layerParameters6).getUnit().getBiases());
        String str7 = "bert.encoder.layer." + i + ".output.dense.weight";
        StackedLayersParameters outputFF3 = bERTParameters.getOutputFF();
        if (!(1 < outputFF3.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (1) out of range ([0, " + (outputFF3.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters7 = outputFF3.getParamsPerLayer().get(1);
        if (layerParameters7 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[6] = TuplesKt.to(str7, ((FeedforwardLayerParameters) layerParameters7).getUnit().getWeights());
        String str8 = "bert.encoder.layer." + i + ".output.dense.bias";
        StackedLayersParameters outputFF4 = bERTParameters.getOutputFF();
        if (!(1 < outputFF4.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (1) out of range ([0, " + (outputFF4.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters8 = outputFF4.getParamsPerLayer().get(1);
        if (layerParameters8 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters");
        }
        pairArr[7] = TuplesKt.to(str8, ((FeedforwardLayerParameters) layerParameters8).getUnit().getBiases());
        String str9 = "bert.encoder.layer." + i + ".output.LayerNorm.weight";
        StackedLayersParameters outputNorm = bERTParameters.getOutputNorm();
        if (!(0 < outputNorm.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (outputNorm.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters9 = outputNorm.getParamsPerLayer().get(0);
        if (layerParameters9 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[8] = TuplesKt.to(str9, ((NormLayerParameters) layerParameters9).getG());
        String str10 = "bert.encoder.layer." + i + ".output.LayerNorm.bias";
        StackedLayersParameters outputNorm2 = bERTParameters.getOutputNorm();
        if (!(0 < outputNorm2.getNumOfLayers())) {
            throw new IllegalArgumentException(("Layer index (0) out of range ([0, " + (outputNorm2.getNumOfLayers() - 1) + "]).").toString());
        }
        LayerParameters layerParameters10 = outputNorm2.getParamsPerLayer().get(0);
        if (layerParameters10 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.layers.models.feedforward.norm.NormLayerParameters");
        }
        pairArr[9] = TuplesKt.to(str10, ((NormLayerParameters) layerParameters10).getB());
        linkedHashMap.putAll(MapsKt.mapOf(pairArr));
        return linkedHashMap;
    }

    private final int countLayers(Map<String, DenseNDArray> map) {
        return SequencesKt.toSet(SequencesKt.map(SequencesKt.mapNotNull(CollectionsKt.asSequence(map.keySet()), new Function1<String, MatchResult>() { // from class: com.kotlinnlp.simplednn.deeplearning.transformers.BERTBaseImportHelper$countLayers$1
            @Nullable
            public final MatchResult invoke(@NotNull String str) {
                Regex regex;
                Intrinsics.checkParameterIsNotNull(str, "it");
                BERTBaseImportHelper bERTBaseImportHelper = BERTBaseImportHelper.INSTANCE;
                regex = BERTBaseImportHelper.LAYER_INDEX_REGEX;
                return Regex.find$default(regex, str, 0, 2, (Object) null);
            }
        }), new Function1<MatchResult, Integer>() { // from class: com.kotlinnlp.simplednn.deeplearning.transformers.BERTBaseImportHelper$countLayers$2
            public /* bridge */ /* synthetic */ Object invoke(Object obj) {
                return Integer.valueOf(invoke((MatchResult) obj));
            }

            public final int invoke(@NotNull MatchResult matchResult) {
                Intrinsics.checkParameterIsNotNull(matchResult, "it");
                return Integer.parseInt((String) matchResult.getGroupValues().get(1));
            }
        })).size();
    }

    private BERTBaseImportHelper() {
    }
}
