import zipfile
import numpy as np
from skimage import color, exposure, transform
from skimage import io
import os
import glob
import pickle
import random
import pandas as pd
import seaborn as sns
from absl import flags
import matplotlib.pyplot as plt
from matplotlib import figure
import matplotlib.gridspec as gridspec
from matplotlib.backends import backend_agg
import tensorflow as tf

from parameters import *


DATA_DIR = os.path.join(os.getcwd(), "..", 'Data')
FLAGS = flags.FLAGS

def extract_dataset():
    # Wyodrębnianie danych treningowych
    if not os.path.exists(os.path.join(DATA_DIR, "GTSRB")):
        zip_ref = zipfile.ZipFile(os.path.join(DATA_DIR, 'GTSRB_Final_Training_Images.zip'), 'r')
        zip_ref.extractall(DATA_DIR)
        zip_ref.close()
        # Wyodrębnianie danych testowych
        zip_ref = zipfile.ZipFile(os.path.join(DATA_DIR, 'GTSRB_Final_Test_Images.zip'), 'r')
        zip_ref.extractall(DATA_DIR)
        zip_ref.close()
        # Wyodrębnianie pliku csv zawierającego adnotacje
        zip_ref = zipfile.ZipFile(os.path.join(DATA_DIR,'GTSRB_Final_Test_GT.zip'), 'r')
        zip_ref.extractall(os.path.join(DATA_DIR,"GTSRB"))
    else:
        print ("Dane zostały już wyodrębnione do folderu. Nie ma potrzeby ponownego wyodrębniania danych z pliku zip.")

def normalize_and_reshape_img(img):
    # Normalizacja histogramu w kanale v
    hsv = color.rgb2hsv(img)
    hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
    img = color.hsv2rgb(hsv)

    # Przycięcie środka
    min_side = min(img.shape[:-1])
    centre = img.shape[0] // 2, img.shape[1] // 2
    img = img[centre[0] - min_side // 2:centre[0] + min_side // 2,
              centre[1] - min_side // 2:centre[1] + min_side // 2,
              :]
    # Przeskalowanie do żądanego rozmiaru
    img = transform.resize(img, (IMG_SIZE, IMG_SIZE))
    return img

def get_class(img_path):
    try:
        return int(img_path.split('/')[-2])
    except:
        return int(img_path.split('\\')[-2])

def preprocess_and_save_data(data_type ='train'):
    '''
    Wstępnie przetworzenie danych obrazu i zapisanie cech obrazu i etykiet jako plików pickle do wykorzystania w modelu.
    :parametr data_type: data_type to 'train' lub 'test'
    :zwracane wartości: brak
    '''
    if data_type =='train':
        root_dir = os.path.join(DATA_DIR, 'GTSRB/Final_Training/Images/')
        imgs = []
        labels = []

        all_img_paths = glob.glob(os.path.join(root_dir, '*/*.ppm'))
        np.random.shuffle(all_img_paths)
        for img_path in all_img_paths:
            img = normalize_and_reshape_img(io.imread(img_path))
            label = get_class(img_path)
            imgs.append(img)
            labels.append(label)
        X_train = np.array(imgs, dtype='float32')
        # Tworzenie celów gorącojedynkowych
        Y_train = np.array(labels, dtype = 'uint8')

        train_data = {"features": X_train, "labels": Y_train}
        if not os.path.exists(os.path.join(DATA_DIR,"Preprocessed_Data")):
            os.makedirs(os.path.join(DATA_DIR,"Preprocessed_Data"))
        pickle.dump(train_data,open(os.path.join(DATA_DIR,"Preprocessed_Data","preprocessed_train.p"),"wb"))
        return train_data
    elif data_type == 'test':
        # Odczyt pliku testowego
        test = pd.read_csv(os.path.join(DATA_DIR, "GTSRB", 'GT-final_test.csv'), sep=';')
        X_test = []
        y_test = []
        i = 0
        for file_name, class_id in zip(list(test['Filename']), list(test['ClassId'])):
            img_path = os.path.join(DATA_DIR, 'GTSRB/Final_Test/Images/', file_name)
            X_test.append(normalize_and_reshape_img(io.imread(img_path)))
            y_test.append(class_id)

        test_data = {"features": np.array(X_test,dtype ='float32'), "labels": np.array(y_test,dtype = 'uint8')}
        if not os.path.exists(os.path.join(DATA_DIR,"Preprocessed_Data")):
            os.makedirs(os.path.join(DATA_DIR,"Preprocessed_Data"))
        pickle.dump(test_data,open(os.path.join(DATA_DIR,"Preprocessed_Data","preprocessed_test.p"),"wb"))
        return test_data

