Source code for maeser.chat.chat_session_manager
# SPDX-License-Identifier: LGPL-3.0-or-later
"""
Module for managing chat sessions and interactions with multiple chat interfaces.
"""
from maeser.chat.chat_logs import BaseChatLogsManager
from maeser.user_manager import User
import time
from uuid import uuid4 as uid
from langchain_community.callbacks import get_openai_callback
from langgraph.graph.graph import CompiledGraph
[docs]
class ChatSessionManager:
"""
Manages and directs sessions for multiple chat interfaces.
Args:
chat_logs_manager (BaseChatLogsManager | None):
The chat logs manager to use for logging chat data.
This can be a ChatLogsManager object or a custom chat logs manager
that inherits from BaseChatLogsManager.
Returns:
None
"""
def __init__(
self,
chat_logs_manager: BaseChatLogsManager | None = None,
) -> None:
self.chat_logs_manager: BaseChatLogsManager | None = chat_logs_manager
self.graphs: dict = {}
[docs]
def register_branch(
self, branch_name: str, branch_label: str, graph: CompiledGraph
) -> None:
"""
Registers a new chat branch with its name, label, and compiled RAG graph.
See maeser.graphs for built-in RAG graphs.
Args:
branch_name (str): The name of the branch.
branch_label (str): The label of the branch.
graph (CompiledGraph): The compiled RAG graph for the branch.
Returns:
None
"""
self.graphs[branch_name] = {"label": branch_label, "graph": graph}
[docs]
def get_new_session_id(self, branch_name: str, user: User | None = None) -> str:
"""
Creates a new chat session for the given branch and user.
Includes creating a new log file for the session.
If no user is provided, "anon" will be used in place of ``authenticator.user_id``.
Args:
branch_name (str): The action of the branch to create a session for.
user (User | None): The user to create the session for.
Returns:
str: The session ID for the new session.
"""
# Generate session ID with user information if it exists
if user:
session_id: str = f"{uid()}-{user.auth_method}-{user.ident}"
else:
session_id: str = f"{uid()}-anon"
# Create log file if chat logs manager is available
if self.chat_logs_manager:
self.chat_logs_manager.log(branch_name, session_id, {"user": user})
return session_id
[docs]
def ask_question(self, message: str, branch_name: str, sess_id: str) -> dict:
"""
Asks a question in a specific session of a branch.
Args:
message (str): The question to ask.
branch_name (str): The chat branch to ask the question in.
sess_id (str): The session ID to ask the question in.
Returns:
dict: The response to the question.
"""
config = {"configurable": {"thread_id": sess_id}}
start_time = time.time()
# Get token count for the response
with get_openai_callback() as cb:
response = self.graphs[branch_name]["graph"].invoke(
{
"messages": [message],
},
config=config,
)
response["tokens_used"] = cb.total_tokens
response["cost"] = cb.total_cost
end_time = time.time()
execution_time = end_time - start_time
response["execution_time"] = execution_time
if self.chat_logs_manager:
self.chat_logs_manager.log(branch_name, sess_id, response)
return response
[docs]
def add_feedback(
self, branch_name: str, session_id: str, message_index: int, feedback: str
) -> None:
"""
Adds feedback to the log for a specific response in a specific session.
Args:
branch_name (str): The name of the branch.
session_id (str): The session ID for the conversation.
message_index (int): The index of the message to add feedback to.
feedback (str): The feedback to add to the message.
Returns:
None
"""
# Return if no chat logs manager
if not self.chat_logs_manager:
return
self.chat_logs_manager.log_feedback(
branch_name, session_id, message_index, feedback
)
[docs]
def get_conversation_history(self, branch_name: str, session_id: str) -> dict:
"""
Gets the conversation history for a specific session in a specific branch.
Args:
branch_name (str): The action of the branch to get the conversation history from.
session_id (str): The session ID to get the conversation history from.
Returns:
dict: The conversation history for the session.
"""
if not self.chat_logs_manager:
return {}
return self.chat_logs_manager.get_chat_history(branch_name, session_id)
@property
def branches(self) -> dict:
"""dict: The list of branches available for chat."""
return self.graphs
@property
def chat_log_path(self) -> str | None:
"""str | None: The path to the logs directory."""
return self.chat_logs_manager.chat_log_path if self.chat_logs_manager else None