Source code for maeser.discord_handler._discord_handler
# SPDX-License-Identifier: LGPL-3.0-or-later
import discord
import os
import re
from maeser.generate_response import handle_message, get_valid_course_ids, BOT_DATA_PATH
from maeser.config import COURSE_ID, DISCORD_BOT_TOKEN, DISCORD_INTRO
import shlex
import maeser.graphs.universal_rag as RAG_VARS
# Setup intents
intents = discord.Intents.default()
intents.messages = True
intents.message_content = True
intents.dm_messages = True
client = discord.Client(intents=intents)
# Helper: Extract figure references like "1_page3_fig2"
def extract_figures_from_text(text: str) -> list[str]:
"""Finds all references to figures in **text** and returns a list of figure IDs.
The reference to the figure should be formated like "Figure X.X".
Only "X.X" (the figure ID) will be extracted and added to the resulting list.
Args:
text (str): The text containing figure references.
Returns:
list[str]: A list of figure IDs.
"""
# Finds "Figure 13.2" and extracts just "13.2"
pattern = r"Figure (\d+\.\d+)"
return re.findall(pattern, text)
def split_string(text: str, max_length=1999) -> list[str]:
"""Splits one strings into multiple strings so that all resultant chunks are shorter
than **max_length**.
Checks for a semantically clean place to split the text first (e.g. at a whitespace).
If a clean place to split is not found, force-splits at **max_length**.
Args:
text (str): The text to split.
max_length (int, optional): The maximum length any chunk can be after splitting. Defaults to 1999.
Returns:
list[str]: The list of text chunks that make up the original **text**.
"""
chunks = []
while len(text) > max_length:
# Try to split at the last newline before max_length
split_index = text.rfind("\n", 0, max_length)
if split_index == -1:
# Try to split at the last space before max_length
split_index = text.rfind(" ", 0, max_length)
if split_index == -1:
# No good split point; force split
split_index = max_length
chunks.append(text[:split_index].strip())
text = text[split_index:].strip()
if text:
chunks.append(text)
return chunks
def is_admin_message(message: discord.Message) -> bool:
"""Checks to see if a message was sent by a channel administrator.
Args:
message (discord.Message): The message to check administrator privileges for.
Returns:
bool: True if the message sender is an administrator for the channel the message was sent in.
"""
return (
message.guild is not None
and message.channel.permissions_for(message.author).administrator
)
async def command_say(
channel: discord.abc.Messageable,
content: str,
) -> None:
"""Say **content** in **channel**.
Args:
channel (discord.abc.Messageable): The channel to send the message in.
content (str): The content of the message.
"""
await channel.send(content)
async def command_intro(channel: discord.abc.Messageable) -> None:
"""Sends the default intro message into **channel**.
This message can be configured in the "discord:intro" field in `config.yaml`.
All instances of "@self" in the text will be replaced with a mention to the
Discord bot (e.g. "@BotName").
Args:
channel (discord.abc.Messageable): The channel to send the intro message in.
"""
intro_content: str = DISCORD_INTRO.replace("@self", client.user.mention)
await channel.send(intro_content)
async def run_admin_command(message: discord.Message, command_args: list[str]) -> None:
channel = message.channel
argc: int = len(command_args)
match command_args[0]:
case "!say":
if argc < 2:
await channel.send(
"Usage: `!say [CONTENT]`\n"
'Use quotes around CONTENT (e.g. `!say "Hello World!"`)\n'
"Additional arguments will be sent on a new line."
)
return
say_text: str = "\n".join(command_args[1:])
await command_say(channel, say_text)
await message.delete()
case "!intro":
await command_intro(channel)
await message.delete()
@client.event
async def on_ready():
print(f"✅ Discord Bot connected as {client.user}")
@client.event
async def on_message(message: discord.Message):
# Get data from message
user_id = str(message.author.id)
msg_text = message.content.strip()
channel = message.channel
# Ignore if message is from bot or if message is blank
if message.author.bot or len(msg_text) == 0:
return
# Only run admin commands if message is not a DM
if not isinstance(channel, discord.DMChannel):
if is_admin_message(message):
msg_args = shlex.split(
msg_text
) # Gets list of args as if msg_text was a terminal command
# Bot must be mentioned first
# and message must contain a command (not just a bot mention)
if msg_args[0] == client.user.mention and len(msg_args) > 1:
command_args = msg_args[1:] # remove chatbot mention
await run_admin_command(message, command_args)
return
# -- MESSAGE PROCESSING --
async with channel.typing():
try:
reply = handle_message(user_id, COURSE_ID, msg_text)
# Send text reply
if len(reply) > 1999:
chunks = split_string(reply)
for chunk in chunks:
await channel.send(chunk)
else:
await channel.send(reply)
try:
# Extract and send figures if referenced
figure_dir = (
f"{BOT_DATA_PATH}/{COURSE_ID}/{RAG_VARS.recommended_topics[0]}"
)
figure_names = extract_figures_from_text(reply)
files = []
for fig in figure_names:
image_path = os.path.join(figure_dir, f"{fig}.png")
if os.path.exists(image_path):
files.append(discord.File(image_path, filename=f"{fig}.png"))
else:
print(f"[WARN] Figure not found: {image_path}")
if files:
await channel.send(files=files)
except Exception:
print("❌ There was an issue sending figures.")
except Exception as e:
await channel.send(f"❌ Error: {e}")
[docs]
def run_discord_handler(course_id: str = COURSE_ID, bot_token: str = DISCORD_BOT_TOKEN) -> None:
"""Runs the discord handler by setting up a RAG Graph with **course_id** and connecting
it to a discord bot with **bot_token**.
Args:
course_id (str, optional): The course ID the RAG Graph should use for context. Defaults to maeser.config.COURSE_ID.
bot_token (str, optional): _description_. Defaults to maeser.config.DISCORD_BOT_TOKEN.
"""
if course_id not in get_valid_course_ids():
print(f"ERROR: Course ID {course_id} not a valid course ID.")
exit(1)
client.run(bot_token)
if __name__ == "__main__":
run_discord_handler(COURSE_ID, DISCORD_BOT_TOKEN)