package org.platanios.tensorflow.api.ops.rnn.attention;

import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.IndexerConstructionWithTwoNumbers$;
import org.platanios.tensorflow.api.core.NewAxis$;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.core.types.Cpackage;
import org.platanios.tensorflow.api.core.types.package$TF$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.rnn.attention.Attention;
import org.platanios.tensorflow.jni.InvalidArgumentException;
import scala.Function1;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ScalaSignature;
import scala.runtime.Nothing$;

/* compiled from: LuongAttention.scala */
@ScalaSignature(bytes = "\u0006\u0001\t\u0005c\u0001B\u0001\u0003\u0001E\u0011a\u0002T;p]\u001e\fE\u000f^3oi&|gN\u0003\u0002\u0004\t\u0005I\u0011\r\u001e;f]RLwN\u001c\u0006\u0003\u000b\u0019\t1A\u001d8o\u0015\t9\u0001\"A\u0002paNT!!\u0003\u0006\u0002\u0007\u0005\u0004\u0018N\u0003\u0002\f\u0019\u0005QA/\u001a8t_J4Gn\\<\u000b\u00055q\u0011!\u00039mCR\fg.[8t\u0015\u0005y\u0011aA8sO\u000e\u0001QC\u0001\n\u001a'\t\u00011\u0003E\u0002\u0015+]i\u0011AA\u0005\u0003-\t\u0011qbU5na2,\u0017\t\u001e;f]RLwN\u001c\t\u00031ea\u0001\u0001B\u0003\u001b\u0001\t\u00071DA\u0001U#\ta\"\u0005\u0005\u0002\u001eA5\taDC\u0001 \u0003\u0015\u00198-\u00197b\u0013\t\tcDA\u0004O_RD\u0017N\\4\u0011\u0005u\u0019\u0013B\u0001\u0013\u001f\u0005\r\te.\u001f\u0005\tM\u0001\u0011)\u0019!C!O\u0005QQ.Z7pef\u001c\u0016N_3\u0016\u0003!\u00022!\u000b\u0016-\u001b\u00051\u0011BA\u0016\u0007\u0005\u0019yU\u000f\u001e9viB\u0011Q$L\u0005\u0003]y\u00111!\u00138u\u0011!\u0001\u0004A!A!\u0002\u0013A\u0013aC7f[>\u0014\u0018pU5{K\u0002B\u0001B\r\u0001\u0003\u0006\u0004%\taM\u0001\u000e[\u0016lwN]=XK&<\u0007\u000e^:\u0016\u0003Q\u00022!\u000b\u0016\u0018\u0011!1\u0004A!A!\u0002\u0013!\u0014AD7f[>\u0014\u0018pV3jO\"$8\u000f\t\u0005\tq\u0001\u0011)\u0019!C\u0001s\u0005i\u0001O]8cC\nLG.\u001b;z\r:,\u0012A\u000f\t\u0005;m\"D'\u0003\u0002==\tIa)\u001e8di&|g.\r\u0005\t}\u0001\u0011\t\u0011)A\u0005u\u0005q\u0001O]8cC\nLG.\u001b;z\r:\u0004\u0003\u0002\u0003!\u0001\u0005\u000b\u0007I\u0011A\u001a\u0002\u0017M\u001c\u0017\r\\3GC\u000e$xN\u001d\u0005\t\u0005\u0002\u0011\t\u0011)A\u0005i\u0005a1oY1mK\u001a\u000b7\r^8sA!AA\t\u0001BC\u0002\u0013\u0005S)\u0001\btG>\u0014X-T1tWZ\u000bG.^3\u0016\u0003\u0019\u00032!\u000b\u0016H!\ti\u0002*\u0003\u0002J=\t)a\t\\8bi\"I1\n\u0001B\u0001B\u0003%a\tT\u0001\u0010g\u000e|'/Z'bg.4\u0016\r\\;fA%\u0011A)\u0006\u0005\t\u001d\u0002\u0011)\u0019!C!\u001f\u0006!a.Y7f+\u0005\u0001\u0006CA)Y\u001d\t\u0011f\u000b\u0005\u0002T=5\tAK\u0003\u0002V!\u00051AH]8pizJ!a\u0016\u0010\u0002\rA\u0013X\rZ3g\u0013\tI&L\u0001\u0004TiJLgn\u001a\u0006\u0003/zA\u0001\u0002\u0018\u0001\u0003\u0002\u0003\u0006I\u0001U\u0001\u0006]\u0006lW\r\t\u0005\t=\u0002\u0011\u0019\u0011)A\u0006?\u0006QQM^5eK:\u001cW\rJ\u0019\u0011\u0007\u0001\u001cxC\u0004\u0002ba:\u0011!-\u001c\b\u0003G.t!\u0001\u001a6\u000f\u0005\u0015LgB\u00014i\u001d\t\u0019v-C\u0001\u0010\u0013\tia\"\u0003\u0002\f\u0019%\u0011\u0011BC\u0005\u0003Y\"\tAaY8sK&\u0011an\\\u0001\u0006if\u0004Xm\u001d\u0006\u0003Y\"I!!\u001d:\u0002\u000fA\f7m[1hK*\u0011an\\\u0005\u0003iV\u0014!\u0001\u0016$\u000b\u0005E\u0014\b\u0002C<\u0001\u0005\u0007\u0005\u000b1\u0002=\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$#\u0007E\u0002as^I!A_;\u0003\u0013%\u001bH)Z2j[\u0006d\u0007\"\u0002?\u0001\t\u0003i\u0018A\u0002\u001fj]&$h\bF\u0007\u007f\u0003\u000b\t9!!\u0003\u0002\f\u00055\u0011q\u0002\u000b\u0006\u007f\u0006\u0005\u00111\u0001\t\u0004)\u00019\u0002\"\u00020|\u0001\by\u0006\"B<|\u0001\bA\b\"\u0002\u0014|\u0001\u0004A\u0003\"\u0002\u001a|\u0001\u0004!\u0004\"\u0002\u001d|\u0001\u0004Q\u0004b\u0002!|!\u0003\u0005\r\u0001\u000e\u0005\b\tn\u0004\n\u00111\u0001G\u0011\u001dq5\u0010%AA\u0002ACq!a\u0005\u0001\t\u0003\n)\"A\u0005lKf\u001c8\u000b[1qKR!\u0011qCA\u0010!\u0011\tI\"a\u0007\u000e\u0003=L1!!\bp\u0005\u0015\u0019\u0006.\u00199f\u0011!\t\t#!\u0005A\u0002\u0005]\u0011a\u0003<bYV,7o\u00155ba\u0016Dq!!\n\u0001\t#\n9#\u0001\u0003lKf\u001cH#\u0002\u001b\u0002*\u0005m\u0002\u0002CA\u0016\u0003G\u0001\r!!\f\u0002\r5,Wn\u001c:z!\u0015\ty#!\u000e\u0018\u001d\r!\u0012\u0011G\u0005\u0004\u0003g\u0011\u0011!C!ui\u0016tG/[8o\u0013\u0011\t9$!\u000f\u0003\r5+Wn\u001c:z\u0015\r\t\u0019D\u0001\u0005\b\u0003{\t\u0019\u00031\u00015\u0003\u00191\u0018\r\\;fg\"9\u0011\u0011\t\u0001\u0005R\u0005\r\u0013!B:d_J,G#\u0002\u001b\u0002F\u0005%\u0003bBA$\u0003\u007f\u0001\r\u0001N\u0001\u0006cV,'/\u001f\u0005\t\u0003\u0017\ny\u00041\u0001\u0002N\u0005)1\u000f^1uKB1\u0011qFA(/QJA!!\u0015\u0002:\t)1\u000b^1uK\"2\u0011qHA+\u0003_\u0002R!HA,\u00037J1!!\u0017\u001f\u0005\u0019!\bN]8xgB!\u0011QLA5\u001d\u0011\ty&a\u0019\u000f\u0007\t\f\t'\u0003\u0002r_&!\u0011QMA4\u0003%)\u0007pY3qi&|gN\u0003\u0002r_&!\u00111NA7\u0005aIeN^1mS\u0012\f%oZ;nK:$X\t_2faRLwN\u001c\u0006\u0005\u0003K\n9'\r\u0004\u001f!\u0006E\u0014qS\u0019\nG\u0005M\u0014qOAG\u0003s*2aTA;\t\u0019Q\u0002C1\u0001\u0002��%!\u0011\u0011PA>\u0003m!C.Z:tS:LG\u000fJ4sK\u0006$XM\u001d\u0013eK\u001a\fW\u000f\u001c;%c)\u0019\u0011Q\u0010\u0010\u0002\rQD'o\\<t#\ra\u0012\u0011\u0011\t\u0005\u0003\u0007\u000b9ID\u0002\u001e\u0003\u000bK!!\u001d\u0010\n\t\u0005%\u00151\u0012\u0002\n)\"\u0014xn^1cY\u0016T!!\u001d\u00102\u0013\r\ny)!%\u0002\u0014\u0006udbA\u000f\u0002\u0012&\u0019\u0011Q\u0010\u00102\u000b\tjb$!&\u0003\u000bM\u001c\u0017\r\\12\u0007\u0019\nY\u0006C\u0004\u0002\u001c\u0002!\t&!(\u0002\u0017A\u0014xNY1cS2LG/\u001f\u000b\u0006i\u0005}\u0015\u0011\u0015\u0005\b\u0003\u0003\nI\n1\u00015\u0011!\tY%!'A\u0002\u00055saBAS\u0005!\u0005\u0011qU\u0001\u000f\u0019V|gnZ!ui\u0016tG/[8o!\r!\u0012\u0011\u0016\u0004\u0007\u0003\tA\t!a+\u0014\t\u0005%\u0016Q\u0016\t\u0004;\u0005=\u0016bAAY=\t1\u0011I\\=SK\u001aDq\u0001`AU\t\u0003\t)\f\u0006\u0002\u0002(\"A\u0011\u0011XAU\t\u0003\tY,A\u0003baBd\u00170\u0006\u0003\u0002>\u0006\u0015GCDA`\u0003'\f).!7\u0002^\u0006\u0005\u00181\u001d\u000b\u0007\u0003\u0003\f9-!4\u0011\tQ\u0001\u00111\u0019\t\u00041\u0005\u0015GA\u0002\u000e\u00028\n\u00071\u0004\u0003\u0006\u0002J\u0006]\u0016\u0011!a\u0002\u0003\u0017\f!\"\u001a<jI\u0016t7-\u001a\u00134!\u0011\u00017/a1\t\u0015\u0005=\u0017qWA\u0001\u0002\b\t\t.\u0001\u0006fm&$WM\\2fIQ\u0002B\u0001Y=\u0002D\"1a%a.A\u0002!BqAMA\\\u0001\u0004\t9\u000e\u0005\u0003*U\u0005\r\u0007\"\u0003\u001d\u00028B\u0005\t\u0019AAn!\u0019i2(a6\u0002X\"Q\u0011q\\A\\!\u0003\u0005\r!a6\u0002\u0019M\u001c\u0017\r\\3XK&<\u0007\u000e^:\t\u0011\u0011\u000b9\f%AA\u0002\u0019C\u0001BTA\\!\u0003\u0005\r\u0001\u0015\u0005\u000b\u0003O\fI+%A\u0005\u0002\u0005%\u0018a\u0007\u0013mKN\u001c\u0018N\\5uI\u001d\u0014X-\u0019;fe\u0012\"WMZ1vYR$C'\u0006\u0003\u0002l\n\u001dQCAAwU\u0011\ty/!>\u0011\u0007u\t\t0C\u0002\u0002tz\u0011AAT;mY.\u0012\u0011q\u001f\t\u0005\u0003s\u0014\u0019!\u0004\u0002\u0002|*!\u0011Q`A��\u0003%)hn\u00195fG.,GMC\u0002\u0003\u0002y\t!\"\u00198o_R\fG/[8o\u0013\u0011\u0011)!a?\u0003#Ut7\r[3dW\u0016$g+\u0019:jC:\u001cW\r\u0002\u0004\u001b\u0003K\u0014\ra\u0007\u0005\u000b\u0005\u0017\tI+%A\u0005\u0002\t5\u0011a\u0007\u0013mKN\u001c\u0018N\\5uI\u001d\u0014X-\u0019;fe\u0012\"WMZ1vYR$S'\u0006\u0003\u0003\u0010\tMQC\u0001B\tU\r1\u0015Q\u001f\u0003\u00075\t%!\u0019A\u000e\t\u0015\t]\u0011\u0011VI\u0001\n\u0003\u0011I\"A\u000e%Y\u0016\u001c8/\u001b8ji\u0012:'/Z1uKJ$C-\u001a4bk2$HEN\u000b\u0005\u00057\u0011y\"\u0006\u0002\u0003\u001e)\u001a\u0001+!>\u0005\ri\u0011)B1\u0001\u001c\u0011)\u0011\u0019#!+\u0012\u0002\u0013\u0005!QE\u0001\u0010CB\u0004H.\u001f\u0013eK\u001a\fW\u000f\u001c;%gU!\u00111\u001eB\u0014\t\u0019Q\"\u0011\u0005b\u00017!Q!1FAU#\u0003%\tA!\f\u0002\u001f\u0005\u0004\b\u000f\\=%I\u00164\u0017-\u001e7uIQ*B!a;\u00030\u00111!D!\u000bC\u0002mA!Ba\r\u0002*F\u0005I\u0011\u0001B\u001b\u0003=\t\u0007\u000f\u001d7zI\u0011,g-Y;mi\u0012*T\u0003\u0002B\b\u0005o!aA\u0007B\u0019\u0005\u0004Y\u0002B\u0003B\u001e\u0003S\u000b\n\u0011\"\u0001\u0003>\u0005y\u0011\r\u001d9ms\u0012\"WMZ1vYR$c'\u0006\u0003\u0003\u001c\t}BA\u0002\u000e\u0003:\t\u00071\u0004")
/* loaded from: input_file:org/platanios/tensorflow/api/ops/rnn/attention/LuongAttention.class */
public class LuongAttention<T> extends SimpleAttention<T> {
    private final Output<Object> memorySize;
    private final Output<T> memoryWeights;
    private final Function1<Output<T>, Output<T>> probabilityFn;
    private final Output<T> scaleFactor;
    private final String name;
    private final Cpackage.TF<T> evidence$1;
    private final Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> evidence$2;

