package org.platanios.tensorflow.api.ops.rnn.attention;

import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.core.NewAxis$;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.core.types.Cpackage;
import org.platanios.tensorflow.api.core.types.package$TF$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.OutputOps$;
import org.platanios.tensorflow.api.ops.rnn.attention.Attention;
import org.platanios.tensorflow.api.utilities.DefaultsTo$;
import org.platanios.tensorflow.jni.InvalidArgumentException;
import scala.Function1;
import scala.Predef$;
import scala.collection.Seq$;
import scala.reflect.ScalaSignature;
import scala.runtime.Nothing$;

/* compiled from: BahdanauAttention.scala */
@ScalaSignature(bytes = "\u0006\u0001\tMd\u0001B\u0001\u0003\u0001E\u0011\u0011CQ1iI\u0006t\u0017-^!ui\u0016tG/[8o\u0015\t\u0019A!A\u0005biR,g\u000e^5p]*\u0011QAB\u0001\u0004e:t'BA\u0004\t\u0003\ry\u0007o\u001d\u0006\u0003\u0013)\t1!\u00199j\u0015\tYA\"\u0001\u0006uK:\u001cxN\u001d4m_^T!!\u0004\b\u0002\u0013Ad\u0017\r^1oS>\u001c(\"A\b\u0002\u0007=\u0014xm\u0001\u0001\u0016\u0005II2C\u0001\u0001\u0014!\r!RcF\u0007\u0002\u0005%\u0011aC\u0001\u0002\u0010'&l\u0007\u000f\\3BiR,g\u000e^5p]B\u0011\u0001$\u0007\u0007\u0001\t\u0015Q\u0002A1\u0001\u001c\u0005\u0005!\u0016C\u0001\u000f#!\ti\u0002%D\u0001\u001f\u0015\u0005y\u0012!B:dC2\f\u0017BA\u0011\u001f\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"!H\u0012\n\u0005\u0011r\"aA!os\"Aa\u0005\u0001BC\u0002\u0013\u0005s%\u0001\u0006nK6|'/_*ju\u0016,\u0012\u0001\u000b\t\u0004S)bS\"\u0001\u0004\n\u0005-2!AB(viB,H\u000f\u0005\u0002\u001e[%\u0011aF\b\u0002\u0004\u0013:$\b\u0002\u0003\u0019\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u0015\u0002\u00175,Wn\u001c:z'&TX\r\t\u0005\te\u0001\u0011)\u0019!C\u0001g\u0005iQ.Z7pef<V-[4iiN,\u0012\u0001\u000e\t\u0004S):\u0002\u0002\u0003\u001c\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001b\u0002\u001d5,Wn\u001c:z/\u0016Lw\r\u001b;tA!A\u0001\b\u0001BC\u0002\u0013\u00051'\u0001\u0007rk\u0016\u0014\u0018pV3jO\"$8\u000f\u0003\u0005;\u0001\t\u0005\t\u0015!\u00035\u00035\tX/\u001a:z/\u0016Lw\r\u001b;tA!AA\b\u0001BC\u0002\u0013\u00051'\u0001\u0007tG>\u0014XmV3jO\"$8\u000f\u0003\u0005?\u0001\t\u0005\t\u0015!\u00035\u00035\u00198m\u001c:f/\u0016Lw\r\u001b;tA!A\u0001\t\u0001BC\u0002\u0013\u0005\u0011)A\u0007qe>\u0014\u0017MY5mSRLhI\\\u000b\u0002\u0005B!Qd\u0011\u001b5\u0013\t!eDA\u0005Gk:\u001cG/[8oc!Aa\t\u0001B\u0001B\u0003%!)\u0001\bqe>\u0014\u0017MY5mSRLhI\u001c\u0011\t\u0011!\u0003!Q1A\u0005\u0002M\n1C\\8s[\u0006d\u0017N_1uS>tg)Y2u_JD\u0001B\u0013\u0001\u0003\u0002\u0003\u0006I\u0001N\u0001\u0015]>\u0014X.\u00197ju\u0006$\u0018n\u001c8GC\u000e$xN\u001d\u0011\t\u00111\u0003!Q1A\u0005\u0002M\n\u0011C\\8s[\u0006d\u0017N_1uS>t')[1t\u0011!q\u0005A!A!\u0002\u0013!\u0014A\u00058pe6\fG.\u001b>bi&|gNQ5bg\u0002B\u0001\u0002\u0015\u0001\u0003\u0006\u0004%\t%U\u0001\u000fg\u000e|'/Z'bg.4\u0016\r\\;f+\u0005\u0011\u0006cA\u0015+'B\u0011Q\u0004V\u0005\u0003+z\u0011QA\u00127pCRD\u0011b\u0016\u0001\u0003\u0002\u0003\u0006IA\u0015-\u0002\u001fM\u001cwN]3NCN\\g+\u00197vK\u0002J!\u0001U\u000b\t\u0011i\u0003!Q1A\u0005Bm\u000bAA\\1nKV\tA\f\u0005\u0002^I:\u0011aL\u0019\t\u0003?zi\u0011\u0001\u0019\u0006\u0003CB\ta\u0001\u0010:p_Rt\u0014BA2\u001f\u0003\u0019\u0001&/\u001a3fM&\u0011QM\u001a\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005\rt\u0002\u0002\u00035\u0001\u0005\u0003\u0005\u000b\u0011\u0002/\u0002\u000b9\fW.\u001a\u0011\t\u0011)\u0004!1!Q\u0001\f-\f!\"\u001a<jI\u0016t7-\u001a\u00132!\rawp\u0006\b\u0003[rt!A\\=\u000f\u0005=<hB\u00019w\u001d\t\tXO\u0004\u0002si:\u0011ql]\u0005\u0002\u001f%\u0011QBD\u0005\u0003\u00171I!!\u0003\u0006\n\u0005aD\u0011\u0001B2pe\u0016L!A_>\u0002\u000bQL\b/Z:\u000b\u0005aD\u0011BA?\u007f\u0003\u001d\u0001\u0018mY6bO\u0016T!A_>\n\t\u0005\u0005\u00111\u0001\u0002\u0003)\u001aS!! @\t\u0015\u0005\u001d\u0001AaA!\u0002\u0017\tI!\u0001\u0006fm&$WM\\2fII\u0002B\u0001\\A\u0006/%!\u0011QBA\u0002\u0005%I5\u000fR3dS6\fG\u000eC\u0004\u0002\u0012\u0001!\t!a\u0005\u0002\rqJg.\u001b;?)Q\t)\"!\b\u0002 \u0005\u0005\u00121EA\u0013\u0003O\tI#a\u000b\u0002.Q1\u0011qCA\r\u00037\u00012\u0001\u0006\u0001\u0018\u0011\u0019Q\u0017q\u0002a\u0002W\"A\u0011qAA\b\u0001\b\tI\u0001\u0003\u0004'\u0003\u001f\u0001\r\u0001\u000b\u0005\u0007e\u0005=\u0001\u0019\u0001\u001b\t\ra\ny\u00011\u00015\u0011\u0019a\u0014q\u0002a\u0001i!1\u0001)a\u0004A\u0002\tC\u0001\u0002SA\b!\u0003\u0005\r\u0001\u000e\u0005\t\u0019\u0006=\u0001\u0013!a\u0001i!A\u0001+a\u0004\u0011\u0002\u0003\u0007!\u000b\u0003\u0005[\u0003\u001f\u0001\n\u00111\u0001]\u0011\u001d\t\t\u0004\u0001C!\u0003g\t\u0011b[3zgNC\u0017\r]3\u0015\t\u0005U\u0012Q\b\t\u0005\u0003o\tI$D\u0001|\u0013\r\tYd\u001f\u0002\u0006'\"\f\u0007/\u001a\u0005\t\u0003\u007f\ty\u00031\u0001\u00026\u0005Ya/\u00197vKN\u001c\u0006.\u00199f\u0011\u001d\t\u0019\u0005\u0001C)\u0003\u000b\nAa[3zgR)A'a\u0012\u0002Z!A\u0011\u0011JA!\u0001\u0004\tY%\u0001\u0004nK6|'/\u001f\t\u0006\u0003\u001b\n\u0019f\u0006\b\u0004)\u0005=\u0013bAA)\u0005\u0005I\u0011\t\u001e;f]RLwN\\\u0005\u0005\u0003+\n9F\u0001\u0004NK6|'/\u001f\u0006\u0004\u0003#\u0012\u0001bBA.\u0003\u0003\u0002\r\u0001N\u0001\u0007m\u0006dW/Z:\t\u000f\u0005}\u0003\u0001\"\u0015\u0002b\u0005)1oY8sKR)A'a\u0019\u0002h!9\u0011QMA/\u0001\u0004!\u0014!B9vKJL\b\u0002CA5\u0003;\u0002\r!a\u001b\u0002\u000bM$\u0018\r^3\u0011\r\u00055\u0013QN\f5\u0013\u0011\ty'a\u0016\u0003\u000bM#\u0018\r^3)\r\u0005u\u00131OAG!\u0015i\u0012QOA=\u0013\r\t9H\b\u0002\u0007i\"\u0014xn^:\u0011\t\u0005m\u0014q\u0011\b\u0005\u0003{\n\tID\u0002o\u0003\u007fJ!!`>\n\t\u0005\r\u0015QQ\u0001\nKb\u001cW\r\u001d;j_:T!!`>\n\t\u0005%\u00151\u0012\u0002\u0019\u0013:4\u0018\r\\5e\u0003J<W/\\3oi\u0016C8-\u001a9uS>t'\u0002BAB\u0003\u000b\u000bdA\b/\u0002\u0010\u0006U\u0016'C\u0012\u0002\u0012\u0006U\u00151VAL+\rY\u00161\u0013\u0003\u00075A\u0011\r!!(\n\t\u0005]\u0015\u0011T\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000fJ\u0019\u000b\u0007\u0005me$\u0001\u0004uQJ|wo]\t\u00049\u0005}\u0005\u0003BAQ\u0003Ks1!HAR\u0013\tih$\u0003\u0003\u0002(\u0006%&!\u0003+ie><\u0018M\u00197f\u0015\tih$M\u0005$\u0003[\u000by+!-\u0002\u001c:\u0019Q$a,\n\u0007\u0005me$M\u0003#;y\t\u0019LA\u0003tG\u0006d\u0017-M\u0002'\u0003sBq!!/\u0001\t#\nY,A\u0006qe>\u0014\u0017MY5mSRLH#\u0002\u001b\u0002>\u0006}\u0006bBA0\u0003o\u0003\r\u0001\u000e\u0005\t\u0003S\n9\f1\u0001\u0002l\u001d9\u00111\u0019\u0002\t\u0002\u0005\u0015\u0017!\u0005\"bQ\u0012\fg.Y;BiR,g\u000e^5p]B\u0019A#a2\u0007\r\u0005\u0011\u0001\u0012AAe'\u0011\t9-a3\u0011\u0007u\ti-C\u0002\u0002Pz\u0011a!\u00118z%\u00164\u0007\u0002CA\t\u0003\u000f$\t!a5\u0015\u0005\u0005\u0015\u0007\u0002CAl\u0003\u000f$\t!!7\u0002\u000b\u0005\u0004\b\u000f\\=\u0016\t\u0005m\u00171\u001d\u000b\u0015\u0003;\f\t0a=\u0002x\u0006e\u00181`A��\u0005\u0003\u0011\u0019A!\u0002\u0015\r\u0005}\u0017Q]Av!\u0011!\u0002!!9\u0011\u0007a\t\u0019\u000f\u0002\u0004\u001b\u0003+\u0014\ra\u0007\u0005\u000b\u0003O\f).!AA\u0004\u0005%\u0018AC3wS\u0012,gnY3%gA!An`Aq\u0011)\ti/!6\u0002\u0002\u0003\u000f\u0011q^\u0001\u000bKZLG-\u001a8dK\u0012\"\u0004#\u00027\u0002\f\u0005\u0005\bB\u0002\u0014\u0002V\u0002\u0007\u0001\u0006C\u00043\u0003+\u0004\r!!>\u0011\t%R\u0013\u0011\u001d\u0005\bq\u0005U\u0007\u0019AA{\u0011\u001da\u0014Q\u001ba\u0001\u0003kD\u0011\u0002QAk!\u0003\u0005\r!!@\u0011\ru\u0019\u0015Q_A{\u0011%A\u0015Q\u001bI\u0001\u0002\u0004\t)\u0010C\u0005M\u0003+\u0004\n\u00111\u0001\u0002v\"A\u0001+!6\u0011\u0002\u0003\u0007!\u000b\u0003\u0005[\u0003+\u0004\n\u00111\u0001]\u0011)\u0011I!a2\u0012\u0002\u0013\u0005!1B\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000f\n\u001c\u0016\t\t5!\u0011F\u000b\u0003\u0005\u001fQCA!\u0005\u0003\u0018A\u0019QDa\u0005\n\u0007\tUaD\u0001\u0003Ok2d7F\u0001B\r!\u0011\u0011YB!\n\u000e\u0005\tu!\u0002\u0002B\u0010\u0005C\t\u0011\"\u001e8dQ\u0016\u001c7.\u001a3\u000b\u0007\t\rb$\u0001\u0006b]:|G/\u0019;j_:LAAa\n\u0003\u001e\t\tRO\\2iK\u000e\\W\r\u001a,be&\fgnY3\u0005\ri\u00119A1\u0001\u001c\u0011)\u0011i#a2\u0012\u0002\u0013\u0005!qF\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000fJ\u001c\u0016\t\t5!\u0011\u0007\u0003\u00075\t-\"\u0019A\u000e\t\u0015\tU\u0012qYI\u0001\n\u0003\u00119$A\u000e%Y\u0016\u001c8/\u001b8ji\u0012:'/Z1uKJ$C-\u001a4bk2$H\u0005O\u000b\u0005\u0005s\u0011i$\u0006\u0002\u0003<)\u001a!Ka\u0006\u0005\ri\u0011\u0019D1\u0001\u001c\u0011)\u0011\t%a2\u0012\u0002\u0013\u0005!1I\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000fJ\u001d\u0016\t\t\u0015#\u0011J\u000b\u0003\u0005\u000fR3\u0001\u0018B\f\t\u0019Q\"q\bb\u00017!Q!QJAd#\u0003%\tAa\u0014\u0002\u001f\u0005\u0004\b\u000f\\=%I\u00164\u0017-\u001e7uIU*BA!\u0004\u0003R\u00111!Da\u0013C\u0002mA!B!\u0016\u0002HF\u0005I\u0011\u0001B,\u0003=\t\u0007\u000f\u001d7zI\u0011,g-Y;mi\u00122T\u0003\u0002B\u0007\u00053\"aA\u0007B*\u0005\u0004Y\u0002B\u0003B/\u0003\u000f\f\n\u0011\"\u0001\u0003`\u0005y\u0011\r\u001d9ms\u0012\"WMZ1vYR$s'\u0006\u0003\u0003\u000e\t\u0005DA\u0002\u000e\u0003\\\t\u00071\u0004\u0003\u0006\u0003f\u0005\u001d\u0017\u0013!C\u0001\u0005O\nq\"\u00199qYf$C-\u001a4bk2$H\u0005O\u000b\u0005\u0005s\u0011I\u0007\u0002\u0004\u001b\u0005G\u0012\ra\u0007\u0005\u000b\u0005[\n9-%A\u0005\u0002\t=\u0014aD1qa2LH\u0005Z3gCVdG\u000fJ\u001d\u0016\t\t\u0015#\u0011\u000f\u0003\u00075\t-$\u0019A\u000e")
/* loaded from: input_file:org/platanios/tensorflow/api/ops/rnn/attention/BahdanauAttention.class */
public class BahdanauAttention<T> extends SimpleAttention<T> {
    private final Output<Object> memorySize;
    private final Output<T> memoryWeights;
    private final Output<T> queryWeights;
    private final Output<T> scoreWeights;
    private final Function1<Output<T>, Output<T>> probabilityFn;
    private final Output<T> normalizationFactor;
    private final Output<T> normalizationBias;
    private final String name;
    private final Cpackage.TF<T> evidence$1;
    private final Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> evidence$2;

