import collections
import random
import math
from tic_tac_toe import has_winner, available_moves, apply_move


def monte_carlo_sample(board_state, side):
    """Przykadowy pojedynczy ruch dla biecego stanu board_state i strony. Ruchy s wprowadzane do biecego stanu board_state a do momentu 
     osignicia stanu kocowego. Nastpnie zwracany jest wynik i pierwszy ruch wykonany po to, aby si tam dosta.

    Argumenty:
        board_state (krotka liczb cakowitych o wymiarach 3x3): stan planszy
        side (int): strona, po ktrej gramy. +1 d;a gracza plus, -1 dla gracza minus

    Funkcja zwraca:
        (result(int), move(int,int)): Wynik z tej symulacji, +1 dla wygranej gracza plus -1, dla wygranej  
            gracza minus, 0 w przypadku remisu.
    """
    result = has_winner(board_state)
    if result != 0:
        return result, None
    moves = list(available_moves(board_state))
    if not moves:
        return 0, None

    # wybr losowego ruchu
    move = random.choice(moves)
    result, next_move = monte_carlo_sample(apply_move(board_state, move, side), -side)
    return result, move


def monte_carlo_tree_search(board_state, side, number_of_samples):
    """Oceniaj najlepszego z aktualnego stanu board_state dla okrelonej strony, z wykorzystaniem prbkowania Monte Carlo.

    Argumenty:
        board_state (krotka liczb cakowitych o wymiarach 3x3): stan planszy
        side (int): strona, po ktrej gramy. +1 dla gracza plus, -1 dla gracza minus
        number_of_samples (int): liczba prbek symulacji do uruchomienia. Im wiksza liczba, 
            tym lepsze oszacowanie pozycji.

    Funkcja zwraca:
        (result(int), move(int,int)): redni wynik dla najlepszego ruchu z tej pozycji oraz wykonany ruch.
    """
    move_wins = collections.defaultdict(int)
    move_samples = collections.defaultdict(int)
    for _ in range(number_of_samples):
        result, move = monte_carlo_sample(board_state, side)
        # zapamitanie wyniku i liczby prb wykonania tego ruchu
        if result == side:
            move_wins[move] += 1
        move_samples[move] += 1

    # wybranie ruchu z najlepszym rednim wynikiem:
    move = max(move_wins, key=lambda x: move_wins.get(x) / move_samples[move])

    return move_wins[move] / move_samples[move], move


def _upper_confidence_bounds(payout, samples_for_this_machine, log_total_samples):
    return payout / samples_for_this_machine + math.sqrt((2 * log_total_samples) / samples_for_this_machine)


def monte_carlo_tree_search_uct(board_state, side, number_of_samples):
    """Ocena najlepszego z aktualnego stanu board_state dla okrelonej strony, z wykorzystaniem prbkowania Monte Carlo z grnym 
    poziomem zaufania dla drzew.

    Argumenty:
        board_state (krotka liczb cakowitych o wymiarach 3x3): stan planszy
        side (int): strona, po ktrej gramy. +1 dla gracza plus, -1 dla gracza minus
        number_of_samples (int): liczba prbek symulacji do uruchomienia z biecej pozycji. Im wiksza liczba 
            tym lepsze oszacowanie pozycji.

    Funkcja zwraca:
        (result(int), move(int,int)): redni wynik dla najlepszego ruchu z tej pozycji oraz wykonany ruch.
    """
    state_results = collections.defaultdict(float)
    state_samples = collections.defaultdict(float)

    for _ in range(number_of_samples):
        current_side = side
        current_board_state = board_state
        first_unvisited_node = True
        rollout_path = []
        result = 0

        while result == 0:
            move_states = {move: apply_move(current_board_state, move, current_side)
                           for move in available_moves(current_board_state)}

            if not move_states:
                result = 0
                break

            if all((state in state_samples) for _, state in move_states):
                log_total_samples = math.log(sum(state_samples[s] for s in move_states.values()))
                move, state = max(move_states, key=lambda _, s: _upper_confidence_bounds(state_results[s],
                                                                                         state_samples[s],
                                                                                         log_total_samples))
            else:
                move = random.choice(list(move_states.keys()))

            current_board_state = move_states[move]

            if first_unvisited_node:
                rollout_path.append((current_board_state, current_side))
                if current_board_state not in state_samples:
                    first_unvisited_node = False

            current_side = -current_side

            result = has_winner(current_board_state)

        for path_board_state, path_side in rollout_path:
            state_samples[path_board_state] += 1.
            result *= path_side
            # normalizacja wynikw do zakresu od 0 do 1. Przed normaliczaj wyniki s w zakresie od -1 do 1 
            result /= 2.
            result += .5
            state_results[path_board_state] += result

    move_states = {move: apply_move(board_state, move, side) for move in available_moves(board_state)}

    move = max(move_states, key=lambda x: state_results[move_states[x]] / state_samples[move_states[x]])

    return state_results[move_states[move]] / state_samples[move_states[move]], move


if __name__ == '__main__':
    board_state = ((1, 0, -1),
                   (1, 0, 0),
                   (0, -1, 0))

    print(monte_carlo_tree_search_uct(board_state, -1, 10000))
