package org.platanios.tensorflow.examples;

import com.typesafe.scalalogging.Logger;
import com.typesafe.scalalogging.Logger$;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.implicits.helpers.DataTypeAuxToDataType$;
import org.platanios.tensorflow.api.implicits.helpers.DataTypeToOutput$;
import org.platanios.tensorflow.api.implicits.helpers.OutputToTensor$;
import org.platanios.tensorflow.api.learn.SupervisedTrainableModel;
import org.platanios.tensorflow.api.learn.hooks.Hook;
import org.platanios.tensorflow.api.learn.layers.Input;
import org.platanios.tensorflow.api.learn.layers.Layer;
import org.platanios.tensorflow.api.learn.layers.rnn.RNN$;
import org.platanios.tensorflow.api.learn.layers.rnn.cell.BasicLSTMCell$;
import org.platanios.tensorflow.api.learn.layers.rnn.cell.DropoutWrapper;
import org.platanios.tensorflow.api.learn.layers.rnn.cell.DropoutWrapper$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.control_flow.WhileLoopVariable$;
import org.platanios.tensorflow.api.ops.io.data.Dataset;
import org.platanios.tensorflow.api.ops.io.data.PrefetchDataset;
import org.platanios.tensorflow.api.ops.io.data.RepeatDataset;
import org.platanios.tensorflow.api.ops.rnn.cell.DropoutWrapper$Supported$;
import org.platanios.tensorflow.api.ops.rnn.cell.package$LSTMState$;
import org.platanios.tensorflow.api.package$;
import org.platanios.tensorflow.api.package$tf$;
import org.platanios.tensorflow.api.package$tf$learn$;
import org.platanios.tensorflow.api.tensors.Tensor;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.data.text.PTBDataset;
import org.platanios.tensorflow.data.text.PTBLoader$;
import org.slf4j.LoggerFactory;
import scala.Function0;
import scala.Predef$;
import scala.Some;
import scala.runtime.BoxesRunTime;

/* compiled from: RNNTutorialUsingPTB.scala */
/* loaded from: input_file:org/platanios/tensorflow/examples/RNNTutorialUsingPTB$.class */
public final class RNNTutorialUsingPTB$ {
    public static final RNNTutorialUsingPTB$ MODULE$ = null;
    private final Logger logger;
    private final int batchSize;
    private final int numSteps;
    private final int prefetchSize;
    private final DataType dataType;
    private final int vocabularySize;
    private final int numHidden;
    private final int numLayers;
    private final float dropoutKeepProbability;
    private final SupervisedTrainableModel<Tensor, Output, DataType, Shape, Output, Tensor, Output, DataType, Shape, Output> model;

    static {
        new RNNTutorialUsingPTB$();
    }

    public int batchSize() {
        return this.batchSize;
    }

    public int numSteps() {
        return this.numSteps;
    }

    public int prefetchSize() {
        return this.prefetchSize;
    }

    public DataType dataType() {
        return this.dataType;
    }

    public int vocabularySize() {
        return this.vocabularySize;
    }

    public int numHidden() {
        return this.numHidden;
    }

    public int numLayers() {
        return this.numLayers;
    }

    public float dropoutKeepProbability() {
        return this.dropoutKeepProbability;
    }

