"""
Uzupełniający moduł, zawierający funkcje do wizualizacji grafów i sieci neuronowych.
"""


import networkx as nx
import torch_geometric as pyg
import matplotlib.pyplot as plt
import torch_geometric.data as pyg_data
import torch_geometric.utils as pyg_utils

from deepsnap import graph as deepsnap_graph
from typing import Tuple, Iterable, Dict

from utils.vis_utils import image_size_in_cm

# Mapowanie jasnych odcieni skali szarości na podstawie klas lub etykiet
_GRAYSCALE_SHADES = [0.9, 0.78, 0.66, 0.54, 0.42]


def _func_idx_to_gray(idx: int) -> str:
    """Zwraca jasny odcień szarości zależny od indeksu."""
    return str(_GRAYSCALE_SHADES[idx % len(_GRAYSCALE_SHADES)])


def add_node_to_collections(
    nodes_collection: list,
    node_colors: list,
    node_id: int,
    color_assign_func: callable,
    pyg_graph: pyg_data.Data,
):
    """Pomocnicza funkcja dodająca wierzchołki grafu do kolekcji

    Parameters
    ----------
    nodes_collection : list
        Kolekcja, do której należy dodać wierzchołki.
    node_colors : list
        Kolory wierzchołków - lista z palety matplotlib
    node_id : int
        Identyfikator wierzchołka
    color_assign_func : callable
        Funkcja przypisująca kolor w
    pyg_graph : pyg_data.Data
        Graf w formacie PyTorch Geometric
    """
    nodes_collection.append(node_id)
    node_colors.append(color_assign_func(pyg_graph.y[node_id].item()))


def plot_network_with_colors_and_shapes(
    pyg_graph: pyg_data.Data,
    color_attr: str = "y",
    use_mask_shapes: bool = False,
    title: str = "",
    seed: int = 222,
    legend_loc: str = "best",
    **kwargs,
) -> plt.Figure:
    """Funkcja rysująca graf wraz z kolorami i kształtami wierzchołków zależnymi od ich przypisania.

    Parameters
    ----------
    pyg_graph : pyg_data.Data
        Graf w formacie PyTorch Geometric
    color_attr : str, optional
        Atrybut wierzchołków, który ma być użyty do przypisania kolorów, domyślnie 'y'
    use_mask_shapes : bool, optional
        Czy używać kształtów wierzchołków zależnie od maski treningowej, walidacyjnej i testowej, domyślnie False
    seed: int
        Ziarno generatora liczb losowych
    title : str, optional
        Tytuł wykresu, domyślnie ''
    legend_loc: str, optional
        Lokalizacja legendy, domyślnie 'best'
    Returns
    -------
    plt.Figure
        Wykres z narysowanym grafem
    """
    kwargs["figsize"] = image_size_in_cm(12, 12)
    nx_graph = pyg_utils.to_networkx(pyg_graph, node_attrs=[color_attr])
    pos = nx.spring_layout(nx_graph, seed=seed)
    colors = [_func_idx_to_gray(cls) for cls in pyg_graph.y]
    node_size = 180 if len(pyg_graph.x) > 30 else 260

    fig, ax = plt.subplots(**kwargs)
    if use_mask_shapes:
        train_nodes = []
        train_node_colors = []
        val_nodes = []
        val_node_colors = []
        test_nodes = []
        test_node_colors = []
        for i in range(len(pyg_graph.x)):
            if pyg_graph.train_mask[i]:
                add_node_to_collections(
                    train_nodes, train_node_colors, i, _func_idx_to_gray, pyg_graph
                )
            elif pyg_graph.val_mask[i]:
                add_node_to_collections(
                    val_nodes, val_node_colors, i, _func_idx_to_gray, pyg_graph
                )
            elif pyg_graph.test_mask[i]:
                add_node_to_collections(
                    test_nodes, test_node_colors, i, _func_idx_to_gray, pyg_graph
                )
        nx.draw_networkx_nodes(
            nx_graph,
            pos,
            nodelist=train_nodes,
            node_shape="s",
            node_color=train_node_colors,
            node_size=node_size,
            ax=ax,
        )
        nx.draw_networkx_nodes(
            nx_graph,
            pos,
            nodelist=val_nodes,
            node_shape="^",
            node_color=val_node_colors,
            node_size=node_size,
            ax=ax,
        )
        nx.draw_networkx_nodes(
            nx_graph,
            pos,
            nodelist=test_nodes,
            node_shape="o",
            node_color=test_node_colors,
            node_size=node_size,
            ax=ax,
        )

        legend_elements = [
            plt.Line2D(
                [0],
                [0],
                marker="s",
                color="w",
                label="Trening",
                markerfacecolor="0.7",
                markersize=10,
            ),
            plt.Line2D(
                [0],
                [0],
                marker="^",
                color="w",
                label="Walidacja",
                markerfacecolor="0.55",
                markersize=10,
            ),
            plt.Line2D(
                [0],
                [0],
                marker="o",
                color="w",
                label="Test",
                markerfacecolor="0.4",
                markersize=10,
            ),
        ]
        ax.legend(handles=legend_elements, loc=legend_loc)
    else:
        nx.draw_networkx_nodes(
            nx_graph, pos, node_color=colors, node_size=node_size, ax=ax
        )
        legend_elements = [
            plt.Line2D(
                [0],
                [0],
                marker="o",
                color="w",
                label=f"Klasa {c}",
                markerfacecolor=_func_idx_to_gray(c),
                markersize=10,
            )
            for c in pyg_graph.y.unique().tolist()
        ]
    ax.legend(handles=legend_elements, loc=legend_loc)
    nx.draw_networkx_edges(nx_graph, pos, ax=ax, edge_color="0.65")
    nx.draw_networkx_labels(nx_graph, pos, ax=ax, font_size=12)
    plt.title(title)
    plt.tight_layout()
    return fig


