#!/usr/bin/env python3
"""
Solver using MCTS and trained model
"""
import time
import argparse
import random
import logging
import datetime
import collections
import csv

from tqdm import tqdm
import seaborn as sns
import matplotlib.pylab as plt
import torch

from libcube import cubes
from libcube import model
from libcube import mcts

log = logging.getLogger("solver")


DataPoint = collections.namedtuple("DataPoint", field_names=(
    'start_dt', 'stop_dt', 'duration', 'depth', 'scramble', 'is_solved', 'solve_steps', 'sol_len_naive', 'sol_len_bfs',
    'depth_max', 'depth_mean'
))


DEFAULT_MAX_SECONDS = 60
PLOT_MAX_DEPTHS = 50
PLOT_TASKS = 20


def generate_task(env, depth):
    res = []
    prev_a = None
    for _ in range(depth):
        a = env.sample_action(prev_action=prev_a)
        res.append(a.value)
        prev_a = a
    return res


def gather_data(cube_env, net, max_seconds, max_steps, max_depth, samples_per_depth, batch_size, device):
    """
    Try to solve lots of cubes to get data
    :param cube_env: CubeEnv
    :param net: model to be used
    :param max_seconds: time limit per cube in seconds
    :param max_steps: limit of steps, if not None it superseeds max_seconds
    :param max_depth: maximum depth of scramble
    :param samples_per_depth: how many cubes of every depth to generate
    :param device: torch.device
    :return: list DataPoint entries
    """
    result = []
    try:
        for depth in range(1, max_depth+1):
            solved_count = 0
            for task_idx in tqdm(range(samples_per_depth)):
                start_dt = datetime.datetime.utcnow()
                task = generate_task(cube_env, depth)
                tree, solution = solve_task(cube_env, task, net, cube_idx=task_idx, max_seconds=max_seconds,
                                            max_steps=max_steps, device=device, quiet=True, batch_size=batch_size)
                is_solved = solution is not None
                stop_dt = datetime.datetime.utcnow()
                duration = (stop_dt - start_dt).total_seconds()
                scramble = " ".join(map(str, task))
                tree_depth_stats = tree.get_depth_stats()
                sol_len_naive, sol_len_bfs = -1, -1
                if is_solved:
                    sol_len_naive = len(solution)
                    sol_len_bfs = len(tree.find_solution())
                data_point = DataPoint(start_dt=start_dt, stop_dt=stop_dt, duration=duration, depth=depth,
                                       scramble=scramble, is_solved=is_solved, solve_steps=len(tree),
                                       sol_len_naive=sol_len_naive, sol_len_bfs=sol_len_bfs,
                                       depth_max=tree_depth_stats['max'], depth_mean=tree_depth_stats['mean'])
                result.append(data_point)
                if is_solved:
                    solved_count += 1
            log.info("Przetworzono do gbokoci %d, rozwizano %d/%d (%.2f%%)", depth, solved_count, samples_per_depth,
                     100.0*solved_count/samples_per_depth)
    except KeyboardInterrupt:
        log.info("Przerwanie! Uzyskano %d prbek danych, ktre mona uyc", len(result))
    return result


def save_output(data, output_file):
    with open(output_file, "wt", encoding='utf-8') as fd:
        writer = csv.writer(fd)
        writer.writerow(['start_dt', 'stop_dt', 'duration', 'depth', 'scramble', 'is_solved', 'solve_steps',
                         'sol_len_naive', 'sol_len_bfs', 'tree_depth_max', 'tree_depth_mean'])
        for dp in data:
            writer.writerow([
                dp.start_dt.isoformat(),
                dp.stop_dt.isoformat(),
                dp.duration,
                dp.depth,
                dp.scramble,
                int(dp.is_solved),
                dp.solve_steps,
                dp.sol_len_naive,
                dp.sol_len_bfs,
                dp.depth_max,
                dp.depth_mean
            ])


def solve_task(env, task, net, cube_idx=None, max_seconds=DEFAULT_MAX_SECONDS, max_steps=None,
               device=torch.device("cpu"), quiet=False, batch_size=1):
    if not quiet:
        log_prefix = "" if cube_idx is None else "cube %d: " % cube_idx
        log.info("%sZadanie %s, trwa rozwizywanie...", log_prefix, task)
    cube_state = env.scramble(map(env.action_enum, task))
    tree = mcts.MCTS(env, cube_state, net, device=device)
    step_no = 0
    ts = time.time()

    while True:
        if batch_size > 1:
            solution = tree.search_batch(batch_size)
        else:
            solution = tree.search()
        if solution:
            if not quiet:
                log.info("W kroku %d znaleziono stan docelowy, rozwijanie. Prdko: %.2f wyszukiwa/s",
                         step_no, (step_no*batch_size) / (time.time() - ts))
                log.info("Wysoko drzewa: %s", tree.get_depth_stats())
                bfs_solution = tree.find_solution()
                log.info("Rozwizania: proste %d, bfs %d", len(solution), len(bfs_solution))
                log.info("BFS: %s", bfs_solution)
                log.info("Proste: %s", solution)
#                tree.dump_solution(solution)
#                tree.dump_solution(bfs_solution)
#                tree.dump_root()
#                log.info("Tree: %s", tree)
            return tree, solution
        step_no += 1
        if max_steps is not None:
            if step_no > max_steps:
                if not quiet:
                    log.info("Osignito maks. liczb krokw, kostka nie zostaa uoona. "
                             "Wykonano %d wyszukiwa, prdko: %.2f wyszukiwa/s",
                             step_no, (step_no*batch_size) / (time.time() - ts))
                    log.info("Tree depths: %s", tree.get_depth_stats())
                return tree, None
        elif time.time() - ts > max_seconds:
            if not quiet:
                log.info("Czas upyn, kostka nie zostaa uoona. Wykonano %d wyszukiwa, prdko: %.2f wyszukiwa/s.",
                         step_no, (step_no*batch_size) / (time.time() - ts))
                log.info("Wysoko drzewa: %s", tree.get_depth_stats())
            return tree, None


