#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# maze.py
#
# Copyright 2017-2018 Tohgoroh Matsui All Rights Reserved.
#
from environment import Environment
import numpy as np


class Maze(Environment):
    """Sutton & Barto 4x3迷路問題。状態は (x座標, y座標) で表される。"""

    walls = np.array([[0, 0, 0, 0],
                      [0, 1, 0, 0],
                      [0, 0, 0, 0]]).T  # 壁
    terminals = np.array([[ 0, 0, 0, 1],
                          [ 0, 0, 0, 1],
                          [ 0, 0, 0, 0]]).T # 終端状態
#    rewards = np.array([[-0.02, -0.02, -0.02,  1.],
#                        [-0.02,  0.,   -0.02, -1.],
#                        [-0.02, -0.02, -0.02, -0.02]]).T    # 報酬
    rewards = np.array([[0., 0., 0.,  1.],
                        [0., 0., 0., -1.],
                        [0., 0., 0.,  0.]]).T    # 報酬
    width = len(walls)  # 迷路の幅
    height = len(walls[0])   # 迷路の高さ
    states = width * height   # 状態数
    actions = 4   # 行動数
    deterministic = True # 状態遷移が決定的

    @classmethod
    def init_state(cls):
        """初期状態を返す。"""
        return (0, 2)

    @classmethod
    def is_terminal(cls, state):
        """渡された状態が終端状態ならTrue, そうでないならFalseを返す。"""
        x, y = state[0], state[1]
        return True if cls.terminals[x, y] == 1 else False

    @classmethod
    def is_wall(cls, state):
        """渡された状態が壁ならTrue, そうでないならFalseを返す。"""
        x, y = state[0], state[1]
        return True if x < 0 or x >= cls.width or y < 0 or y >= cls.height or cls.walls[x, y] == 1 else False

    @classmethod
    def get_s(cls, state):
        """状態を状態番号に変換して返す。"""
        x, y = state[0], state[1]
        s = y * cls.width + x
        return s

    @classmethod
    def get_a(cls, action):
        """行動を行動番号に変換して返す。"""
        return action

    @classmethod
    def get_reward(cls, state):
        """報酬を返す。"""
        x, y = state[0], state[1]
        reward = cls.rewards[x, y]
        return reward

    @classmethod
    def take_action(cls, state, action):
        """行動を実行して状態を更新し、報酬と次の状態を返す。"""
        if action == 0:
            forward = cls.north(state)
            left = cls.west(state)
            right = cls.east(state)
        elif action == 1:
            forward = cls.east(state)
            left = cls.north(state)
            right = cls.south(state)
        elif action == 2:
            forward = cls.west(state)
            left = cls.south(state)
            right = cls.north(state)
        else:
            forward = cls.south(state)
            left = cls.east(state)
            right = cls.west(state)
        if cls.deterministic:
            state_ = forward
        else:
            r = np.random.random()
            if r < 0.8:
                state_ = forward
            elif r < 0.9:
                state_ = left
            else:
                state_ = right
        reward = cls.get_reward(state_)  # 報酬
        return reward, state_

    @classmethod
    def str_state(cls, state):
        """状態を表す文字列を返す。"""
        x, y = state[0], state[1]
        return '(%d, %d)' % (x, y)

    @classmethod
    def str_action(cls, action):
        """行動を表す文字列を返す。"""
        if action == 0:
            str = 'north'
        elif action == 1:
            str = 'east'
        elif action == 2:
            str = 'west'
        else:
            str = 'south'
        return str

    @classmethod
    def north(cls, state):
        """北に壁がない場合は北隣の状態を、壁がある場合は現在の状態を返す。"""
        x, y = state[0], state[1]
        x_, y_ = x, y - 1
        return (x_, y_) if not cls.is_wall((x_, y_)) else state

    @classmethod
    def east(cls, state):
        """東に壁がない場合は東隣の状態を、壁がある場合は現在の状態を返す。"""
        x, y = state[0], state[1]
        x_, y_ = x + 1, y
        return (x_, y_) if not cls.is_wall((x_, y_)) else state

    @classmethod
    def west(cls, state):
        """西に壁がない場合は西隣の状態を、壁がある場合は現在の状態を返す。"""
        x, y = state[0], state[1]
        x_, y_ = x - 1, y
        return (x_, y_) if not cls.is_wall((x_, y_)) else state

    @classmethod
    def south(cls, state):
        """南に壁がない場合は南隣の状態を、壁がある場合は現在の状態を返す。"""
        x, y = state[0], state[1]
        x_, y_ = x, y + 1
        return (x_, y_) if not cls.is_wall((x_, y_)) else state

    def __init__(self, deterministic=False):
        Maze.deterministic = deterministic
        super().__init__()
