import os
from parameter_config import *
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')


## Funkcje dla sieci kapsułowych
def squash(vectors, name=None):
    """
    Funkcja squashingu zaimplementowana zgodnie z opracowaniem
    :parametr vectors: wejście wektorowe, które ma być poddane squashingowi
    :parametr name: nazwa tensora na grafie
    :zwraca: tensor o takim samym kształcie jak wektory, ale poddany squashingowi, jak wspomniano w opracowaniu
    """
    with tf.name_scope(name, default_name="squash_op"):
        s_squared_norm = tf.reduce_sum(tf.square(vectors), axis=-2, keepdims=True)
        scale = s_squared_norm / (1. + s_squared_norm) / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())
        return scale*vectors


def routing(u):
    """
    Ta funkcja wykonuje algorytm trasowania, o którym wspomniano w opracowaniu
    :parametr u: Tensor wejściowy o kształcie [batch_size, num_caps_input_layer=1152, 1, caps_dim_input_layer=8, 1].
                NCAPS_CAPS1: liczba kapsuł w pierwszej warstwie kapsuł głównych
                CAPS_DIM_CAPS2: wymiary wektorów wyjściowych pierwszej warstwy kapsuł głównych

    :zwraca: wektor "v_j" (tensor) w warstwie kapsuł cyfrowych
             Kształt:[batch_size, NCAPS_CAPS1=10, CAPS_DIM_CAPS2=16, 1]
    """

    #zmienna lokalna b_ij: [batch_size, num_caps_input_layer=1152, num_caps_output_layer=10, 1, 1]
                #num_caps_output_layer: liczba kapsuł w pierwszej warstwie cyfrowej powiększona o 1
    b_ij = tf.zeros([BATCH_SIZE, NCAPS_CAPS1, NCAPS_CAPS2, 1, 1], dtype=np.float32, name="b_ij")

    # Przygotowanie tensora wejściowego na całkowitą liczbę kapsuł cyfrowych w celu przemnożenia przez W
    u = tf.tile(u, [1, 1, b_ij.shape[2].value, 1, 1])   # u => [batch_size, 1152, 10, 8, 1]


    # W: [num_caps_input_layer, num_caps_output_layer, len_u_i, len_v_j], jak wspomniano w opracowaniu
    W = tf.get_variable('W', shape=(1, u.shape[1].value, b_ij.shape[2].value, u.shape[3].value, CAPS_DIM_CAPS2),
                        dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=STDEV))
    W = tf.tile(W, [BATCH_SIZE, 1, 1, 1, 1]) # W => [batch_size, 1152, 10, 8, 16]

    #Obliczanie u_hat (jak wspomniano w opracowaniu)
    u_hat = tf.matmul(W, u, transpose_a=True)  # [batch_size, 1152, 10, 16, 1]

    # Przy przesyłaniu w przód u_hat_stopped = u_hat;
    # Przy przesyłaniu wstecz żaden gradient nie przechodzi z u_hat_stopped do u_hat
    u_hat_stopped = tf.stop_gradient(u_hat, name='gradient_stop')

    # Tu zaczyna się algorytm trasowania
    for r in range(ROUTING_ITERATIONS):
        with tf.variable_scope('iterations_' + str(r)):
            c_ij = tf.nn.softmax(b_ij, axis=2) # [batch_size, 1152, 10, 1, 1]

            # W ostatniej iteracji użyj `u_hat`, aby uzyskać wsteczną propagację gradientu
            if r == ROUTING_ITERATIONS - 1:
                s_j = tf.multiply(c_ij, u_hat) # [batch_size, 1152, 10, 16, 1]
                # następnie wykonaj sumowanie zgodnie z opracowaniem
                s_j = tf.reduce_sum(s_j, axis=1, keep_dims=True) # [batch_size, 1, 10, 16, 1]

                v_j = squash(s_j) # [batch_size, 1, 10, 16, 1]

            elif r < ROUTING_ITERATIONS - 1:  # W tych iteracjach brak wstecznej propagacji
                s_j = tf.multiply(c_ij, u_hat_stopped)
                s_j = tf.reduce_sum(s_j, axis=1, keepdims=True)
                v_j = squash(s_j)
                v_j = tf.tile(v_j, [1, u.shape[1].value, 1, 1, 1]) # [batch_size, 1152, 10, 16, 1]

                # Mnożenie w ostatnich dwóch wymiarach: [16, 1]^T x [16, 1] daje [1, 1]
                u_hat_dot_v = tf.matmul(u_hat_stopped, v_j, transpose_a=True) # [batch_size, 1152, 10, 1, 1]

                b_ij = tf.add(b_ij,u_hat_dot_v)
    return tf.squeeze(v_j, axis=1) # [batch_size, 10, 16, 1]