def load_preprocessed_data():

    '''
    Wczytywanie wstępnie przetworzonych danych, jeśli są już obecne. W przeciwnym należy 
	wstępnie przetworzyć dane, a następnie je wczytać.
    :zwraca:
    '''
    print ("Wczytywanie danych treningowych")
    if not os.path.isfile(os.path.join(DATA_DIR, "Preprocessed_Data", "preprocessed_train.p")):
        print ("Przetworzony plik nie istnieje. Najpierw należy wykonać wstępne przetwarzanie danych.")
        train_data = preprocess_and_save_data(data_type='train')
    else:
       train_data= pickle.load(open(os.path.join(DATA_DIR,"Preprocessed_Data","preprocessed_train.p"),"rb"))
    X_train, y_train = train_data['features'], train_data['labels']

    print ("Wczytywanie danych testowych")
    if not os.path.isfile(os.path.join(DATA_DIR, "Preprocessed_Data", "preprocessed_test.p")):
        print ("Przetworzony plik nie istnieje. Najpierw należy wykonać wstępne przetwarzanie danych.")
        test_data = preprocess_and_save_data(data_type='test')
    else:
       test_data= pickle.load(open(os.path.join(DATA_DIR,"Preprocessed_Data","preprocessed_test.p"),"rb"))
    X_test, y_test = test_data['features'], test_data['labels']

    return X_train, y_train, X_test,y_test


def convert_to_grayscale(data):
    data_gray = np.zeros((data.shape[0], data.shape[1], data.shape[2], 1))
    for i in range(len(data)):
        if i % 10000 == 0:
            print("Liczba obrazów skonwertowanych do skali szarości ", i)
        temp = color.rgb2gray(data[i])
        temp = temp.reshape((temp.shape[0], temp.shape[1], 1))
        data_gray[i] = temp
    return data_gray.astype(np.float32)

def load_grayscale_images(data, data_type = 'train'):
    '''
    Konwersja danych do skali szarości, ponieważ zależy nam tylko na klasyfikacji, a nie na kolorze znaku drogowego
    :parametr data: dane obrazu, który ma być skonwertowany do skali szarości
    :zwraca: obraz (obrazy) w skali szarości
    '''
    if data_type == 'train':
        if not os.path.exists(os.path.join(DATA_DIR, "Preprocessed_Data", 'preprocessed_train_gray.p')):
            data_gray = convert_to_grayscale(data)
            # Zapisanie data_gray jako plik pickle
            pickle.dump(data_gray, open(os.path.join(DATA_DIR, "Preprocessed_Data", "preprocessed_train_gray.p"), "wb"))
        else:
            data_gray = pickle.load(open(os.path.join(DATA_DIR, "Preprocessed_Data", "preprocessed_train_gray.p"), "r"))
    elif data_type == 'test':
        if not os.path.exists(os.path.join(DATA_DIR, "Preprocessed_Data", 'preprocessed_test_gray.p')):
            data_gray = convert_to_grayscale(data)
            # Zapisanie data_gray jako plik pickle
            pickle.dump(data_gray, open(os.path.join(DATA_DIR, "Preprocessed_Data", "preprocessed_test_gray.p"), "wb"))
        else:
            data_gray = pickle.load(open(os.path.join(DATA_DIR, "Preprocessed_Data", "preprocessed_test_gray.p"), "r"))

    return data_gray


def build_data_pipeline(X_train, X_test,y_train, y_test):
    '''
    Iterator zbioru danych do szkolenia modelu
    :parametr X_train: macierz Numpy składająca się z obrazów treningowych
    :parametr X_test: macierz Numpy składająca się z obrazów testowych
    :parametr y_train: macierz Numpy składająca się z etykiet treningowych
    :parametr y_test: macierz Numpy składająca się z etykiet testowych
    :zwraca: iteratory do szkolenia i testowania
    '''

    train_data = tf.data.Dataset.from_tensor_slices((np.float32(X_train), np.int32(y_train)))
    train_batches = train_data.shuffle(50000, reshuffle_each_iteration=True).repeat().batch(BATCH_SIZE)
    train_iterator = train_batches.make_one_shot_iterator()

    # Bodowanie iteratora z test_dataset przy batch_size = X_test.shape[0].
	# Używamy całości danych testowych dla jednorazowego użycia iteratora.
    test_data = tf.data.Dataset.from_tensor_slices((np.float32(X_test),np.int32(y_test)))
    test_frozen = (test_data.take(X_test.shape[0]).repeat().batch(X_test.shape[0]))
    test_iterator = test_frozen.make_one_shot_iterator()

    # Wprowadzamy dane do iteratora, który potrafi przełączać się pomiędzy wejściem treningowym, a walidacyjnym.
    iter_handle = tf.placeholder(tf.string, shape=[])
    iterator_feed = tf.data.Iterator.from_string_handle(iter_handle, train_batches.output_types, train_batches.output_shapes)
    images, labels = iterator_feed.get_next()

    return images, labels, iter_handle, train_iterator, test_iterator