def plot_multiple_networks_with_edge_labels(
    base_graph: pyg_data.Data,
    graphs_with_colors: Iterable[Tuple[pyg_data.Data, str, str]],
    title: str = "",
    plot_messsage_passing_links: bool = False,
    use_curved_edges: bool = False,
    legend_loc: str = "best",
    seed: int = 222,
    font_size: int = 12,
    node_size: int = 260,
    **kwargs,
) -> plt.Figure:
    """Funkcja rysująca wiele grafów z kolorami krawędzi zależnymi od ich etykiet oraz różnymi stylami linii.

    Parameters
    ----------
    base_graph: pyg_data.Data
        Graf bazowy, który posłuży do narysowania pozycji wierzchołków.
    graphs_with_colors : Iterable[Tuple[pyg_data.Data, str, str]]
        Krotka: graf, kolor, nazwa.
    edge_label_attr : str, optional
        Atrybut krawędzi, który ma być użyty do przypisania kolorów, domyślnie 'edge_label'
    seed: int
        Ziarno generatora liczb losowych
    plot_messsage_passing_links: bool
        Czy rysować krawędzie związane z przekazywaniem informacji, domyślnie False.
    use_curved_edges: bool
        Czy dla jednokierunkowych krawędzi stosować łuki w celu uniknięcia nakładania, domyślnie False.
    title : str, optional
        Tytuł wykresu, domyślnie ''
    legend_loc: str, optional
        Lokalizacja legendy, domyślnie 'best'
    font_size : int, optional
        Wielkość czcionki etykiet wierzchołków i legendy, domyślnie 12
    node_size : int, optional
        Rozmiar wierzchołków, domyślnie 260
    Returns
    -------
    plt.Figure
        Wykres z narysowanymi grafami
    """
    kwargs["figsize"] = image_size_in_cm(*kwargs.get("figsize", (12, 12)))
    fig, ax = plt.subplots(**kwargs)
    nx_graph = pyg_utils.to_networkx(base_graph)
    pos = nx.spring_layout(nx_graph, seed=seed)
    # Lista wierzchołków, do których dochodzą krawędzie
    nodes_with_edges = set()

    legend_elements = []

    split_styles = [
        {"color": "0.8", "linestyle": "solid"},
        {"color": "0.6", "linestyle": "solid"},  # clearly dashed
        {"color": "0.45", "linestyle": "solid"},  # clearly dotted
    ]

    for idx_graph, (pyg_graph, _, name) in enumerate(graphs_with_colors):
        split_style = split_styles[idx_graph % len(split_styles)]
        edges_by_pair: Dict[Tuple[int, int], list] = {}
        curved_edges: list = []
        curved_colors: list = []
        curved_styles: list = []
        straight_edges: list = []
        straight_colors: list = []
        straight_styles: list = []

        def _is_solid(style_val):
            return style_val in ("solid", "-")

        for idx, e in enumerate(pyg_graph.edge_label_index.T):
            e_tpl = tuple(e.tolist())
            pair_key = tuple(sorted(e_tpl))
            nodes_with_edges.update([e_tpl[0], e_tpl[1]])
            base_style = split_style["linestyle"]
            style = base_style if pyg_graph.edge_label[idx].item() == 1 else "dotted"
            edges_by_pair.setdefault(pair_key, []).append(
                {
                    "edge": e_tpl,
                    "color": split_style["color"],
                    "style": style,
                    "label": pyg_graph.edge_label[idx].item(),
                }
            )

        if plot_messsage_passing_links:
            for idx, e in enumerate(pyg_graph.edge_index.T):
                e_tpl = tuple(e.tolist())
                pair_key = tuple(sorted(e_tpl))
                if pair_key not in edges_by_pair:
                    edges_by_pair[pair_key] = [
                        {"edge": e_tpl, "color": "0.35", "style": "solid", "label": 1}
                    ]
                nodes_with_edges.update([e_tpl[0], e_tpl[1]])

        for pair_key, entries in edges_by_pair.items():
            if len(entries) == 1:
                entry = entries[0]
                target_curved = use_curved_edges and not _is_solid(entry["style"])
                (curved_edges if target_curved else straight_edges).append(
                    entry["edge"]
                )
                (curved_colors if target_curved else straight_colors).append(
                    entry["color"]
                )
                (curved_styles if target_curved else straight_styles).append(
                    entry["style"]
                )
            else:
                for idx_entry, entry in enumerate(entries):
                    has_other_style = any(
                        (entry["color"], entry["style"])
                        != (other["color"], other["style"])
                        for j, other in enumerate(entries)
                        if j != idx_entry
                    )
                    target_curved = (
                        use_curved_edges
                        and not _is_solid(entry["style"])
                        and has_other_style
                    )
                    (curved_edges if target_curved else straight_edges).append(
                        entry["edge"]
                    )
                    (curved_colors if target_curved else straight_colors).append(
                        entry["color"]
                    )
                    (curved_styles if target_curved else straight_styles).append(
                        entry["style"]
                    )

        if straight_edges:
            nx.draw_networkx_edges(
                nx_graph,
                pos,
                edgelist=straight_edges,
                edge_color=straight_colors,
                style=straight_styles,
                ax=ax,
                width=2.5,
            )
        if curved_edges:
            nx.draw_networkx_edges(
                nx_graph,
                pos,
                edgelist=curved_edges,
                edge_color=curved_colors,
                style=curved_styles,
                ax=ax,
                connectionstyle="arc3,rad=0.1",
                width=2.5,
            )
        legend_elements.append(
            plt.Line2D(
                [0],
                [0],
                color=split_style["color"],
                label=f"{name}",
                linestyle=split_style["linestyle"],
                linewidth=2.5,
            )
        )
    if plot_messsage_passing_links:
        legend_elements.append(
            plt.Line2D(
                [0],
                [0],
                color="0.35",
                label="Przekazywanie informacji",
                linestyle="solid",
                linewidth=2.5,
            )
        )

    legend_elements.append(
        plt.Line2D(
            [0],
            [0],
            color="0.5",
            label="Krawędź negatywna",
            linestyle=":",
            linewidth=2.5,
        )
    )

    # NetworkX stosuje ciągłą numerację wierzchołków.
    # Jeśli mamy krawędź (3,4), to NetworkX oczekuje, że istnieją też wierzchołki 0, 1, 2.
    # Usuwamy je z wizualizacji, na potrzeby ilustracji połączeń krawędzi
    networkx_nodes = list(nx_graph.nodes())
    for node in networkx_nodes:
        if node not in nodes_with_edges:
            nx_graph.remove_node(node)

    nx.draw_networkx_nodes(
        nx_graph, pos, node_color="0.85", node_size=node_size, edgecolors="0.6", ax=ax
    )
    nx.draw_networkx_labels(nx_graph, pos, ax=ax, font_size=font_size)

    ax.legend(handles=legend_elements, loc=legend_loc)
    plt.title(title)
    return fig


