"""Algorytm A3C (Asynchronous Advantage Actor-Critic) dla uczenia przez wzmacnianie."""

import numpy as np
import tensorflow as tf
import copy
import multiprocessing
import os
import re
import threading
from collections import Sequence
import pickle
import threading
import time
import numpy as np
import os
import six
import tensorflow as tf
import tempfile

from tensorgraph import TensorGraph
from tensorgraph import Layer
from tensorgraph import Dense
from tensorgraph import Squeeze
from tensorgraph import Flatten
from tensorgraph import BatchNorm
from tensorgraph import SoftMax
from tensorgraph import Input


class A3CLoss(Layer):
  """Ta warstwa oblicza funkcję straty dla A3C."""

  def __init__(self, value_weight, entropy_weight, **kwargs):
    super(A3CLoss, self).__init__(**kwargs)
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight

  def create_tensor(self, **kwargs):
    reward, action, prob, value, advantage = [
        layer.out_tensor for layer in self.in_layers
    ]
    prob = prob + np.finfo(np.float32).eps
    log_prob = tf.log(prob)
    policy_loss = -tf.reduce_mean(
        advantage * tf.reduce_sum(action * log_prob, axis=1))
    value_loss = tf.reduce_mean(tf.square(reward - value))
    entropy = -tf.reduce_mean(tf.reduce_sum(prob * log_prob, axis=1))
    self.out_tensor = policy_loss + self.value_weight * value_loss - self.entropy_weight * entropy
    return self.out_tensor


