"""Uczenie DQN/DDQN by rozwiązać problem wózka z kijkiem (CartPole-v0)

"""

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from collections import deque
import numpy as np
import random
import argparse
import gym
from gym import wrappers, logger

class DQNAgent:
    def __init__(self,
                 state_space, 
                 action_space, 
                 episodes=500):
        """Agent DQN w środowisku CartPole-v0

        Argumenty:
            state_space (tensor): przestrzeń stanów
            action_space (tensor): przestrzeń akcji
            episodes (int): liczba epizodów uczenia
        """

        self.action_space = action_space

        # bufor doświadczeń
        self.memory = []

        # współczynnik dyskontowy
        self.gamma = 0.9

        # początkowo 90% eksploracji, 10% eksploatacji
        self.epsilon = 1.0
        # iteracyjne stosowanie zaniku aż do osiągnięcia 10% eksploracji
        # i 90% eksploatacji
        self.epsilon_min = 0.1
        self.epsilon_decay = self.epsilon_min / self.epsilon
        self.epsilon_decay = self.epsilon_decay ** \
                             (1. / float(episodes))

        # nazwa pliku wag sieci Q
        self.weights_file = 'dqn_cartpole.h5'
        # sieć Q do uczenia
        n_inputs = state_space.shape[0]
        n_outputs = action_space.n
        self.q_model = self.build_model(n_inputs, n_outputs)
        self.q_model.compile(loss='mse', optimizer=Adam())
        # docelowa sieć Q
        self.target_q_model = self.build_model(n_inputs, n_outputs)
        # kopiowanie parametrów sieci Q do docelowej sieci Q
        self.update_weights()

        self.replay_counter = 0

    
    def build_model(self, n_inputs, n_outputs):
        """sieć Q to MLP 256-256-256

        Argumenty:
            n_inputs (int): wymiar wejścia
            n_outputs (int): wymiar wyjścia

        Zwraca:
            q_model (Model): DQN
        """
        inputs = Input(shape=(n_inputs, ), name='state')
        x = Dense(256, activation='relu')(inputs)
        x = Dense(256, activation='relu')(x)
        x = Dense(256, activation='relu')(x)
        x = Dense(n_outputs,
                  activation='linear', 
                  name='action')(x)
        q_model = Model(inputs, x)
        q_model.summary()
        return q_model


    def save_weights(self):
        """zapis parametrów sieci Q do pliku"""
        self.q_model.save_weights(self.weights_file)


    def update_weights(self):
        """kopiowanie wyuczonych parametrów sieci Q do docelowej sieci Q"""
        self.target_q_model.set_weights(self.q_model.get_weights())


    def act(self, state):
        """strategia eps-zachłanna
        Zwraca:
            action (tensor): akcja do wykonania
        """

        if np.random.rand() < self.epsilon:
            # eksploruj — wykonaj losową akcję
            return self.action_space.sample()

        # eksploatacja
        q_values = self.q_model.predict(state)
        # wybierz akcję o maksymalnej wartości Q
        action = np.argmax(q_values[0])
        return action


    def remember(self, state, action, reward, next_state, done):
        """zachowaj doświadczenia w buforze do ponownego użycia
        Argumenty:
            state (tensor): stan środowiska
            action (tensor): akcja agenta
            reward (float): nagroda otrzymana po wykonaniu akcji w stanie 
            next_state (tensor): następny stan
        """

        item = (state, action, reward, next_state, done)
        self.memory.append(item)


    def get_target_q_value(self, next_state, reward):
        """Oblicz maksymalną wartość Q_max.
           Użycie sieci docelowej rozwiązuje problem niestabilności.
        Argumenty:
            reward (float): nagroda otrzymana po wykonaniu akcji w stanie
            next_state (tensor): następny stan
        Zwraca:
            q_value (float): obliczona wartość maksymalna Q
        """

        # maksymalna wartość Q pomiędzy akcjami następnych stanów
        # DQN wybiera maksymalną wartość Q spośród następnych akcji
        # selekcja i ocena akcji jest dokonywana w docelowej sieci Q
        # Q_max = max_a' Q_docelowa(s', a')

        q_value = np.amax(\
                     self.target_q_model.predict(next_state)[0])

        # Q_max = nagroda+gamma * Q_max
        q_value *= self.gamma
        q_value += reward
        return q_value


    def replay(self, batch_size):
        """powtórka doświadczenia pomaga w rozwiązaniu problemu korelacji między próbkami
        Argumenty:
            batch_size (int): rozmiar próbki partii bufora powtórki doświadczenia 
        """
        # sars = stan, akcja, nagroda, stan' (next_state)

        sars_batch = random.sample(self.memory, batch_size)
        state_batch, q_values_batch = [], []

        # do poprawienia: z uwagi na prędkość może to być wykonane 
        # na poziomie tensorów, ale łatwiej zrozumieć z użyciem pętli 

        for state, action, reward, next_state, done in sars_batch:
            # predykcja strategii dla danego stanu
            q_values = self.q_model.predict(state)
            
            # pobierz Q_max
            q_value = self.get_target_q_value(next_state, reward)

            # korekcja wartości Q dla użytej akcji
            q_values[0][action] = reward if done else q_value

            # zbierz mapowanie partia-stan-wartość_q
            state_batch.append(state[0])
            q_values_batch.append(q_values[0])

        # uczenie sieci Q
        self.q_model.fit(np.array(state_batch),
                         np.array(q_values_batch),
                         batch_size=batch_size,
                         epochs=1,
                         verbose=0)

        # aktualizacja prawdopodobieństwa eksploracja-eksploatacja
        self.update_epsilon()

        # kopia nowych parametrów w miejsce starych docelowych 
        # po każdych 10 aktualizacjach treningu
        if self.replay_counter % 10 == 0:
            self.update_weights()

        self.replay_counter += 1

    
    def update_epsilon(self):
        """zmniejszanie eksploracji, zwiększanie eksploatacji"""
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
        

