package org.deeplearning4j.rl4j.mdp;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.IntegerActionSchema;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.nd4j.linalg.api.rng.Random;

/* loaded from: input_file:org/deeplearning4j/rl4j/mdp/TMazeEnvironment.class */
public class TMazeEnvironment implements Environment<Integer> {
    private static final double BAD_MOVE_REWARD = -0.1d;
    private static final double GOAL_REWARD = 4.0d;
    private static final double TRAP_REWARD = -4.0d;
    private static final double BRANCH_REWARD = 1.0d;
    private static final int NUM_ACTIONS = 4;
    private static final int ACTION_LEFT = 0;
    private static final int ACTION_RIGHT = 1;
    private static final int ACTION_UP = 2;
    private static final int ACTION_DOWN = 3;
    private final int lengthOfMaze;
    private final Random rnd;
    private final Schema<Integer> schema;
    private int currentLocation;
    private boolean hasNavigatedToBranch;
    private boolean hasNavigatedToSolution;
    private boolean isSolutionUp;
    boolean episodeFinished;

    public boolean hasNavigatedToSolution() {
        return this.hasNavigatedToSolution;
    }

    public TMazeEnvironment(int i, Random random) {
        this.lengthOfMaze = i;
        this.rnd = random;
        this.schema = new Schema<>(new IntegerActionSchema(4, 1, random));
    }

    @Override // org.deeplearning4j.rl4j.environment.Environment
    public Map<String, Object> reset() {
        this.episodeFinished = false;
        this.currentLocation = 0;
        this.hasNavigatedToBranch = false;
        this.isSolutionUp = this.rnd.nextBoolean();
        return new HashMap<String, Object>() { // from class: org.deeplearning4j.rl4j.mdp.TMazeEnvironment.1
            {
                double[] dArr = new double[5];
                dArr[0] = 1.0d;
                dArr[1] = 0.0d;
                dArr[2] = 0.0d;
                dArr[3] = TMazeEnvironment.this.isSolutionUp ? TMazeEnvironment.BRANCH_REWARD : 0.0d;
                dArr[4] = TMazeEnvironment.this.isSolutionUp ? 0.0d : TMazeEnvironment.BRANCH_REWARD;
                put("data", dArr);
            }
        };
    }

    @Override // org.deeplearning4j.rl4j.environment.Environment
    public StepResult step(Integer num) {
        boolean z = this.currentLocation == this.lengthOfMaze - 1;
        double d = 0.0d;
        if (!this.episodeFinished) {
            switch (num.intValue()) {
                case 0:
                    d = -0.1d;
                    if (this.currentLocation > 0) {
                        this.currentLocation--;
                        break;
                    }
                    break;
                case 1:
                    if (!z) {
                        this.currentLocation++;
                        break;
                    } else {
                        d = -0.1d;
                        break;
                    }
                case 2:
                    if (!z) {
                        d = -0.1d;
                        break;
                    } else {
                        d = this.isSolutionUp ? GOAL_REWARD : TRAP_REWARD;
                        this.hasNavigatedToSolution = this.isSolutionUp;
                        this.episodeFinished = true;
                        break;
                    }
                case 3:
                    if (!z) {
                        d = -0.1d;
                        break;
                    } else {
                        d = !this.isSolutionUp ? GOAL_REWARD : TRAP_REWARD;
                        this.hasNavigatedToSolution = !this.isSolutionUp;
                        this.episodeFinished = true;
                        break;
                    }
            }
        }
        boolean z2 = this.currentLocation == this.lengthOfMaze - 1;
        if (!this.hasNavigatedToBranch && z2) {
            d += BRANCH_REWARD;
            this.hasNavigatedToBranch = true;
        }
        final double[] dArr = z2 ? new double[]{0.0d, 0.0d, BRANCH_REWARD, -1.0d, -1.0d} : new double[]{0.0d, BRANCH_REWARD, 0.0d, -1.0d, -1.0d};
        return new StepResult(new HashMap<String, Object>() { // from class: org.deeplearning4j.rl4j.mdp.TMazeEnvironment.2
            {
                put("data", dArr);
            }
        }, d, this.episodeFinished);
    }

    @Override // org.deeplearning4j.rl4j.environment.Environment
    public void close() {
    }

    @Override // org.deeplearning4j.rl4j.environment.Environment
    public Schema<Integer> getSchema() {
        return this.schema;
    }

    @Override // org.deeplearning4j.rl4j.environment.Environment
    public boolean isEpisodeFinished() {
        return this.episodeFinished;
    }
}
