Source code for maeser.chat.chat_session_manager
"""
Module for managing chat sessions and interactions with multiple chat interfaces.
© 2024 Blaine Freestone, Carson Bush
This file is part of Maeser.
Maeser is free software: you can redistribute it and/or modify it under the terms of
the GNU Lesser General Public License as published by the Free Software Foundation,
either version 3 of the License, or (at your option) any later version.
Maeser is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with
Maeser. If not, see <https://www.gnu.org/licenses/>.
"""
from maeser.chat.chat_logs import BaseChatLogsManager
from maeser.user_manager import User
import time
from uuid import uuid4 as uid
from langchain_community.callbacks import get_openai_callback
from langgraph.graph.graph import CompiledGraph
[docs]
class ChatSessionManager:
"""
Manages and directs sessions for multiple chat interfaces.
"""
def __init__(
self,
chat_logs_manager: BaseChatLogsManager | None = None,
) -> None:
"""
Initializes the chat session manager.
Args:
chat_logs_manager (BaseChatLogsManager | None): The chat logs manager to use for logging chat data.
Returns:
None
"""
self.chat_logs_manager: BaseChatLogsManager | None = chat_logs_manager
self.graphs: dict = {}
[docs]
def register_branch(self, branch_name: str, branch_label: str, graph: CompiledGraph) -> None:
"""
Registers a branch with its information and graph.
Args:
branch_name (str): The name of the branch.
branch_label (str): The label of the branch.
graph (CompiledGraph): The graph for the branch.
Returns:
None
"""
self.graphs[branch_name] = {
'label': branch_label,
'graph': graph
}
[docs]
def get_new_session_id(self, branch_name: str, user: User | None = None) -> str:
"""
Creates a new chat session for the given branch action.
Includes creating a new log file for the session.
Args:
branch_name (str): The action of the branch to create a session for.
user (User | None): The user to create the session for.
Returns:
str: The session ID for the new session.
"""
# Generate session ID with user information if it exists
if user:
session_id: str = f'{uid()}-{user.auth_method}-{user.ident}'
else:
session_id: str = f'{uid()}-anon'
# Create log file if chat logs manager is available
if self.chat_logs_manager:
self.chat_logs_manager.log(branch_name, session_id, {'user': user})
return session_id
[docs]
def ask_question(self, message: str, branch_name: str, sess_id: str) -> dict:
"""
Asks a question in a specific session of a branch.
Args:
message (str): The question to ask.
branch_name (str): The action of the branch to ask the question in.
sess_id (str): The session ID to ask the question in.
Returns:
dict: The response to the question.
"""
config = {'configurable': {'thread_id': sess_id}}
start_time = time.time()
# Get token count for the response
with get_openai_callback() as cb:
response = self.graphs[branch_name]['graph'].invoke({
'messages': [message],
}, config=config)
response['tokens_used'] = cb.total_tokens
response['cost'] = cb.total_cost
end_time = time.time()
execution_time = end_time - start_time
response['execution_time'] = execution_time
if self.chat_logs_manager:
self.chat_logs_manager.log(branch_name, sess_id, response)
return response
[docs]
def add_feedback(self, branch_name: str, session_id: str, message_index: int, feedback: str) -> None:
"""
Adds feedback to the log for a specific response in a specific session.
Args:
branch_name (str): The name of the branch.
session_id (str): The session ID for the conversation.
message_index (int): The index of the message to add feedback to.
feedback (str): The feedback to add to the message.
Returns:
None
"""
# Return if no chat logs manager
if not self.chat_logs_manager:
return
self.chat_logs_manager.log_feedback(branch_name, session_id, message_index, feedback)
[docs]
def get_conversation_history(self, branch_name: str, session_id: str) -> dict:
"""
Gets the conversation history for a specific session in a specific branch.
Args:
branch_name (str): The action of the branch to get the conversation history from.
session_id (str): The session ID to get the conversation history from.
Returns:
dict: The conversation history for the session.
"""
if not self.chat_logs_manager:
return {}
return self.chat_logs_manager.get_chat_history(branch_name, session_id)
@property
def branches(self) -> dict:
"""
Returns the list of branches available for chat.
Returns:
dict: The list of branches available for chat.
"""
return self.graphs
@property
def chat_log_path(self) -> str | None:
"""
Returns the path to the logs directory.
Returns:
str | None: The path to the logs directory.
"""
return self.chat_logs_manager.chat_log_path if self.chat_logs_manager else None