    public static <T> BahdanauAttention<T> apply(Output<Object> output, Output<T> output2, Output<T> output3, Output<T> output4, Function1<Output<T>, Output<T>> function1, Output<T> output5, Output<T> output6, Output<Object> output7, String str, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        return BahdanauAttention$.MODULE$.apply(output, output2, output3, output4, function1, output5, output6, output7, str, tf, lessVar);
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention, org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public Output<Object> memorySize() {
        return this.memorySize;
    }

    public Output<T> memoryWeights() {
        return this.memoryWeights;
    }

    public Output<T> queryWeights() {
        return this.queryWeights;
    }

    public Output<T> scoreWeights() {
        return this.scoreWeights;
    }

    public Function1<Output<T>, Output<T>> probabilityFn() {
        return this.probabilityFn;
    }

    public Output<T> normalizationFactor() {
        return this.normalizationFactor;
    }

    public Output<T> normalizationBias() {
        return this.normalizationBias;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<Object> scoreMaskValue() {
        return super.scoreMaskValue();
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention, org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public String name() {
        return this.name;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public Shape keysShape(Shape shape) {
        return shape.apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0)))).$plus(memoryWeights().shape().apply(-1));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> keys(Attention.Memory<T> memory, Output<T> output) {
        if (output.rank() != 3) {
            return Math$.MODULE$.matmul(output, memoryWeights(), Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2);
        }
        Output<T> reshape = Basic$.MODULE$.reshape(Math$.MODULE$.matmul(Basic$.MODULE$.reshape(output, Basic$.MODULE$.stack(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.constant(Implicits$.MODULE$.intToTensor(-1), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3()), Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), this.evidence$1).castTo(package$TF$.MODULE$.intEvTF()).slice(Implicits$.MODULE$.intToIndex(-1), Predef$.MODULE$.wrapRefArray(new Indexer[0]))})), Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.reshape$default$3(), this.evidence$1, package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), memoryWeights(), Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2), Basic$.MODULE$.concatenate(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), this.evidence$1).castTo(package$TF$.MODULE$.intEvTF()).slice(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0))), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Basic$.MODULE$.shape(memoryWeights(), Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), this.evidence$1).castTo(package$TF$.MODULE$.intEvTF()).slice(Implicits$.MODULE$.intToIndex(-1), Predef$.MODULE$.wrapRefArray(new Indexer[]{NewAxis$.MODULE$}))})), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.concatenate$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.reshape$default$3(), this.evidence$1, package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms());
        reshape.setShape(output.shape().apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0)))).$plus(memoryWeights().shape().apply(-1)));
        return reshape;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> score(Output<T> output, Attention.State<T, Output<T>> state) throws InvalidArgumentException {
        Output<T> $times;
        int apply = output.shape().apply(-1);
        int apply2 = state.keys().shape().apply(-1);
        if (apply != apply2) {
            throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(161).append("Incompatible or unknown inner dimensions between query and keys. ").append(new StringBuilder(21).append("Query (").append(output.name()).append(") has ").append(apply).append(" units. ").toString()).append(new StringBuilder(21).append("Keys (").append(state.keys().name()).append(") have ").append(apply2).append(" units. ").toString()).append("Perhaps you need to set the number of units of the attention model ").append("to the keys' number of units.").toString());
        }
        Output<T> expandDims = Implicits$.MODULE$.outputBasicOps(Math$.MODULE$.matmul(output, queryWeights(), Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2)).expandDims(Implicits$.MODULE$.intToOutput(1));
        if (normalizationFactor() == null) {
            $times = scoreWeights();
        } else {
            Output<T> $times2 = normalizationFactor().$times(scoreWeights(), this.evidence$2);
            Math$ math$ = Math$.MODULE$;
            Math$ math$2 = Math$.MODULE$;
            Output<T> output2 = (Output) Math$.MODULE$.square(scoreWeights(), Math$.MODULE$.square$default$2(), this.evidence$1, this.evidence$2, OutputOps$.MODULE$.outputOps());
            Math$.MODULE$.sum$default$2();
            $times = $times2.$times((Output) math$.rsqrt(math$2.sum(output2, null, Math$.MODULE$.sum$default$3(), Math$.MODULE$.sum$default$4(), this.evidence$1, this.evidence$2, DefaultsTo$.MODULE$.defaultDefaultsTo(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), Math$.MODULE$.rsqrt$default$2(), this.evidence$1, this.evidence$2, OutputOps$.MODULE$.outputOps()), this.evidence$2);
        }
        Output<T> output3 = $times;
        return normalizationBias() == null ? Math$.MODULE$.sum(output3.$times((Output) Math$.MODULE$.tanh(state.keys().$plus(expandDims, this.evidence$2), Math$.MODULE$.tanh$default$2(), this.evidence$1, this.evidence$2, OutputOps$.MODULE$.outputOps()), this.evidence$2), Implicits$.MODULE$.intToOutput(2), Math$.MODULE$.sum$default$3(), Math$.MODULE$.sum$default$4(), this.evidence$1, this.evidence$2, DefaultsTo$.MODULE$.defaultDefaultsTo(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()) : Math$.MODULE$.sum(output3.$times((Output) Math$.MODULE$.tanh(state.keys().$plus(expandDims, this.evidence$2).$plus(normalizationBias(), this.evidence$2), Math$.MODULE$.tanh$default$2(), this.evidence$1, this.evidence$2, OutputOps$.MODULE$.outputOps()), this.evidence$2), Implicits$.MODULE$.intToOutput(2), Math$.MODULE$.sum$default$3(), Math$.MODULE$.sum$default$4(), this.evidence$1, this.evidence$2, DefaultsTo$.MODULE$.defaultDefaultsTo(), package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms());
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> probability(Output<T> output, Attention.State<T, Output<T>> state) {
        return (Output) probabilityFn().apply(output);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public BahdanauAttention(Output<Object> output, Output<T> output2, Output<T> output3, Output<T> output4, Function1<Output<T>, Output<T>> function1, Output<T> output5, Output<T> output6, Output<Object> output7, String str, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        super(output, output7, str, tf, lessVar);
        this.memorySize = output;
        this.memoryWeights = output2;
        this.queryWeights = output3;
        this.scoreWeights = output4;
        this.probabilityFn = function1;
        this.normalizationFactor = output5;
        this.normalizationBias = output6;
        this.name = str;
        this.evidence$1 = tf;
        this.evidence$2 = lessVar;
    }
}