def plot_multiple_subgraphs(
    subraph_indices: Dict[str, list],
    base_graph: nx.Graph,
    title: str,
    figsize: Tuple[float, float] = (12, 12),
) -> plt.Figure:
    """Pomocnicza funkcja rysująca wiele podgrafów na jednym wykresie.

    Parameters
    ----------
    subraph_indices : Dict[str, list]
        Słownik, w którym klucze to nazwy podgrafów, a wartości to listy wierzchołków w tych podgrafach.
    base_graph : nx.Graph
        Bazowy graf, na którym zostaną narysowane podgrafy.
    title: str
        Tytuł wykresu
    figsize: Tuple[float, float]
        Rozmiar wykresu w cm
    Returns
    -------
    plt.Figure
        Wykres z narysowanymi podgrafami
    """
    figsize_cm = image_size_in_cm(*figsize)
    fig, axs = plt.subplots(figsize=figsize_cm, nrows=len(subraph_indices), ncols=1)
    idx = 0
    pos = nx.spring_layout(base_graph, seed=123)
    unique_classes = set(
        cls for cls in nx.get_node_attributes(base_graph, "y").values()
    )
    legend_elements = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=_func_idx_to_gray(c),
            markersize=10,
            label=f"Klasa {c}",
        )
        for c in unique_classes
    ]
    for name, (nodes, _) in subraph_indices.items():
        subgraph = nx.subgraph(base_graph, nodes)
        node_colors = [
            _func_idx_to_gray(y) for y in nx.get_node_attributes(subgraph, "y").values()
        ]
        nx.draw_networkx_edges(
            subgraph, pos, ax=axs[idx], label=name, edge_color="0.6", width=2
        )
        nx.draw_networkx_nodes(
            subgraph, pos, ax=axs[idx], node_color=node_colors, node_size=220
        )
        axs[idx].set_title(name)
        idx += 1
    axs[0].legend(handles=legend_elements, loc="upper right")
    plt.suptitle(title)
    plt.tight_layout()
    return fig


