"""Czat z modułem pobierania i osadzeniami."""
import logging
import os
import tempfile

from langchain.chains import (
    ConversationalRetrievalChain,
    FlareChain,
    OpenAIModerationChain,
    SimpleSequentialChain,
)
from langchain.chains.base import Chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import DocArrayInMemorySearch

from czat_z_RAG.utils import MEMORY, load_document
from config import set_environment

logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()
set_environment()

# Konfiguracja LLM i łańcucha dla odpowiadania na pytani. Ustawienie niskiej temperatury w celu minimalizacji zjawiska halucynacji.
LLM = ChatOpenAI(
    model_name="gpt-3.5-turbo", temperature=0, streaming=True
)


def configure_retriever(
        docs: list[Document],
        use_compression: bool = False
) -> BaseRetriever:
    """Moduł pobierania."""
    # Podział każdego dokumentu:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
    splits = text_splitter.split_documents(docs)

    # Utworzenie osadzeń i zapisani ich vectordb:
    embeddings = OpenAIEmbeddings()
    # albo: HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    # Utworzenie vectordb z pojedynczym wywołaniem modelu osadzającego dla tekstów:
    vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
    retriever = vectordb.as_retriever(
        search_type="mmr", search_kwargs={
            "k": 5,
            "fetch_k": 7,
            "include_metadata": True
        },
    )
    if not use_compression:
        return retriever

    embeddings_filter = EmbeddingsFilter(
        embeddings=embeddings, similarity_threshold=0.2
    )
    return ContextualCompressionRetriever(
        base_compressor=embeddings_filter,
        base_retriever=retriever,
    )


def configure_chain(retriever: BaseRetriever, use_flare: bool = True) -> Chain:
    """Konfiguracja łańcucha z modułem pobierającym.

    Automatyczne przekazywanie maksymalnej liczby tokenów max_tokens_limit
    Ucinanie tokenów podczas wysyłania monitów do modelu LLM!
    """
    params = dict(
        llm=LLM,
        retriever=retriever,
        memory=MEMORY,
        verbose=True,
        max_tokens_limit=4000,
    )
    if use_flare:
        # różny zestaw parametrów
        # niestety trzeba użyć klasy typu "protected" 
        return FlareChain.from_llm(
            **params
        )
    return ConversationalRetrievalChain.from_llm(
        **params
    )


def configure_retrieval_chain(
        uploaded_files,
        use_compression: bool = False,
        use_flare: bool = False,
        use_moderation: bool = False
) -> Chain:
    """Wczytywanie dokumentów, konfiguracja modułu pobierającego i łańcucha."""
    docs = []
    temp_dir = tempfile.TemporaryDirectory()
    for file in uploaded_files:
        temp_filepath = os.path.join(temp_dir.name, file.name)
        with open(temp_filepath, "wb") as f:
            f.write(file.getvalue())
        docs.extend(load_document(temp_filepath))

    retriever = configure_retriever(docs=docs, use_compression=use_compression)
    chain = configure_chain(retriever=retriever, use_flare=use_flare)
    if not use_moderation:
        return chain

    moderation_chain = OpenAIModerationChain()
    return SimpleSequentialChain(chains=[chain, moderation_chain])

