'''Przykład sieci CVAE na zbiorze MNIST z wykorzystaniem CNN

Ta sieć VAE ma budowę modułową. Koder, dekoder i VAE 
są trzema modelami współdzielącymi wagi. Po wytrenowaniu modelu VAE,
koder może być używany do generowania wektorów niejawnych.
Dekoder może być używany do generowania cyfr MNIST przez próbkowanie wektora niejawnego z rozkładu Gaussa o średniej mean=0 i odchyleniu standardowym std=1.

[1] Sohn, Kihyuk, Honglak Lee, and Xinchen Yan.
"Learning structured output representation using
deep conditional generative models."
Advances in Neural Information Processing Systems. 2015.
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.layers import Conv2D, Flatten, Lambda
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.losses import mse, binary_crossentropy
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical

import numpy as np
import matplotlib.pyplot as plt
import argparse
import os


# sztuczka z reparametryzacją
# zamiast próbkować z Q(z|X), próbka eps = N(0,I)
# z = z_mean+sqrt(var)*eps
def sampling(args):
    """Trik z reparametryzacją przez próbkowanie z gaussowskiej jednostki izotropowej

    # Argumenty:
        args (tensor): średnia i log wariancji Q(z|X)

    # Zwraca:
        z (tensor): próbkowany wektor niejawny
    """


    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # domyślnie random_normal ma średnią = 0 i odchylenie standardowe = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon


def plot_results(models,
                 data,
                 y_label,
                 batch_size=128,
                 model_name="cvae_mnist"):
    """Wykres dwuwymiarowych wartości średniej Q(z|X) z użyciem etykiet
        jako gradientu koloru, a następnie narysowanie cyfr MNIST 
        w funkcji dwuwymiarowego wektora niejawnego

    Argumenty:
        models (list): modele kodera i dekodera
        data (list): dane testowe i etykieta
        y_label (array): wektor OH oznaczający która cyfrę wyświetlić
        batch_size (int): przewidywany rozmiar próbki
        model_name (string): który model używa tej funkcji
    """

    encoder, decoder = models
    x_test, y_test = data
    xmin = ymin = -4
    xmax = ymax = +4
    os.makedirs(model_name, exist_ok=True)

    filename = os.path.join(model_name, "vae_mean.png")
    # wyświetlenie dwuwymiarowego wykresu klas cyfr w przestrzeni niejawnej
    z, _, _ = encoder.predict([x_test, to_categorical(y_test)],
                              batch_size=batch_size)
    plt.figure(figsize=(12, 10))

    # zakresy na osiach x oraz y
    axes = plt.gca()
    axes.set_xlim([xmin,xmax])
    axes.set_ylim([ymin,ymax])

    # podpróbkowanie aby zredukować gęstość punktów na wykresie
    z = z[0::2]
    y_test = y_test[0::2]
    plt.scatter(z[:, 0], z[:, 1], marker="")
    for i, digit in enumerate(y_test):
        axes.annotate(digit, (z[i, 0], z[i, 1]))
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename)
    plt.show()

    filename = os.path.join(model_name, "%05d.png" % np.argmax(y_label))
    # Wyświetlenie dwuwymiarowej rozmaitości cyfr o wymiarach 10x10 (y_label)
    n = 10
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # Rozłożone równomiernie współrzędne odpowiadające dwuwymiarowemu 
    # wykresowi klas cyfr w przestrzeni niejawnej
    grid_x = np.linspace(-4, 4, n)
    grid_y = np.linspace(-4, 4, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict([z_sample, y_label])
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()


# zbiór MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()

image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# oblicz liczbę etykiet
num_labels = len(np.unique(y_train))

# parametry sieci
input_shape = (image_size, image_size, 1)
label_shape = (num_labels, )
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 30

# model VAE = koder+dekoder
# konstruowanie modelu kodera
inputs = Input(shape=input_shape, name='encoder_input')
y_labels = Input(shape=label_shape, name='class_labels')
x = Dense(image_size * image_size)(y_labels)
x = Reshape((image_size, image_size, 1))(x)
x = concatenate([inputs, x])
for i in range(2):
    filters *= 2
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x)

# potrzebna informacja o kształcie, aby zbudować model dekodera
shape = K.int_shape(x)

# generowanie niejawnego wektora Q(z|X)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# użycie triku z reparametryzacją, aby potraktować próbkowanie jako wejście
# zauważ, że "output_shape" nie jest konieczne 
# w backendzie TensorFlow 
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var])

# instancja modelu kodera
encoder = Model([inputs, y_labels],
                [z_mean, z_log_var, z], 
                name='encoder')
encoder.summary()
plot_model(encoder,
           to_file='cvae_cnn_encoder.png', 
           show_shapes=True)

# konstruowanie modelu dekodera
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = concatenate([latent_inputs, y_labels])
x = Dense(shape[1]*shape[2]*shape[3], activation='relu')(x)
x = Reshape((shape[1], shape[2], shape[3]))(x)

for i in range(2):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)
    filters //= 2

outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x)

# instancja modelu dekodera
decoder = Model([latent_inputs, y_labels],
                outputs, 
                name='decoder')
decoder.summary()
plot_model(decoder,
           to_file='cvae_cnn_decoder.png', 
           show_shapes=True)

# instancja modelu VAE
outputs = decoder([encoder([inputs, y_labels])[2], y_labels])
cvae = Model([inputs, y_labels], outputs, name='cvae')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Zaladuj model tf z wyuczonymi wagami"
    parser.add_argument("-w", "--weights", help=help_)
    help_ = "Użycie binarnej entropii krzyżowej zamiast domyślnej mse"
    parser.add_argument("--bce", help=help_, action='store_true')
    help_ = "Podaj konkretną cyfrę do wygenerowania"
    parser.add_argument("-d", "--digit", type=int, help=help_)
    help_ = "Beta in Beta-CVAE. Beta > 1. Wartość domyślna to 1.0 (CVAE)"
    parser.add_argument("-b", "--beta", type=float, help=help_)
    args = parser.parse_args()
    models = (encoder, decoder)
    data = (x_test, y_test)

    if args.beta is None or args.beta < 1.0:
        beta = 1.0
        print("CVAE")
        model_name = "cvae_cnn_mnist"
        save_dir = "wagi_cvae"
    else:
        beta = args.beta
        print("Beta-CVAE z beta=", beta)
        model_name = "beta-cvae_cnn_mnist"
        save_dir = "wagi_beta-cvae"

    # VAE loss = mse_loss lub xent_loss + kl_loss
    if args.bce:
        reconstruction_loss = binary_crossentropy(K.flatten(inputs),
                                                  K.flatten(outputs))
    else:
        reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))

    reconstruction_loss *= image_size * image_size
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5 * beta
    cvae_loss = K.mean(reconstruction_loss + kl_loss)
    cvae.add_loss(cvae_loss)
    cvae.compile(optimizer='rmsprop')
    cvae.summary()
    plot_model(cvae, to_file='cvae_cnn.png', show_shapes=True)

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    if args.weights:
        filepath = os.path.join(save_dir, args.weights)
        cvae = cvae.load_weights(filepath)
    else:
        cvae.fit([x_train, to_categorical(y_train)],
                 epochs=epochs,
                 batch_size=batch_size,
                 validation_data=([x_test, to_categorical(y_test)], None))
        filename = model_name + '.tf'
        filepath = os.path.join(save_dir, filename)
        cvae.save_weights(filepath)

    if args.digit in range(0, num_labels):
        digit = np.array([args.digit])
    else:
        digit = np.random.randint(0, num_labels, 1)

    print("CVAE dla cyfry %d" % digit)
    y_label = np.eye(num_labels)[digit]
    plot_results(models,
                 data,
                 y_label=y_label,
                 batch_size=batch_size,
                 model_name=model_name)