def produce_plots(data, prefix, max_seconds, max_steps):
    data_solved = [(dp.depth, int(dp.is_solved)) for dp in data]
    data_steps = [(dp.depth, dp.solve_steps) for dp in data if dp.is_solved]

    if max_steps is not None:
        suffix = "(maks. liczba krokw: %d)" % max_steps
    else:
        suffix = "(ograniczenie czasowe: %d sekund)" % max_seconds

    sns.set()
    d, v = zip(*data_solved)
    plot = sns.lineplot(d, v)
    plot.set_title("Wspczynnik rozwizywania w porwnaniu z gebokoci %s" % suffix)
    plot.get_figure().savefig(prefix + "-solve_vs_depth.png")

    plt.clf()
    d, v = zip(*data_steps)
    plot = sns.lineplot(d, v)
    plot.set_title("Liczba krokw w porwnaniu z gbokoci %s" % suffix)
    plot.get_figure().savefig(prefix + "-steps_vs_depth.png")


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--env", required=True, help="Typ rodowiska do trenowania, wspierane typy=%s" % cubes.names())
    parser.add_argument("-m", "--model", required=True, help="Plik modelu do wczytania, musi odpowiada typowi rodowiska")
    parser.add_argument("--max-time", type=int, default=DEFAULT_MAX_SECONDS,
                        help="Limit czasu dla zadania (w sekundach), domylnie=%s" % DEFAULT_MAX_SECONDS)
    parser.add_argument("--max-steps", type=int, help="Limit wyszukiwa dla algorytmu MCTS. "
                                                      "Jeli podano, zastpuje opcj --max-time")
    parser.add_argument("--max-depth", type=int, default=PLOT_MAX_DEPTHS,
                        help="Maksymalny poziom zoonoci dla wykresw i danych, domylnie=%s" % PLOT_MAX_DEPTHS)
    parser.add_argument("--samples", type=int, default=PLOT_TASKS,
                        help="Licznik testw, domylnie=%s" % PLOT_TASKS)
    parser.add_argument("-b", "--batch", type=int, default=1, help="Rozmiar paczki uywanej podczas wyszukiwania, domylnie=1")
    parser.add_argument("--cuda", default=False, action="store_true", help="Opcja CUDA")
    parser.add_argument("--seed", type=int, default=42, help="Ziarno generatora; jeli opcja rwna 0, ziarno nie zostanie uyte. domylnie=42")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("-i", "--input", help="Plik tekstowy z permutacjami dla kostek, ktre maj zosta uoone, "
                                             "prawdopodobnie wygenerowany przez gen_cubes.py")
    group.add_argument("-p", "--perm", help="Permutacje w postaci listy akcji oddzielnych przecinkami")
    group.add_argument("-r", "--random", metavar="DEPTH", type=int, help="Losowy ukad wstpny o okrelonym poziomie zoonoci")
    group.add_argument("--plot", metavar="PREFIX", help="Generowanie wykresw prezentujcych dokadno procesu rozwizywania")
    group.add_argument("-o", "--output", help="Zapisywanie wynikw testu do pliku CSV o podanej nazwie")
    args = parser.parse_args()

    if args.seed:
        random.seed(args.seed)
    device = torch.device("cuda" if args.cuda else "cpu")

    cube_env = cubes.get(args.env)
    log.info("rodowisko %s", cube_env)
    assert isinstance(cube_env, cubes.CubeEnv)              

    net = model.Net(cube_env.encoded_shape, len(cube_env.action_enum)).to(device)
    net.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage))
    net.eval()
    log.info("Sie: %s", args.model)

    if args.random is not None:
        task = generate_task(cube_env, args.random)
        solve_task(cube_env, task, net, max_seconds=args.max_time, max_steps=args.max_steps, device=device,
                   batch_size=args.batch)
    elif args.perm is not None:
        task = list(map(int, args.perm.split(',')))
        solve_task(cube_env, task, net, max_seconds=args.max_time, max_steps=args.max_steps, device=device,
                   batch_size=args.batch)
    elif args.input is not None:
        log.info("Szyfrowanie: %s", args.input)
        count = 0
        solved = 0
        with open(args.input, 'rt', encoding='utf-8') as fd:
            for idx, l in enumerate(fd):
                task = list(map(int, l.strip().split(',')))
                _, solution  = solve_task(cube_env, task, net, cube_idx=idx, max_seconds=args.max_time,
                                          max_steps=args.max_steps, device=device, batch_size=args.batch)
                if solution is not None:
                    solved += 1
                count += 1
        log.info("Uoono %d z %d kostek, co oznacza %.2f%% wspczynnik sukcesu", solved, count, 100*solved / count)
    elif args.plot is not None:
        log.info("Tworzenie wykresw z przedrostkiem %s", args.plot)
        data = gather_data(cube_env, net, args.max_time, args.max_steps, args.max_depth, args.samples,
                           args.batch, device)
        produce_plots(data, args.plot, args.max_time)
    elif args.output is not None:
        data = gather_data(cube_env, net, args.max_time, args.max_steps, args.max_depth, args.samples,
                           args.batch, device)
        save_output(data, args.output)
        pass