def plot_snap_graph_split_edges(
    snap_graph: deepsnap_graph.Graph,
    base_pyg_graph: pyg_data.Data,
    split_name: str,
    split_color: str,
    **kwargs,
):
    pyg_from_snap = pyg_data.Data(
        edge_index=snap_graph[0].edge_index,
        edge_label=snap_graph[0].edge_label,
        edge_label_index=snap_graph[0].edge_label_index,
    )
    return plot_multiple_networks_with_edge_labels(
        base_pyg_graph, [(pyg_from_snap, split_color, split_name)], **kwargs
    )


def plot_batch_with_renaming(
    full_graph: nx.Graph,
    batch: pyg_data.Data,
    color_attr: str = "y",
    name_attr: str = "n_id",
    title: str = "",
    seed: int = 123,
    font_size: int = 12,
    node_size: int = 240,
    figsize: Tuple[float, float] = (12, 12),
    f: plt.Figure = None,
    ax: plt.Axes = None,
    **kwargs,
) -> plt.Figure:
    """Rysuje podgraf batcha z zachowaniem oryginalnych identyfikatorów i etykiet klas.

    Parameters
    ----------
    full_graph : nx.Graph
        Pełny graf służący do wyznaczenia położeń wierzchołków.
    batch : pyg_data.Data
        Podgraf batcha w formacie PyTorch Geometric.
    color_attr : str, optional
        Atrybut wierzchołków do kolorowania (w skali szarości), domyślnie 'y'.
    name_attr : str, optional
        Atrybut z nazwą wierzchołka używany przy relabelingu, domyślnie 'n_id'.
    title : str, optional
        Tytuł wykresu, domyślnie ''.
    seed : int, optional
        Ziarno układu współrzędnych, domyślnie 123.
    font_size : int, optional
        Wielkość czcionki etykiet i legendy, domyślnie 12.
    node_size : int, optional
        Rozmiar wierzchołków, domyślnie 240.
    figsize : Tuple[float, float], optional
        Rozmiar wykresu w centymetrach, domyślnie (12, 12).

    Returns
    -------
    plt.Figure
        Wykres przedstawiający batch z relabelingiem i legendą klas.
    """
    pos = nx.spring_layout(full_graph, seed=seed)
    attrs_to_take = []
    if name_attr:
        attrs_to_take.append(name_attr)
    if color_attr:
        attrs_to_take.append(color_attr)
    batch_nx = pyg.utils.to_networkx(batch, node_attrs=attrs_to_take)
    if name_attr:
        batch_nx = nx.relabel_nodes(
            batch_nx,
            {
                batch_idx: data[name_attr]
                for batch_idx, data in batch_nx.nodes(data=True)
            },
        )
    if color_attr:
        node_colors = {
            node: _func_idx_to_gray(data[color_attr])
            for node, data in batch_nx.nodes(data=True)
        }
    else:
        node_colors = {node: "0.85" for node in batch_nx.nodes()}
    if f is None and ax is None:
        kwargs["figsize"] = image_size_in_cm(*figsize)
        f, ax = plt.subplots(**kwargs)
    nx.draw(
        batch_nx,
        pos=pos,
        with_labels=True,
        node_color=list(node_colors.values()),
        node_size=node_size,
        ax=ax,
        edge_color="0.65",
        font_size=font_size,
    )
    legend_elements = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=_func_idx_to_gray(c),
            markersize=10,
            label=f"Klasa {c}",
        )
        for c in batch.y.unique().tolist()
    ]
    plt.tight_layout()
    ax.legend(handles=legend_elements, loc="best", fontsize=font_size)
    ax.set_title(title)
    return f