class A3C(object):
  """
  Implementowanie algorytmu A3C (Asynchronous Advantage Actor-Critic) dla uczenia przez wzmacnianie.

  Algorytm ten opisany jest w pracy "Asynchronous Methods for Deep Reinforcement
  Learning" (https://arxiv.org/abs/1602.01783), autorstwa Mniha i innych. Klasa ta 
  wymaga polityki tworzenia dwóch wartości wyjściowych: wektora dającego 
  prawdopodobieństwo podjęcia każdego działania oraz oszacowania funkcji wartości
  dla obecnego stanu. Optymalizuje ona oba wyjścia jednocześnie, wykorzystując
  stratę, która jest sumą trzech pojęć:

  1) straty polityki, która ma na celu maksymalizację zdyskontowanej nagrody
     za każde działanie,
  2) straty wartości, która stara się, aby oszacowanie wartości odpowiadało
     faktycznej zdyskontowanej nagrodzie osiągniętej na każdym etapie,
  3) pojęcia entropii w celu zachęcenia do poszukiwań.

  Klasa ta obsługuje tylko środowiska z przestrzeniami działań dyskretnych,
  a nie ciągłych.  Argument "action" przekazywany do środowiska jest liczbą
  całkowitą, podającą indeks czynności do wykonania.

  Klasa ta wspiera uogólnione szacowanie korzyści, opisane w dokumencie
  "High-Dimensional Continuous Control Using Generalized Advantage Estimation"
  (https://arxiv.org/abs/1506.02438) autorstwa Schulmana i innych.  Jest to metoda
  osiągania kompromisu pomiędzy obciążeniem i wariancją w szacowaniu korzyści, która
  czasami może poprawić wskaźnik konwergencji.  W celu skorygowania kompromisu
  należy użyć parametru advantage_lambda.
  """

  def __init__(self,
               env,
               max_rollout_length=20,
               discount_factor=0.99,
               advantage_lambda=0.98,
               value_weight=1.0,
               entropy_weight=0.01,
               optimizer=None,
               model_dir=None):
    """Tworzenie obiektu do optymalizacji polityki.

    Parametry
    ---------
    env: środowisko
      środowisko, z którym będzie miała miejsce interakcja
    max_rollout_length: liczba całkowita
      maksymalna długość rozwinięć do wygenerowania
    discount_factor: liczba zmiennoprzecinkowa
      współczynnik dyskontowy, który należy stosować przy obliczaniu nagród
    advantage_lambda: liczba zmiennoprzecinkowa
      parametr dla wyznaczania kompromisu pomiędzy obciążeniem a wariancją
      w uogólnionym szacowaniu korzyści
    value_weight: liczba zmiennoprzecinkowa
      współczynnik skali dla pojęcia straty wartości w funkcji straty
    entropy_weight: liczba zmiennoprzecinkowa
      współczynnik skali dla pojęcia entropii w funkcji straty
    optimizer: Optymalizator
      optymalizator do wykorzystania.  Jeśli None, użyty zostanie domyślny
      optymalizator.
    model_dir: łańcuch
      katalog, w którym zostanie zapisany model.  Jeśli None, zostanie utworzony
      katalog tymczasowy.
    """
    self._env = env
    self.max_rollout_length = max_rollout_length
    self.discount_factor = discount_factor
    self.advantage_lambda = advantage_lambda
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    self._optimizer = None
    (self._graph, self._features, self._rewards, self._actions,
     self._action_prob, self._value, self._advantages) = self.build_graph(
         None, "globalnie", model_dir)
    with self._graph._get_tf("Graf").as_default():
      self._session = tf.Session()

  def build_graph(self, tf_graph, scope, model_dir):
    """Konstruowanie obiektu TensorGraph zawierającego obliczenia polityki i strat."""
    state_shape = self._env.state_shape
    features = []
    for s in state_shape:
      features.append(Input(shape=[None] + list(s), dtype=tf.float32))
    d1 = Flatten(in_layers=features)
    d2 = Dense(
        in_layers=[d1],
        activation_fn=tf.nn.relu,
        normalizer_fn=tf.nn.l2_normalize,
        normalizer_params={"wymiary": 1},
        out_channels=64)
    d3 = Dense(
        in_layers=[d2],
        activation_fn=tf.nn.relu,
        normalizer_fn=tf.nn.l2_normalize,
        normalizer_params={"wymiary": 1},
        out_channels=32)
    d4 = Dense(
        in_layers=[d3],
        activation_fn=tf.nn.relu,
        normalizer_fn=tf.nn.l2_normalize,
        normalizer_params={"wymiary": 1},
        out_channels=16)
    d4 = BatchNorm(in_layers=[d4])
    d5 = Dense(in_layers=[d4], activation_fn=None, out_channels=9)
    value = Dense(in_layers=[d4], activation_fn=None, out_channels=1)
    value = Squeeze(squeeze_dims=1, in_layers=[value])
    action_prob = SoftMax(in_layers=[d5])

    rewards = Input(shape=(None,))
    advantages = Input(shape=(None,))
    actions = Input(shape=(None, self._env.n_actions))
    loss = A3CLoss(
        self.value_weight,
        self.entropy_weight,
        in_layers=[rewards, actions, action_prob, value, advantages])
    graph = TensorGraph(
        batch_size=self.max_rollout_length,
        graph=tf_graph,
        model_dir=model_dir)
    for f in features:
      graph._add_layer(f)
    graph.add_output(action_prob)
    graph.add_output(value)
    graph.set_loss(loss)
    graph.set_optimizer(self._optimizer)
    with graph._get_tf("Graf").as_default():
      with tf.variable_scope(scope):
        graph.build()
    return graph, features, rewards, actions, action_prob, value, advantages

  def fit(self,
          total_steps,
          max_checkpoints_to_keep=5,
          checkpoint_interval=600,
          restore=False):
    """Trenowanie polityki.

    Parametry
    ---------
    total_steps: liczba całkowita
      całkowita liczba kroków czasowych do wykonania w środowisku, przez wszystkie
	  rozwinięcia na wszystkich wątkach
    max_checkpoints_to_keep: liczba całkowita
      maksymalna liczba plików punktów kontrolnych do przechowania.  Po osiągnięciu
      tej liczby starsze pliki są usuwane.
    checkpoint_interval: liczba zmiennoprzecinkowa
      przedział czasowy, w którym zapisywane są punkty kontrolne, mierzony w sekundach
    restore: wartość logiczna
      jeśli True, przywraca model z ostatniego punktu kontrolnego i kontynuuje
      szkolenie od tego miejsca.  Jeśli False, ponownie szkoli model od podstaw.
    """
    with self._graph._get_tf("Graf").as_default():
      step_count = [0]
      workers = []
      threads = []
      for i in range(multiprocessing.cpu_count()):
        workers.append(Worker(self, i))
      self._session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      for worker in workers:
        thread = threading.Thread(
            name=worker.scope,
            target=lambda: worker.run(step_count, total_steps))
        threads.append(thread)
        thread.start()
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope="globalnie")
      saver = tf.train.Saver(variables, max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      while True:
        threads = [t for t in threads if t.isAlive()]
        if len(threads) > 0:
          threads[0].join(checkpoint_interval)
        checkpoint_index += 1
        saver.save(
            self._session, self._graph.save_file, global_step=checkpoint_index)
        if len(threads) == 0:
          break

  def predict(self, state):
    """Obliczanie przewidywań dotyczących wyników polityki dla danego stanu.

    Parametry
    ---------
    state: tablica
      stan środowiska, dla którego mają być tworzone prognozy

    Zwracane wartości
    -----------------
    tablica prawdopodobieństw działań i szacunkowa funkcja wartości
    """
    with self._graph._get_tf("Graf").as_default():
      feed_dict = self.create_feed_dict(state)
      tensors = [self._action_prob.out_tensor, self._value.out_tensor]
      results = self._session.run(tensors, feed_dict=feed_dict)
      return results[:2]

  def select_action(self,
                    state,
                    deterministic=False):
    """Wybór akcji do wykonania na podstawie stanu środowiska.

    Parametry
    ----------
    state: tablica
      stan środowiska, dla którego należy wybrać działanie
    deterministic: wartość logiczna
      jeśli True, zawsze zwraca najlepsze działanie (tj. działanie o najwyższym
      prawdopodobieństwie).  Jeśli False, losowo wybiera akcję na podstawie
      obliczonych prawdopodobieństw.
    Zwracane wartości
    -------
    indeks wybranego działania
    """
    with self._graph._get_tf("Graf").as_default():
      feed_dict = self.create_feed_dict(state)
      tensors = [self._action_prob.out_tensor]
      results = self._session.run(tensors, feed_dict=feed_dict)
      probabilities = results[0]
      if deterministic:
        return probabilities.argmax()
      else:
        return np.random.choice(
            np.arange(self._env.n_actions), p=probabilities[0])

  def restore(self):
    """Ponowne wczytanie parametrów modelu z ostatniego pliku punktu kontrolnego."""
    last_checkpoint = tf.train.latest_checkpoint(self._graph.model_dir)
    if last_checkpoint is None:
      raise ValueError("Nie znaleziono punktu kontrolnego")
    with self._graph._get_tf("Graf").as_default():
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope="globalnie")
      saver = tf.train.Saver(variables)
      saver.restore(self._session, last_checkpoint)

  def create_feed_dict(self, state):
    """Utworzenie słownika zasilającego do użycia przez predict() lub select_action()."""
    feed_dict = dict((f.out_tensor, np.expand_dims(s, axis=0))
                     for f, s in zip(self._features, state))
    return feed_dict


