import os
from time import time
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from parameters import *


def create_model_dir():
    current_time = time()
    model_dir = LOGGING_DIR + "/model_files_{}".format(current_time)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    return model_dir

def create_freeze_graph_dir(model_dir):
    freeze_graph_dir = os.path.join(model_dir, "freeze")
    if not os.path.exists(freeze_graph_dir):
        os.makedirs(freeze_graph_dir)
    return freeze_graph_dir

def create_optimized_graph_dir(model_dir):
    optimized_graph_dir = os.path.join(model_dir, "optimized")
    if not os.path.exists(optimized_graph_dir):
        os.makedirs(optimized_graph_dir)
    return optimized_graph_dir

def create_frozen_graph(sess,output_name,freeze_graph_dir):
    frozen_graph = freeze_session(sess,
                                  output_names=output_name)
    tf.train.write_graph(frozen_graph, freeze_graph_dir, FREEZE_FILE_NAME , as_text=False)

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Zamraża stan sesji do przyciętego grafu obliczeniowego.

    Przekształca istniejący graf w nowy graf, w którym węzły zmiennych są zastępowane stałymi.
    Nowy graf obcina istniejący graf z wszelkich operacji, które nie są wymagane
    do obliczenia wymaganego wyjścia.

    Wyjścia są usuwane.
    @parametr session: Sesja TensorFlow do zamrożenia.
    @parametr keep_var_names: Lista nazw zmiennych, które nie powinny być zamrażane
                              lub None, aby zamrozić wszystkie zmienne w grafie.
    @parametr output_names: Nazwy odpowiednich wyjść grafu.
    @parametr clear_devices: Usuwa dyrektywy urządzeń z grafu, aby zapewnić lepszą przenośność.
    @zwraca: Definicję zamrożonego grafu.
    """

    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

def pb_to_tensorboard(input_graph_dir,graph_type ="freeze"):
    '''
    Konwersja pliku grafu ".pb" na format możliwy do odczytu przez Tensorboard
    :parametr input_graph_dir: Katalog, w którym przechowywany jest plik grafu
    :parametr graph_type: "freeze" lub "optimize", w zależności od operacji.
    :zwraca: Zapisuje plik w folderze, który może być otwarty przez Tensorboard
    '''
    file_name = ""
    if graph_type == "freeze":
        file_name = FREEZE_FILE_NAME
    elif graph_type == "optimize":
        file_name = OPTIMIZE_FILE_NAME

    with tf.Session() as sess:
        model_filename = input_graph_dir + "/" + file_name
        with gfile.FastGFile(model_filename, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            g_in = tf.import_graph_def(graph_def)
    train_writer = tf.summary.FileWriter(input_graph_dir)
    train_writer.add_graph(sess.graph)

def strip(input_graph, drop_scope, input_before, output_after, pl_name):
    '''
    Ta funkcja odcina z grafu węzeł drop_scope.
    :parametr input_graph: Graf wejściowy
    :parametr drop_scope: Zakres typu "porzuć", który należy usunąć z grafu
    :parametr input_before: Wejście przed drop_scope
    :parametr output_after:  Wyjście po drop_scope
    :parametr pl_name: Nazwa pl
    :zwraca: obcięty graf wyjściowy
    '''
    input_nodes = input_graph.node
    nodes_after_strip = []
    for node in input_nodes:
        if node.name.startswith(drop_scope + '/'):
            continue

        if node.name == pl_name:
            continue

        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        if new_node.name == output_after:
            new_input = []
            for node_name in new_node.input:
                if node_name == drop_scope + '/cond/Merge':
                    new_input.append(input_before)
                else:
                    new_input.append(node_name)
            del new_node.input[:]
            new_node.input.extend(new_input)
        else:
            new_input= []
            for node_name in new_node.input:
                if node_name == drop_scope + '/cond/Merge':
                    new_input.append(input_before)
                else:
                    new_input.append(node_name)
            del new_node.input[:]
            new_node.input.extend(new_input)

        nodes_after_strip.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_strip)
    return output_graph


def optimize_graph(input_dir, output_dir):
    '''
    Służy do optymalizacji zamrożonego grafu poprzez usunięcie zbędnych operacji
    :parametr input_dir: katalog, w którym przechowywany jest graf wejściowy.
    :parametr output_dir: katalog, w którym powinien być przechowywany graf końcowy.
    :zwracane wartości: Brak
    '''
    input_graph = os.path.join(input_dir, FREEZE_FILE_NAME)
    output_graph = os.path.join(output_dir, OPTIMIZE_FILE_NAME)

    input_graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(input_graph, "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = strip(input_graph_def, u'dropout_1', u'conv2d_2/bias', u'dense_1/kernel', u'training')
    output_graph_def = strip(output_graph_def, u'dropout_3', u'max_pooling2d_2/MaxPool', u'flatten_2/Shape',
                             u'training')
    output_graph_def = strip(output_graph_def, u'dropout_4', u'dense_3/Relu', u'dense_4/kernel', u'training')
    output_graph_def = strip(output_graph_def, u'Adadelta_1', u'softmax_tensor_1/Softmax',
                             u'training/Adadelta/Variable', u'training')
    output_graph_def = strip(output_graph_def, u'training', u'softmax_tensor_1/Softmax',
                             u'_', u'training')

    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
