'''Uczenie sieci LSGAN na zbiorze MNIST z użyciem Keras

LSGAN jest podobny do DCGAN z wyjątkiem tego, że używamy funkcji straty MSE przez Dyskryminator i siec współzawodniczącą.
  
[1] Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional
generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).

[2] Mao, Xudong, et al. "Least squares generative adversarial networks." 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017.
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.layers import Input
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import load_model

import numpy as np
import argparse

import sys
sys.path.append("..")
from lib import gan


def build_and_train_models():
    """Załadowanie zbioru danych, konstruowanie dyskryminatora LSGAN,
    generatora i modeli sieci współzawodniczącej
    Wywołanie procedury uczenia LSGAN.
    """
    # załadowanie zbioru MNIST
    (x_train, _), (_, _) = mnist.load_data()

    # zmiana rozmiaru danych dla CNN na (28, 28, 1) i normalizacja
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, 
                         [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255

    model_name = "lsgan_mnist"
    # parametry sieci
    # 100D wektor niejawny lub wektor z
    latent_size = 100
    input_shape = (image_size, image_size, 1)
    batch_size = 64
    lr = 2e-4
    decay = 6e-8
    train_steps = 40000

    # konstruowanie modelu dyskryminatora
    inputs = Input(shape=input_shape, name='discriminator_input')
    discriminator = gan.discriminator(inputs, activation=None)
    # W [1] użyto Adam, ale dyskryminator łatwiej osiąga zbieżność dla RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    # W LSGAN używamy MSE jako funkcji straty [2]
    discriminator.compile(loss='mse',
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()

    # budowanie modelu generatora
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = gan.generator(inputs, image_size)
    generator.summary()

    # budowanie modelu sieci współzawodniczącej = generator+dyskryminator
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    # zamrożenie wag dyskryminatora podczas trenowania sieci współzawodniczącej
    discriminator.trainable = False
    adversarial = Model(inputs,
                        discriminator(generator(inputs)),
                        name=model_name)
    # W LSGAN użyto straty MSE [2]
    adversarial.compile(loss='mse',
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()

    # uczenie dyskryminatora i sieci współzawodniczącej
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    gan.train(models, x_train, params)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load generator h5 model with trained weights"
    parser.add_argument("-g", "--generator", help=help_)
    args = parser.parse_args()
    if args.generator:
        generator = load_model(args.generator)
        gan.test_generator(generator)
    else:
        build_and_train_models()