    public void main(String[] strArr) {
        PTBDataset load = PTBLoader$.MODULE$.load(Paths.get("datasets/PTB", new String[0]), PTBLoader$.MODULE$.load$default$2());
        package$ package_ = package$.MODULE$;
        RepeatDataset.RepeatDatasetOps datasetToRepeatDatasetOps = package$.MODULE$.datasetToRepeatDatasetOps(PTBLoader$.MODULE$.tokensToBatchedTFDataset(load.train(), batchSize(), numSteps(), "TrainDataset"));
        PrefetchDataset.PrefetchDatasetOps datasetToPrefetchDatasetOps = package_.datasetToPrefetchDatasetOps(datasetToRepeatDatasetOps.repeat(datasetToRepeatDatasetOps.repeat$default$1(), datasetToRepeatDatasetOps.repeat$default$2()));
        Dataset prefetch = datasetToPrefetchDatasetOps.prefetch(prefetchSize(), datasetToPrefetchDatasetOps.prefetch$default$2());
        Path path = Paths.get("temp/rnn-ptb", new String[0]);
        package$tf$learn$.MODULE$.InMemoryEstimator().apply(package$.MODULE$.supervisedTrainableModelToModelFunction(this.model), package$tf$learn$.MODULE$.Configuration().apply(new Some(path), package$tf$learn$.MODULE$.Configuration().apply$default$2(), package$tf$learn$.MODULE$.Configuration().apply$default$3(), package$tf$learn$.MODULE$.Configuration().apply$default$4()), package$tf$learn$.MODULE$.StopCriteria().apply(package$tf$learn$.MODULE$.StopCriteria().apply$default$1(), new Some(BoxesRunTime.boxToLong(100000L)), package$tf$learn$.MODULE$.StopCriteria().apply$default$3(), package$tf$learn$.MODULE$.StopCriteria().apply$default$4(), package$tf$learn$.MODULE$.StopCriteria().apply$default$5(), package$tf$learn$.MODULE$.StopCriteria().apply$default$6(), package$tf$learn$.MODULE$.StopCriteria().apply$default$7()), Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new Hook[]{package$tf$learn$.MODULE$.LossLogger().apply(package$tf$learn$.MODULE$.LossLogger().apply$default$1(), package$tf$learn$.MODULE$.LossLogger().apply$default$2(), package$tf$learn$.MODULE$.StepHookTrigger().apply(10), package$tf$learn$.MODULE$.LossLogger().apply$default$4(), package$tf$learn$.MODULE$.LossLogger().apply$default$5()), package$tf$learn$.MODULE$.StepRateLogger().apply(false, path, package$tf$learn$.MODULE$.StepHookTrigger().apply(100), package$tf$learn$.MODULE$.StepRateLogger().apply$default$4(), package$tf$learn$.MODULE$.StepRateLogger().apply$default$5()), package$tf$learn$.MODULE$.SummarySaver().apply(path, package$tf$learn$.MODULE$.StepHookTrigger().apply(10), package$tf$learn$.MODULE$.SummarySaver().apply$default$3(), package$tf$learn$.MODULE$.SummarySaver().apply$default$4()), package$tf$learn$.MODULE$.CheckpointSaver().apply(path, package$tf$learn$.MODULE$.StepHookTrigger().apply(1000), package$tf$learn$.MODULE$.CheckpointSaver().apply$default$3(), package$tf$learn$.MODULE$.CheckpointSaver().apply$default$4())})), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$5(), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$6(), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$7(), package$tf$learn$.MODULE$.TensorBoardConfig().apply(path, package$tf$learn$.MODULE$.TensorBoardConfig().apply$default$2(), package$tf$learn$.MODULE$.TensorBoardConfig().apply$default$3(), 1), package$tf$learn$.MODULE$.InMemoryEstimator().apply$default$9()).train(new RNNTutorialUsingPTB$$anonfun$main$1(prefetch), package$tf$learn$.MODULE$.StopCriteria().apply(package$tf$learn$.MODULE$.StopCriteria().apply$default$1(), new Some(BoxesRunTime.boxToLong(10000L)), package$tf$learn$.MODULE$.StopCriteria().apply$default$3(), package$tf$learn$.MODULE$.StopCriteria().apply$default$4(), package$tf$learn$.MODULE$.StopCriteria().apply$default$5(), package$tf$learn$.MODULE$.StopCriteria().apply$default$6(), package$tf$learn$.MODULE$.StopCriteria().apply$default$7()));
    }

    private RNNTutorialUsingPTB$() {
        MODULE$ = this;
        this.logger = Logger$.MODULE$.apply(LoggerFactory.getLogger("Tutorials / RNN-PTB"));
        this.batchSize = 20;
        this.numSteps = 20;
        this.prefetchSize = 10;
        this.dataType = package$.MODULE$.FLOAT32();
        this.vocabularySize = 10000;
        this.numHidden = 200;
        this.numLayers = 1;
        this.dropoutKeepProbability = 0.5f;
        Input apply = package$tf$learn$.MODULE$.Input().apply(package$.MODULE$.INT32(), package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{-1, -1})), package$tf$learn$.MODULE$.Input().apply$default$3(), DataTypeAuxToDataType$.MODULE$.dataTypeAuxToDataType(), DataTypeToOutput$.MODULE$.dataTypeToOutput(), OutputToTensor$.MODULE$.outputToTensor(), package$.MODULE$.tensorDataHelper());
        Input apply2 = package$tf$learn$.MODULE$.Input().apply(package$.MODULE$.INT32(), package$.MODULE$.Shape().apply(Predef$.MODULE$.wrapIntArray(new int[]{-1, -1})), package$tf$learn$.MODULE$.Input().apply$default$3(), DataTypeAuxToDataType$.MODULE$.dataTypeAuxToDataType(), DataTypeToOutput$.MODULE$.dataTypeToOutput(), OutputToTensor$.MODULE$.outputToTensor(), package$.MODULE$.tensorDataHelper());
        DropoutWrapper apply3 = DropoutWrapper$.MODULE$.apply("DropoutCell", BasicLSTMCell$.MODULE$.apply("LSTMCell", numHidden(), package$.MODULE$.FLOAT32(), 0.0f, BasicLSTMCell$.MODULE$.apply$default$5(), BasicLSTMCell$.MODULE$.apply$default$6(), BasicLSTMCell$.MODULE$.apply$default$7()), 1.0E-5f, DropoutWrapper$.MODULE$.apply$default$4(), DropoutWrapper$.MODULE$.apply$default$5(), DropoutWrapper$.MODULE$.apply$default$6(), WhileLoopVariable$.MODULE$.outputWhileLoopVariable(), package$LSTMState$.MODULE$.lstmStateWhileLoopVariable(WhileLoopVariable$.MODULE$.outputWhileLoopVariable()), DropoutWrapper$Supported$.MODULE$.outputSupported(), DropoutWrapper$Supported$.MODULE$.lstmStateSupported());
        RNN$.MODULE$.apply$default$3();
        this.model = package$tf$learn$.MODULE$.Model().apply(apply, ((Layer) package$tf$learn$.MODULE$.device("/device:CPU:0", new RNNTutorialUsingPTB$$anonfun$1())).$greater$greater(package$tf$learn$.MODULE$.Dropout().apply("Embedding/Dropout", dropoutKeepProbability(), package$tf$learn$.MODULE$.Dropout().apply$default$3(), package$tf$learn$.MODULE$.Dropout().apply$default$4())).$greater$greater(RNN$.MODULE$.apply("RNN", apply3, (Function0) null, false, RNN$.MODULE$.apply$default$5(), RNN$.MODULE$.apply$default$6(), RNN$.MODULE$.apply$default$7(), WhileLoopVariable$.MODULE$.outputWhileLoopVariable(), package$LSTMState$.MODULE$.lstmStateWhileLoopVariable(WhileLoopVariable$.MODULE$.outputWhileLoopVariable()))).$greater$greater(RNNTutorialUsingPTB$RNNOutputLayer$.MODULE$), apply2, package$tf$learn$.MODULE$.SequenceLoss().apply("Loss/SequenceLoss", false, true, package$tf$learn$.MODULE$.SequenceLoss().apply$default$4(), package$tf$learn$.MODULE$.SequenceLoss().apply$default$5()).$greater$greater(package$tf$learn$.MODULE$.Sum().apply("Loss/Sum")).$greater$greater(package$tf$learn$.MODULE$.ScalarSummary().apply("Loss/Summary", "Loss", package$tf$learn$.MODULE$.ScalarSummary().apply$default$3(), package$tf$learn$.MODULE$.ScalarSummary().apply$default$4())), package$tf$.MODULE$.train().GradientDescent().apply(1.0d, package$tf$.MODULE$.train().GradientDescent().apply$default$2(), package$tf$.MODULE$.train().GradientDescent().apply$default$3(), package$tf$.MODULE$.train().GradientDescent().apply$default$4(), package$tf$.MODULE$.train().GradientDescent().apply$default$5(), package$tf$.MODULE$.train().GradientDescent().apply$default$6(), package$tf$.MODULE$.train().GradientDescent().apply$default$7()), package$tf$learn$.MODULE$.ClipGradientsByGlobalNorm().apply(5.0f));
    }
}
