import os
import pickle
from parameters import *
import tensorflow as tf
import numpy as np

def load_data():
    """
    Wczytywanie danych
    """
    input_file = os.path.join(TEXT_SAVE_DIR)
    with open(input_file, "r") as f:
        data = f.read()

    return data

def preprocess_and_save_data():
    """
    Wstępne przetwarzanie zbioru danych dla skryptu książki
    """
    text = load_data()
    token_dict = define_tokens()
    for key, token in token_dict.items():
        text = text.replace(key, ' {} '.format(token))

    text = text.lower()
    text = text.split()

    vocab_to_int, int_to_vocab = create_map(text)
    int_text = [vocab_to_int[word] for word in text]
    pickle.dump((int_text, vocab_to_int, int_to_vocab, token_dict), open('processed_text.p', 'wb'))


def load_preprocess_file():
    """
    Wczytywanie przetworzonych danych skryptu książki
    """
    return pickle.load(open('processed_text.p', mode='rb'))


def save_params(params):
    """
    Zapisywanie parametrów do pliku
    """
    pickle.dump(params, open('parameters.p', 'wb'))


def load_params():
    """
    Wczytywanie parametrów z pliku
    """
    return pickle.load(open('parameters.p', mode='rb'))

def create_map(input_text):
    """
    Przyporządkowanie słów do indeksów i vice versa dla ułatwienia wyszukiwania
    :parametr input_text: dane skryptu podzielone na słowa
    :zwraca: krotkę słowników (vocab_to_int, int_to_vocab)
    """
    vocab = set(input_text)
    vocab_to_int = {c: i for i, c in enumerate(vocab)}
    int_to_vocab = dict(enumerate(vocab))
    return vocab_to_int, int_to_vocab

def define_tokens():
    """
    Generowanie słownika w celu zamiany znaków interpunkcyjnych w tokeny. Sym na początku tokena oznacza symbol.
    :zwraca: stokenizowany słownik, w którym kluczem jest znak interpunkcyjny, a wartością jest token
    """
    dict = {'.':'_Sym_Period_',
            ',':'_Sym_Comma_',
            '"':'_Sym_Quote_',
            ';':'_Sym_Semicolon_',
            '!':'_Sym_Exclamation_',
            '?':'_Sym_Question_',
            '(':'_Sym_Left_Parentheses_',
            ')':'_Sym_Right_Parentheses_',
            '--':'_Sym_Dash_',
            '\n':'_Sym_Return_',
           }
    return dict

def generate_batch_data(int_text):
    """
    Generowanie danych dla grup x (wejścia) i y (cele)
    :parametr int_text: tekst z wyrazami zastąpionymi przez ich identyfikatory
    :zwraca: grupy jako tablica Numpy
    """
    num_batches = len(int_text) // (BATCH_SIZE * SEQ_LENGTH)

    x = np.array(int_text[:num_batches * (BATCH_SIZE * SEQ_LENGTH)])
    y = np.array(int_text[1:num_batches * (BATCH_SIZE * SEQ_LENGTH) + 1])

    x_batches = np.split(x.reshape(BATCH_SIZE, -1), num_batches, 1)
    y_batches = np.split(y.reshape(BATCH_SIZE, -1), num_batches, 1)
    batches = np.array(list(zip(x_batches, y_batches)))
    return batches

def extract_tensors(tf_graph):
    """
    Pobierz z grafu tensory: input, initial_state, final_state i probs
    :parametr loaded_graph: graf TensorFlow wczytany z pliku
    :zwraca: krotkę (tensor_input, tensor_initial_state, tensor_final_state, tensor_probs)
    """
    tensor_input = tf_graph.get_tensor_by_name("Input/input:0")
    tensor_initial_state = tf_graph.get_tensor_by_name("Network/initial_state:0")
    tensor_final_state = tf_graph.get_tensor_by_name("Network/final_state:0")
    tensor_probs = tf_graph.get_tensor_by_name("Network/probs:0")
    return tensor_input, tensor_initial_state, tensor_final_state, tensor_probs

def select_next_word(probs, int_to_vocab):
    """
    Wybór następnego słowa dla generowanego tekstu
    :parametr probs: lista prawdopodobieństw wszystkich słów w słowniku, które mogą być wybrane jako następne słowo
    :parametr int_to_vocab: słownik identyfikatorów słów jako kluczy i słów jako wartości
    :zwraca: przewidywane następne słowo
    """
    index = np.argmax(probs)
    word = int_to_vocab[index]
    return word


def predict_book_script():
    _, vocab_to_int, int_to_vocab, token_dict = load_preprocess_file()
    seq_length, load_dir = load_params()

    script_length = 250 # Długość skryptu książki do wygenerowania. 250 oznacza 250 słów

    first_word = 'postgresql' # postgresql lub inne dowolne słowo z książki

    loaded_graph = tf.Graph()
    with tf.Session(graph=loaded_graph) as sess:
        # Wczytanie zapisanego modelu
        loader = tf.train.import_meta_graph(load_dir + '.meta')
        loader.restore(sess, load_dir)

        # Pobranie tensorów z wczytanego modelu
        input_text, initial_state, final_state, probs = extract_tensors(loaded_graph)

        # Ustawienia generowania zdań
        sentences = [first_word]
        previous_state = sess.run(initial_state, {input_text: np.array([[1]])})
        # Generowanie zdań
        for i in range(script_length):
            # Dynamiczne wejście
            dynamic_input = [[vocab_to_int[word] for word in sentences[-seq_length:]]]
            dynamic_seq_length = len(dynamic_input[0])

            # Uzyskanie przewidywań
            probabilities, previous_state = sess.run([probs, final_state], {input_text: dynamic_input, initial_state: previous_state})
            probabilities= np.squeeze(probabilities)

            pred_word = select_next_word(probabilities[dynamic_seq_length - 1], int_to_vocab)
            sentences.append(pred_word)

        # Wyciąganie tokenów ze słów
        book_script = ' '.join(sentences)
        for key, token in token_dict.items():
            book_script = book_script.replace(' ' + token.lower(), key)
        book_script = book_script.replace('\n ', '\n')
        book_script = book_script.replace('( ', '(')

        # Zapis wygenerowanego skryptu do pliku
        with open("book_script", "w") as text_file:
            text_file.write(book_script)

        print(book_script)