def plot_input_data(X_train,y_train):
    # Wyrysujmy obrazy konkretnego znaku i zobaczmy różnice.
    num_rows = 9
    num_cols = 5

    fig = plt.figure(figsize=(num_cols, num_rows))
    gs = gridspec.GridSpec(num_rows, num_cols, wspace=0.0)
    ax = [plt.subplot(gs[i]) for i in range(num_rows * num_cols)]
    for i in range(num_rows * num_cols):
        ax[i].axis('off')
        if i < 43:
            indexes = list(np.where(y_train == i))[0]
            image = X_train[random.choice(indexes)]
            ax[i].imshow(image, interpolation='nearest')

    image_name = 'Input_Images.png'
    if not os.path.exists(os.path.join(DATA_DIR, "..", "Plots")):
        os.makedirs(os.path.join(DATA_DIR, "..", "Plots"))
    fig.savefig(os.path.join(DATA_DIR, "..", "Plots",image_name), dpi=fig.dpi)
    plt.clf()

def plot_weight_posteriors(names, qm_vals, qs_vals, fname):
  """Zapisanie wykresu PNG plot z histogramami średnich wag i odchyleń standardowych.
  Argumenty:
    names: Obiekt `iterable` Pythona zawierający nazwy zmiennych jako `str`.
    qm_vals: Obiekt `iterable` Pythona, tej samej długości co `names`,
      którego elementy są tablicami Numpy, dowolnego kształtu, zawierającymi 
      średnie a posteriori zmiennych wagowych.
    qs_vals: Obiekt `iterable` Pythona, tej samej długości co `names`,
      którego elementy są tablicami Numpy, dowolnego kształtu, zawierającymi
      odchylenia standardowe a posteriori zmiennych wagowych.
    fname: nazwa pliku jako obiekt `str` Pythona, do którego ma być zapisany wykres.
  """
  fig = figure.Figure(figsize=(6, 3))
  canvas = backend_agg.FigureCanvasAgg(fig)

  ax = fig.add_subplot(1, 2, 1)
  for n, qm in zip(names, qm_vals):
    sns.distplot(qm.flatten(), ax=ax, label=n)
  ax.set_title("Średnia wag")
  ax.set_xlim([-1.5, 1.5])
  ax.legend()

  ax = fig.add_subplot(1, 2, 2)
  for n, qs in zip(names, qs_vals):
    sns.distplot(qs.flatten(), ax=ax)
  ax.set_title("Odchylenie st. wag")
  ax.set_xlim([0, 1.])

  fig.tight_layout()
  save_dir = os.path.join(DATA_DIR, "..","Plots")
  canvas.print_figure(os.path.join(save_dir, fname), format="png")
  print("saved {}".format(fname))


def plot_heldout_prediction(input_vals, probs,
                            fname,  title=""):
  """Wykreślanie niepewności w przewidywaniu próbkowanego obrazu.
  Argumenty:
    input_vals: tablica NymPy w stylu `float` o kształcie
      `IMAGE_SHAPE`, zawierająca próbkowany obraz testowy.
    probs: tablica NymPy w stylu `float` o kształcie `[num_monte_carlo,
      1, num_classes]` zawierająca próbki Monte Carlo 
      prawdopodobieństw klas dla obrazu testowego.
    fname: nazwa pliku jako obiekt `str` Pythona, do którego ma być zapisany wykres.
    title: obiekt `str` Pythona określający tytuł wykresu.
  """
  save_dir = os.path.join(DATA_DIR, "..", "Plots")
  fig = figure.Figure(figsize=(1, 1))
  canvas = backend_agg.FigureCanvasAgg(fig)
  ax = fig.add_subplot(1,1,1)
  ax.imshow(input_vals.reshape((IMG_SIZE,IMG_SIZE)), interpolation="None")
  canvas.print_figure(os.path.join(save_dir, fname + "_image.png"), format="png")

  fig = figure.Figure(figsize=(10, 5))
  canvas = backend_agg.FigureCanvasAgg(fig)
  ax = fig.add_subplot(1,1,1)
  #Przewidywania
  y_pred_list = list(np.argmax(probs,axis=1).astype(np.int32))
  bin_range = [x for x in range(43)]
  ax.hist(y_pred_list,bins = bin_range)
  ax.set_xticks(bin_range)
  ax.set_title("Histogram przewidywanej klasy: " + title)
  ax.set_xlabel("Klasa")
  ax.set_ylabel("Częstotliwość")
  fig.tight_layout()
  save_dir = os.path.join(DATA_DIR, "..", "Plots")
  canvas.print_figure(os.path.join(save_dir, fname + "_predicted_class.png"), format="png")
  print("saved {}".format(fname))
