"""Konstruowanie, uczenie i ocena modelu MINE

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from tensorflow.keras.layers import Input, Dense, Add, Activation, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

import numpy as np
import os
import argparse
import vgg

import matplotlib.pyplot as plt
from scipy.stats.contingency import margins
from data_generator import DataGenerator
from utils import unsupervised_labels, center_crop, AccuracyCallback, lr_schedule

def sample(joint=True,
           mean=[0, 0],
           cov=[[1, 0.5], [0.5, 1]],
           n_data=1000000):
    """Funkcja pomocnicza do otrzymania próbek 
        do dwuwymiarowego rozkładu Gaussa

    Argumenty:
        joint (Bool): jeśli potrzebny jest rozkład łączny 
        mean (list): wartości średnie dwuwymiarowego rozkładu Gaussa 
        cov (list): macierz kowariancji dwuwymiarowego rozkładu Gaussa
        n_data (int): liczba próbek dwuwymiarowego rozkładu Gaussa
    """
    xy = np.random.multivariate_normal(mean=mean,
                                       cov=cov,
                                       size=n_data)
    # próbki rozkładu łącznego
    if joint:
        return xy 
    y = np.random.multivariate_normal(mean=mean,
                                      cov=cov,
                                      size=n_data)

    # próbki rozkładu brzegowego
    x = xy[:,0].reshape(-1,1)
    y = y[:,1].reshape(-1,1)
   
    xy = np.concatenate([x, y], axis=1)
    return xy

def compute_mi(cov_xy=0.5, n_bins=100):
    """Analityczne obliczenie MI z wykorzystaniem zdyskretyzowanego 
        dwuwymiarowego rozkładu Gaussa

    Argumenty:
        cov_xy (list): elementy macierzy kowariancji (poza przekątną główną)
        n_bins (int): liczba przedziałów dyskretyzujących ciągły
            dwuwymiarowy rozkład Gaussa
    """
    cov=[[1, cov_xy], [cov_xy, 1]]
    data = sample(cov=cov)
    # pobierz próbki rozkładu łącznego
    # przygotowanie histogramu dyskretyzacji
    joint, edge = np.histogramdd(data, bins=n_bins)
    joint /= joint.sum()
    eps = np.finfo(float).eps
    joint[joint<eps] = eps
    # obliczenie rozkładu brzegowego
    x, y = margins(joint)

    xy = x*y
    xy[xy<eps] = eps
    # MI to P(X,Y)*log(P(X,Y)/P(X)*P(Y))
    mi = joint*np.log(joint/xy)
    mi = mi.sum()
    print("Obliczona MI: %0.6f" % mi)
    return mi

class SimpleMINE:
    def __init__(self,
                 args,
                 input_dim=1,
                 hidden_units=16,
                 output_dim=1):
        """Uczenie, jak obliczyć MI, używając MINE (algorytm 13.1)

        Argumenty:
            args: argumenty definiowane przez użytkownika, takie jak elementy
                 macierzy kowariancji spoza przekątnej, rozmiar partii, liczba epok itd.
            input_dim (int): wymiar danych wejściowych
            hidden_units (int): liczba ukrytych neuronów sieci MINE MLP
            output_dim (int): wymiar danych wyjściowych
        """
        self.args = args
        self._model = None
        self.build_model(input_dim,
                         hidden_units,
                         output_dim)


    def build_model(self,
                    input_dim,
                    hidden_units,
                    output_dim):
        """Konstruowanie prostego modelu MINE
        Argumenty:
            patrz argumenty klasy
        """
        inputs1 = Input(shape=(input_dim), name="x")
        inputs2 = Input(shape=(input_dim), name="y")
        x1 = Dense(hidden_units)(inputs1)
        x2 = Dense(hidden_units)(inputs2)
        x = Add()([x1, x2])
        x = Activation('relu', name="ReLU")(x)
        outputs = Dense(output_dim, name="MI")(x)
        inputs = [inputs1, inputs2]
        self._model = Model(inputs,
                            outputs,
                            name='MINE')
        self._model.summary()


    def mi_loss(self, y_true, y_pred):
        """ funkcja straty MINE
        Argumenty:
            y_true (tensor): nieużywany, bo to uczenie nienadzorowane
            y_pred (tensor): stos predykcji dla rozkładu łącznego T(x,y) i brzegowego T(x,y)
        """
        size = self.args.batch_size
        # dolna połowa jest predykcją dla rozkładu wspólnego
        pred_xy = y_pred[0: size, :]

        # dolna polowa jest predykcją dla rozkładu brzegowego 
        pred_x_y = y_pred[size : y_pred.shape[0], :]
        # implementacja straty MINE (równanie 13.23)
        loss = K.mean(pred_xy) \
               - K.log(K.mean(K.exp(pred_x_y)))
        return -loss


    def train(self):
        """Uczenie MINE do szacowania MI pomiędzy 
            dwuwymiarowymi rozkładami Gaussa X i Y
        """
        optimizer = Adam(lr=0.01)
        self._model.compile(optimizer=optimizer,
                            loss=self.mi_loss)
        plot_loss = []
        cov=[[1, self.args.cov_xy], [self.args.cov_xy, 1]]
        loss = 0.
        for epoch in range(self.args.epochs):
            # próbki rozkładu łącznego
            xy = sample(n_data=self.args.batch_size,
                        cov=cov)
            x1 = xy[:,0].reshape(-1,1)
            y1 = xy[:,1].reshape(-1,1)
             # próbki rozkładu brzegowego
            xy = sample(joint=False,
                        n_data=self.args.batch_size,
                        cov=cov)
            x2 = xy[:,0].reshape(-1,1)
            y2 = xy[:,1].reshape(-1,1)
    
            # uczenie na partii danych z próbek z rozkładu łącznego i brzegowego
            x =  np.concatenate((x1, x2))
            y =  np.concatenate((y1, y2))
            loss_item = self._model.train_on_batch([x, y],
                                                   np.zeros(x.shape))
            loss += loss_item
            plot_loss.append(-loss_item)
            if (epoch + 1) % 100 == 0:
                fmt = "Epok %d MINE MI: %0.6f" 
                print(fmt % ((epoch+1), -loss/100))
                loss = 0.

        plt.plot(plot_loss, color='black')
        plt.xlabel('epok')
        plt.ylabel('MI')
        plt.savefig("simple_mine_mi.png", dpi=300, color='black')
        plt.show()


    @property
    def model(self):
        return self._model


class LinearClassifier:
    def __init__(self,
                 latent_dim=10,
                 n_classes=10):
        """prosty liniowy klasyfikator oparty na MLP. 
            Klasyfikator liniowy jest siecią MLP 
            bez nieliniowej funkcji aktywacji takiej jak ReLU. 
            Może być użyty jako substytut algorytmu przypisania liniowego

        Argumenty:
            latent_dim (int): wymiarowość wektora niejawnego
            n_classes (int): liczba klas, 
                            do jakiej ma być przekształcony wymiar niejawny
        """
        self.build_model(latent_dim, n_classes)


    def build_model(self, latent_dim, n_classes):
        """Konstruktor modelu klasyfikatora liniowego

        Argumenty: (patrz argumenty klasy)
        """
        inputs = Input(shape=(latent_dim,), name="cluster")
        x = Dense(256)(inputs)
        outputs = Dense(n_classes,
                        activation='softmax',
                        name="class")(x)
        name = "classifier"
        self._model = Model(inputs, outputs, name=name)
        self._model.compile(loss='categorical_crossentropy',
                            optimizer='adam',
                            metrics=['accuracy'])
        self._model.summary()


    def train(self, x_test, y_test):
        """Uczenie klasyfikatora liniowego. 

        Argumenty:
            x_test (tensor): Obraz ze zbioru testowego
            y_test (tensor): Etykieta odpowiadająca obrazowi ze zbioru testowego
        """
        self._model.fit(x_test,
                        y_test,
                        epochs=10,
                        batch_size=128)


    def eval(self, x_test, y_test):
        """Ocena klasyfikatora liniowego. 

        Argumenty:
            x_test (tensor): Obraz ze zbioru testowego
            y_test (tensor): Etykieta odpowiadająca obrazowi ze zbioru testowego
        """
        self._model.fit(x_test,
                        y_test,
                        epochs=10,
                        batch_size=128)

        score = self._model.evaluate(x_test,
                                     y_test,
                                     batch_size=128,
                                     verbose=0)
        accuracy = score[1] * 100
        return accuracy


    @property
    def model(self):
        return self._model



class MINE:
    def __init__(self,
                 args,
                 backbone):
        """Zawiera koder, SimpleMINE, liniowy model klasyfikujący 
            funkcję straty, ładowanie zbioru, procedurę uczącą i oceny
            umożliwiającą implementację nienadzorowanego grupowania 
            z użyciem MINE przez maksymalizację informacji wzajemnej

        Argumenty:
            args: argumenty linii poleceń do określenia rozmiaru partii
                katalogu do zapisu pliku z wagami, nazwy pliku z wagami itp.
            backbone (Model): szkielet koderów MINE (tzn. VGG)
        """
        self.args = args
        self.latent_dim = args.latent_dim
        self.backbone = backbone
        self._model = None
        self._encoder = None
        self.train_gen = DataGenerator(args, 
                                       siamese=True,
                                       mine=True)
        self.n_labels = self.train_gen.n_labels
        self.build_model()
        self.accuracy = 0


    def build_model(self):
        """Konstruowanie modelu MINE do nienadzorowanego grupowania
        """
        inputs = Input(shape=self.train_gen.input_shape,
                       name="x")
        x = self.backbone(inputs)
        x = Flatten()(x)
        y = Dense(self.latent_dim,
                  activation='linear',
                  name="encoded_x")(x)
        # koder jest oparty na szkielecie (tzn. VGG) ekstraktorów cech
        self._encoder = Model(inputs, y, name="encoder")
        # użyta SimpleMINE dla dwuwymiarowego rozkładu Gaussa jako 
        # funkcja T(x,y) w MINE (algorytm 13.1)
        self._mine = SimpleMINE(self.args,
                                input_dim=self.latent_dim,
                                hidden_units=1024,
                                output_dim=1)
        inputs1 = Input(shape=self.train_gen.input_shape,
                        name="x")
        inputs2 = Input(shape=self.train_gen.input_shape,
                        name="y")
        x1 = self._encoder(inputs1)
        x2 = self._encoder(inputs2)
        outputs = self._mine.model([x1, x2])
        # model oblicza MI pomiędzy wejście_1 a wejście_2 (czyli x i y) 
        self._model = Model([inputs1, inputs2],
                            outputs, 
                            name='encoder')
        optimizer = Adam(lr=1e-3)
        self._model.compile(optimizer=optimizer, 
                            loss=self.mi_loss)
        self._model.summary()
        self.load_eval_dataset()
        self._classifier = LinearClassifier(\
                            latent_dim=self.latent_dim)


    def mi_loss(self, y_true, y_pred):
        """ Funkcja straty MINE

        Argumenty:
            y_true (tensor): Nie używane jako że jest to uczenie nienadzorowane.
            y_pred (tensor): Stos predykcji dla łącznego T(x,y) 
                            i brzegowego T(x,y)
        """
        size = self.args.batch_size
        # lower half is pred for joint dist
        pred_xy = y_pred[0: size, :]

        # upper half is pred for marginal dist
        pred_x_y = y_pred[size : y_pred.shape[0], :]
        loss = K.mean(K.exp(pred_x_y))
        loss = K.clip(loss, K.epsilon(), np.finfo(float).max)
        loss = K.mean(pred_xy) - K.log(loss)
        return -loss


    def train(self):
        """Uczenie MINE do zadania szacowania MI pomiędzy X i Y
        (tzn. obrazem MNIST i jego zmodyfikowaną wersją)
        """
        accuracy = AccuracyCallback(self)
        lr_scheduler = LearningRateScheduler(lr_schedule,
                                             verbose=1)
        callbacks = [accuracy, lr_scheduler]
        self._model.fit(x=self.train_gen,
                        use_multiprocessing=False,
                        epochs=self.args.epochs,
                        callbacks=callbacks,
                        shuffle=True)


    def load_eval_dataset(self):
        """Wstępne załadowanie zbioru uczącego do oceny
        """
        (_, _), (x_test, self.y_test) = \
                self.args.dataset.load_data()
        image_size = x_test.shape[1]
        x_test = np.reshape(x_test,
                            [-1, image_size, image_size, 1])
        x_test = x_test.astype('float32') / 255
        x_eval = np.zeros([x_test.shape[0],
                          *self.train_gen.input_shape])
        for i in range(x_eval.shape[0]):
            x_eval[i] = center_crop(x_test[i])

        self.y_test = to_categorical(self.y_test)
        self.x_test = x_eval


    def load_weights(self):
        """Przeładowanie wag modelu do oceny.
        """
        if self.args.restore_weights is None:
            error_msg = "Wagi musza zostać załadowane do przeprowadzenia oceny"
            raise ValueError(error_msg)

        if self.args.restore_weights:
            folder = "weights"
            os.makedirs(folder, exist_ok=True) 
            path = os.path.join(folder, self.args.restore_weights)
            print("Ładowanie wag... ", path)
            self._model.load_weights(path)


    def eval(self):
        """Ocena jakości wag bieżącego modelu
        """
        # przewidywanie grup dla danych testowych
        y_pred = self._encoder.predict(self.x_test)
        # uczenie klasyfikatora liniowego 
        # wejście: pogrupowane dane
        # wyjście: etykiety danych referencyjnych
        self._classifier.train(y_pred, self.y_test)
        accuracy = self._classifier.eval(y_pred, self.y_test)

        info = "Dokładność: %0.2f%%"
        if self.accuracy > 0:
            info += ", Najlepsza stara dokładność: %0.2f%%" 
            data = (accuracy, self.accuracy)
        else:
            data = (accuracy)
        print(info % data)
        # jeśli dokładność się polepszyła, zapisz wagi modelu do pliku 
        if accuracy > self.accuracy \
            and self.args.save_weights is not None:
            folder = self.args.save_dir
            os.makedirs(folder, exist_ok=True) 
            args = (self.latent_dim, self.args.save_weights)
            filename = "%d-dim-%s" % args
            path = os.path.join(folder, filename)
            print("Zapisywanie wag... ", path)
            self._model.save_weights(path)

        if accuracy > self.accuracy: 
            self.accuracy = accuracy


    @property
    def model(self):
        return self._model


    @property
    def encoder(self):
        return self._encoder


    @property
    def classifier(self):
        return self._classifier



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MI dla 2D rozkładu Gaussa')
    parser.add_argument('--cov_xy',
                        type=float,
                        default=0.5,
                        help='Rozkład Gaussa poza przekątną')
    parser.add_argument('--save-dir',
                       default="weights",
                       help='Folder do przechowywania wag modelu')
    parser.add_argument('--save-weights',
                       default=None,
                       help='Plik z (dodanymi wymiarami) wag modelu (h5).')
    parser.add_argument('--dataset',
                       default=mnist,
                       help='Używany zbiór danych')
    parser.add_argument('--epochs',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='Liczba epok uczenia')
    parser.add_argument('--batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='Rozmiar partii uczącej')
    parser.add_argument('--gaussian',
                        default=False,
                        action='store_true',
                        help='Obliczenie MI dla 2D rozkładu Gaussa')
    parser.add_argument('--plot-model',
                        default=False,
                        action='store_true',
                        help='Narysowanie wszystkich modeli sieci')
    parser.add_argument('--train',
                        default=False,
                        action='store_true',
                        help='Uczenie modelu')
    parser.add_argument('--latent-dim',
                        type=int,
                        default=10,
                        metavar='N',
                        help='Koder niejawnych wymiarów MNIST')
    parser.add_argument('--restore-weights',
                        default=None,
                        help='Przywrócenie zapisanych wag modelu')
    parser.add_argument('--eval',
                        default=False,
                        action='store_true',
                        help='Ocena wstępnie wytrenowanego modelu. Musisz wskazać plik z wagami.')

    args = parser.parse_args()
    if args.gaussian:
        print("Kowariancja poza przekątną:", args.cov_xy)
        simple_mine = SimpleMINE(args)
        simple_mine.train()
        compute_mi(cov_xy=args.cov_xy)
        if args.plot_model:
            plot_model(simple_mine.model,
                       to_file="simple_mine.png",
                       show_shapes=True)
    else:
        # zbudowanie szkieletu
        backbone = vgg.VGG(vgg.cfg['F'])
        backbone.model.summary()
        # utworzenie instancji obiektu MINE
        mine = MINE(args, backbone.model)
        if args.plot_model:
            plot_model(mine.classifier.model,
                       to_file="classifier.png",
                       show_shapes=True)
            plot_model(mine.encoder,
                       to_file="encoder.png",
                       show_shapes=True)
            plot_model(mine.model,
                       to_file="model-mine.png",
                       show_shapes=True)
        if args.train:
            mine.train()
    
        if args.eval:
            mine.load_weights()
            mine.eval()