class Worker(object):
  """Dla każdego wątku treningowego tworzony jest obiekt Worker."""

  def __init__(self, a3c, index):
    self.a3c = a3c
    self.index = index
    self.scope = "worker%d" % index
    self.env = copy.deepcopy(a3c._env)
    self.env.reset()
    (self.graph, self.features, self.rewards, self.actions, self.action_prob,
     self.value, self.advantages) = a3c.build_graph(
        a3c._graph._get_tf("Graf"), self.scope, None)
    with a3c._graph._get_tf("Graf").as_default():
      local_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     self.scope)
      global_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      "globalnie")
      gradients = tf.gradients(self.graph.loss.out_tensor, local_vars)
      grads_and_vars = list(zip(gradients, global_vars))
      self.train_op = a3c._graph._get_tf("Optymalizator").apply_gradients(
          grads_and_vars)
      self.update_local_variables = tf.group(
          * [tf.assign(v1, v2) for v1, v2 in zip(local_vars, global_vars)])
      self.global_step = self.graph.get_global_step()

  def run(self, step_count, total_steps):
    with self.graph._get_tf("Graf").as_default():
      while step_count[0] < total_steps:
        self.a3c._session.run(self.update_local_variables)
        states, actions, rewards, values = self.create_rollout()
        self.process_rollout(states, actions, rewards, values, step_count[0])
        step_count[0] += len(actions)

  def create_rollout(self):
    """Generowanie rozwinięcia."""
    n_actions = self.env.n_actions
    session = self.a3c._session
    states = []
    actions = []
    rewards = []
    values = []

    # Generowanie rozwinięcia.
    for i in range(self.a3c.max_rollout_length):
      if self.env.terminated:
        break
      state = self.env.state
      states.append(state)
      feed_dict = self.create_feed_dict(state)
      results = session.run(
          [self.action_prob.out_tensor, self.value.out_tensor],
          feed_dict=feed_dict)
      probabilities, value = results[:2]
      action = np.random.choice(np.arange(n_actions), p=probabilities[0])
      actions.append(action)
      values.append(float(value))
      rewards.append(self.env.step(action))

    # Compute an estimate of the reward for the rest of the episode.
    if not self.env.terminated:
      feed_dict = self.create_feed_dict(self.env.state)
      final_value = self.a3c.discount_factor * float(
          session.run(self.value.out_tensor, feed_dict))
    else:
      final_value = 0.0
    values.append(final_value)
    if self.env.terminated:
      self.env.reset()
    return states, actions, np.array(rewards), np.array(values)

  def process_rollout(self, states, actions, rewards, values, step_count):
    """Szkolenie sieci w oparciu o rozwinięcia."""

    # Obliczanie zdyskontowanych nagród i korzyści.
    if len(states) == 0:
      # Tworzenie rozwinięć czasami się nie udaje w środowisku wielowątkowym
      # Przerwij przetwarzanie w przypadku deformacji
      print("Tworzenie rozwinięcia nie powiodło się. Pomijam.")    
      return

    discounted_rewards = rewards.copy()
    discounted_rewards[-1] += values[-1]
    advantages = rewards - values[:-1] + self.a3c.discount_factor * np.array(
        values[1:])
    for j in range(len(rewards) - 1, 0, -1):
      discounted_rewards[j-1] += self.a3c.discount_factor * discounted_rewards[j]
      advantages[j-1] += (
          self.a3c.discount_factor * self.a3c.advantage_lambda * advantages[j])

    # Przekształcenie działań w gorącojedynkowe.
    n_actions = self.env.n_actions
    actions_matrix = []
    for action in actions:
      a = np.zeros(n_actions)
      a[action] = 1.0
      actions_matrix.append(a)

    # Przestawienie stanów na odpowiedni zestaw tablic.
    state_arrays = [[] for i in range(len(self.features))]
    for state in states:
      for j in range(len(state)):
        state_arrays[j].append(state[j])
    
    # Budowanie słownika zasilającego i stosowanie gradientów..
    feed_dict = {}
    for f, s in zip(self.features, state_arrays):
      feed_dict[f.out_tensor] = s
    feed_dict[self.rewards.out_tensor] = discounted_rewards
    feed_dict[self.actions.out_tensor] = actions_matrix
    feed_dict[self.advantages.out_tensor] = advantages
    feed_dict[self.global_step] = step_count
    self.a3c._session.run(self.train_op, feed_dict=feed_dict)

  def create_feed_dict(self, state):
    """Utworzenie słownika zasilającego do użycia podczas rozwijania."""
    feed_dict = dict((f.out_tensor, np.expand_dims(s, axis=0))
                     for f, s in zip(self.features, state))
    return feed_dict