    public static <T> LuongAttention<T> apply(Output<Object> output, Output<T> output2, Function1<Output<T>, Output<T>> function1, Output<T> output3, Output<Object> output4, String str, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        return LuongAttention$.MODULE$.apply(output, output2, function1, output3, output4, str, tf, lessVar);
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention, org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public Output<Object> memorySize() {
        return this.memorySize;
    }

    public Output<T> memoryWeights() {
        return this.memoryWeights;
    }

    public Function1<Output<T>, Output<T>> probabilityFn() {
        return this.probabilityFn;
    }

    public Output<T> scaleFactor() {
        return this.scaleFactor;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<Object> scoreMaskValue() {
        return super.scoreMaskValue();
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention, org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public String name() {
        return this.name;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.Attention
    public Shape keysShape(Shape shape) {
        return shape.apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0)))).$plus(memoryWeights().shape().apply(-1));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> keys(Attention.Memory<T> memory, Output<T> output) {
        if (output.rank() != 3) {
            return Math$.MODULE$.matmul(output, memoryWeights(), Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2);
        }
        Output<Object> shape = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), this.evidence$1);
        Output<T> reshape = Basic$.MODULE$.reshape(Math$.MODULE$.matmul(Basic$.MODULE$.reshape(output, Basic$.MODULE$.stack(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.constant(Implicits$.MODULE$.intToTensor(-1), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3()), shape.apply(Implicits$.MODULE$.intToIndex(-1), Predef$.MODULE$.wrapRefArray(new Indexer[0]))})), Basic$.MODULE$.stack$default$2(), Basic$.MODULE$.stack$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.reshape$default$3(), this.evidence$1, package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms()), memoryWeights(), Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2), Basic$.MODULE$.concatenate(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{shape.apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0))), Predef$.MODULE$.wrapRefArray(new Indexer[0])), Basic$.MODULE$.shape(memoryWeights(), Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), this.evidence$1).slice(Implicits$.MODULE$.intToIndex(-1), Predef$.MODULE$.wrapRefArray(new Indexer[]{NewAxis$.MODULE$}))})), Implicits$.MODULE$.intToOutput(0), Basic$.MODULE$.concatenate$default$3(), package$TF$.MODULE$.intEvTF()), Basic$.MODULE$.reshape$default$3(), this.evidence$1, package$TF$.MODULE$.intEvTF(), Predef$.MODULE$.$conforms());
        reshape.setShape(output.shape().apply(IndexerConstructionWithTwoNumbers$.MODULE$.indexerConstructionToIndex(Implicits$.MODULE$.intToIndexerConstruction(-1).$colon$colon(Implicits$.MODULE$.intToIndexerConstruction(0)))).$plus(memoryWeights().shape().apply(-1)));
        return reshape;
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> score(Output<T> output, Attention.State<T, Output<T>> state) throws InvalidArgumentException {
        int apply = output.shape().apply(-1);
        int apply2 = state.keys().shape().apply(-1);
        if (apply != apply2) {
            throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(161).append("Incompatible or unknown inner dimensions between query and keys. ").append(new StringBuilder(21).append("Query (").append(output.name()).append(") has ").append(apply).append(" units. ").toString()).append(new StringBuilder(21).append("Keys (").append(state.keys().name()).append(") have ").append(apply2).append(" units. ").toString()).append("Perhaps you need to set the number of units of the attention model ").append("to the keys' number of units.").toString());
        }
        Output<T> squeeze = Basic$.MODULE$.squeeze(Math$.MODULE$.matmul(Implicits$.MODULE$.outputBasicOps(output).expandDims(Implicits$.MODULE$.intToOutput(1)), state.keys(), Math$.MODULE$.matmul$default$3(), true, Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9(), this.evidence$1, this.evidence$2), (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{1})), Basic$.MODULE$.squeeze$default$3(), this.evidence$1);
        return scaleFactor() == null ? squeeze : scaleFactor().$times(squeeze, this.evidence$2);
    }

    @Override // org.platanios.tensorflow.api.ops.rnn.attention.SimpleAttention
    public Output<T> probability(Output<T> output, Attention.State<T, Output<T>> state) {
        return (Output) probabilityFn().apply(output);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public LuongAttention(Output<Object> output, Output<T> output2, Function1<Output<T>, Output<T>> function1, Output<T> output3, Output<Object> output4, String str, Cpackage.TF<T> tf, Predef$.less.colon.less<Function1<Function1<T, Nothing$>, Nothing$>, Function1<Function1<Cpackage.TruncatedHalf, Nothing$>, Nothing$>> lessVar) {
        super(output, output4, str, tf, lessVar);
        this.memorySize = output;
        this.memoryWeights = output2;
        this.probabilityFn = function1;
        this.scaleFactor = output3;
        this.name = str;
        this.evidence$1 = tf;
        this.evidence$2 = lessVar;
    }
}