class DDQNAgent(DQNAgent):
    def __init__(self,
                 state_space, 
                 action_space, 
                 episodes=500):
        super().__init__(state_space, 
                         action_space, 
                         episodes)
        """Agent DDQN w środowisku CartPole-v0

        Argumenty:
            state_space (tensor): przestrzeń stanów
            action_space (tensor): przestrzeń akcji 
            episodes (int): liczba epizodów uczenia
        """

        # nazwa pliku wag sieci Q

        self.weights_file = 'ddqn_cartpole.h5'
        print("-------------DDQN------------")

    def get_target_q_value(self, next_state, reward):
        """oblicz Q_max
            Użycie docelowej sieci Q rozwiązuje problem niestabilności 
        Argumenty:
            reward (float): nagroda otrzymana po wykonaniu akcji w stanie
            next_state (tensor): następny stan
        Zwraca:
            q_value (float): obliczona wartość maksymalna Q
        """
        # maksymalna wartość Q pomiędzy akcjami następnych stanów DDQN
        # bieżąca sieć Q wybiera akcję
        # a'_max = argmax_a' Q(s', a')

        action = np.argmax(self.q_model.predict(next_state)[0])
        # docelowa sieć Q ocenia akcję
        # Q_max = Q_docelowe(s', a'_max)
        q_value = self.target_q_model.predict(\
                                      next_state)[0][action]

        # Q_max = nagroda+gamma * Q_max
        q_value *= self.gamma
        q_value += reward
        return q_value



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('env_id',
                        nargs='?',
                        default='CartPole-v0',
                        help='Wybierz srodowisko do uruchomienia')
    parser.add_argument("-d",
                        "--ddqn",
                        action='store_true',
                        help="Użyj podwojnej sieci DQN")
    parser.add_argument("-r",
                        "--no-render",
                        action='store_true',
                        help="Dezaktywuj renderowanie (dla srodowisk bez grafiki")
    args = parser.parse_args()

    # liczba prób przed uznaniem próby za nieudaną
    win_trials = 100

    # problem CartPole-v0 uwżamy za rozwiązany  
    # jeśli po 100 następujących po sobie próbach kijek na wózku 
    # sie nie przewraca i otrzymał średnią nagrodę w wysokości 195.0 
    # nagroda jest zwiększana o +1 w każdym kroku czasowym  
    # w którym kijek utrzymuje się prosto
    win_reward = { 'CartPole-v0' : 195.0 }

    # zachowanie wartości nagrody na epizod
    scores = deque(maxlen=win_trials)

    logger.setLevel(logger.ERROR)
    env = gym.make(args.env_id)

    outdir = "/tmp/dqn-%s" % args.env_id
    if args.ddqn:
        outdir = "/tmp/ddqn-%s" % args.env_id

    if args.no_render:
        env = wrappers.Monitor(env,
                               directory=outdir,
                               video_callable=False,
                               force=True)
    else:
        env = wrappers.Monitor(env, directory=outdir, force=True)
    env.seed(0)

    # utworzenie instancji agenta DQN/DDQN
    if args.ddqn:
        agent = DDQNAgent(env.observation_space, env.action_space)
    else:
        agent = DQNAgent(env.observation_space, env.action_space)

    # powinno udać się rozwiązać w takiej liczbie epizodów
    episode_count = 3000
    state_size = env.observation_space.shape[0]
    batch_size = 64

    # domyślnie, liczba kroków w epizodzie dla CartPole-v0 jest ustawiona maksymalnie na 200
    # możesz tutaj ją zmienić i poeksperymentować 
    # env._max_episode_steps = 4000

# próbkowanie i dopasowanie Q-uczenia
    for episode in range(episode_count):
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        done = False
        total_reward = 0
        while not done:
            # w CartPole-v0, akcja = 0 oznacza „w lewo”, 
            # a akcja = 1 „w prawo”
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            # w CartPole-v0: stan = [pozycja, prędkość,
            #                        theta, prędkość_kątowa]
            next_state = np.reshape(next_state, [1, state_size])
            # zachowaj każdą jednostkę doświadczenia w buforze powtórek
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward


        # wywołaj powtórkę doświadczeń
        if len(agent.memory) >= batch_size:
            agent.replay(batch_size)
    
        scores.append(total_reward)
        mean_score = np.mean(scores)
        if mean_score >= win_reward[args.env_id] \
                and episode >= win_trials:
            print("Rozwiazane w epizodzie  %d: \
                   Mean survival = %0.2lf w %d epizodach"
                  % (episode, mean_score, win_trials))
            print("Epsilon: ", agent.epsilon)
            agent.save_weights()
            break
        if (episode + 1) % win_trials == 0:
            print("Epizod %d: Mean survival = \
                   %0.2lf in %d episodes" %
                  ((episode + 1), mean_score, win_trials))

    # zamyka środowisko i zapisuje monitorowane wyniki na dysk
    env.close() 