def load_data(load_type='train'):
    '''

    :parametr load_type: train lub test w zależności od przypadku użycia
    :zwraca: x (obrazy), y (etykiety)
    '''
    data_dir = os.path.join('data','fashion-mnist')
    if load_type == 'train':
        image_file = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
        image_data = np.fromfile(file=image_file, dtype=np.uint8)
        x = image_data[16:].reshape((60000, 28, 28, 1)).astype(np.float32)

        label_file = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
        label_data = np.fromfile(file=label_file, dtype=np.uint8)
        y = label_data[8:].reshape(60000).astype(np.int32)

        x_train = x[:55000] / 255.
        y_train = y[:55000]
        y_train = (np.arange(N_CLASSES) == y_train[:, None]).astype(np.float32)

        x_valid = x[55000:, ] / 255.
        y_valid = y[55000:]
        y_valid = (np.arange(N_CLASSES) == y_valid[:, None]).astype(np.float32)
        return x_train, y_train, x_valid, y_valid
    elif load_type == 'test':
        image_file = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
        image_data = np.fromfile(file=image_file, dtype=np.uint8)
        x_test = image_data[16:].reshape((10000, 28, 28, 1)).astype(np.float)

        label_file = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
        label_data = np.fromfile(file=label_file, dtype=np.uint8)
        y_test = label_data[8:].reshape(10000).astype(np.int32)
        y_test = (np.arange(N_CLASSES) == y_test[:, None]).astype(np.float32)
        return x_test / 255., y_test


def shuffle_data(x, y):
    """ Wymieszanie cech i etykiet danych wejściowych"""
    perm = np.arange(y.shape[0])
    np.random.shuffle(perm)
    shuffle_x = x[perm,:,:,:]
    shuffle_y = y[perm]
    return shuffle_x, shuffle_y

def write_progress(op_type = 'train'):
    """
    Tworzenie uchwytów do zapisywania wyników w pliku .csv
    :zwraca: odpowiednie pliki dzienników
    """
    if not os.path.exists(RESULTS_DIR):
        os.mkdir(RESULTS_DIR)
    if op_type == 'train':
        train_path = RESULTS_DIR  + '/' + 'train.csv'
        val_path = RESULTS_DIR + '/' + 'validation.csv'

        if os.path.exists(train_path):
            os.remove(train_path)
        if os.path.exists(val_path):
            os.remove(val_path)

        train_file = open(train_path, 'w')
        train_file.write('step,accuracy,loss\n')
        val_file = open(val_path, 'w')
        val_file.write('epoch,accuracy,loss\n')
        return train_file, val_file
    else:
        test_path = RESULTS_DIR + '/test.csv'
        if os.path.exists(test_path):
            os.remove(test_path)
        test_file = open(test_path, 'w')
        test_file.write('accuracy,loss\n')
        return test_file


def load_existing_details():
    """
    Ta funkcja wczytuje plik treningowy i walidacyjny, aby kontynuować szkolenie
    :zwraca: uchwyty dla pliku treningowego i walidacyjnego oraz minimalną stratę walidacji
    """
    train_path = RESULTS_DIR  + '/' + 'train.csv'
    val_path = RESULTS_DIR + '/' + 'validation.csv'
    # Ustalenie dotychczasowej minimalnej straty walidacji
    f_val = open(val_path, 'r')
    lines = f_val.readlines()
    data = np.genfromtxt(lines[-1:], delimiter=',')
    min_loss = np.min(data[1:, 2])
    # Wczytanie pliku treningowego i walidacyjnego w celu kontynuowania szkolenia
    train_file = open(train_path, 'a')
    val_file = open(val_path, 'a')
    return train_file, val_file, min_loss


def eval_performance(sess, model, x, y):
    '''
    Ta funkcja  jest wykorzystywana głównie do oceny dokładności zestawów testowych i walidacyjnych
    :parametr sess: sesja
    :parametr model: model do wykorzystania
    :parametr x: obrazy
    :parametr y: etykiety
    :zwraca: średnią dokładność i stratę dla zbioru danych
    '''
    acc_all = loss_all = np.array([])
    num_batches = int(y.shape[0] / BATCH_SIZE)
    for batch_num in range(num_batches):
        start = batch_num * BATCH_SIZE
        end = start + BATCH_SIZE
        x_batch, y_batch = x[start:end], y[start:end]
        acc_batch, loss_batch, prediction_batch = sess.run([model.accuracy, model.combined_loss, model.y_predicted],
                                                     feed_dict={model.X: x_batch, model.Y: y_batch})
        acc_all = np.append(acc_all, acc_batch)
        loss_all = np.append(loss_all, loss_batch)
    return np.mean(acc_all), np.mean(loss_all)

def reconstruction(x, y, decoder_output, y_pred, n_samples):
    '''
    Funkcja ta służy do rekonstrukcji przykładowych obrazów do analizy
    :parametr x: Obrazy
    :parametr y: Etykiety
    :parametr decoder_output: wyjście z dekodera
    :parametr y_pred: predykcje z modelu
    :parametr n_samples: liczba obrazów
    :zwraca: zapisuje zrekonstruowane obrazy
    '''

    sample_images = x.reshape(-1, IMG_WIDTH, IMG_HEIGHT)
    decoded_image = decoder_output.reshape([-1, IMG_WIDTH, IMG_WIDTH])

    fig = plt.figure(figsize=(n_samples * 2, 3))
    for i in range(n_samples):
        plt.subplot(1, n_samples, i+ 1)
        plt.imshow(sample_images[i], cmap="binary")
        plt.title("Etykieta:" + IMAGE_LABELS[np.argmax(y[i])])
        plt.axis("off")
    fig.savefig(RESULTS_DIR + '/' + 'input_images.png')
    plt.show()

    fig = plt.figure(figsize=(n_samples * 2, 3))
    for i in range(n_samples):
        plt.subplot(1, n_samples, i + 1)
        plt.imshow(decoded_image[i], cmap="binary")
        plt.title("Predykcja:" + IMAGE_LABELS[y_pred[i]])
        plt.axis("off")
    fig.savefig(RESULTS_DIR + '/' + 'decoder_images.png')
    plt.show()




