Source code for maeser.admin_portal.design_model

# SPDX-License-Identifier: LGPL-3.0-or-later

"""
This module is used by flask_admin_portal.py and contains several helper functions that handle designing and editing class chatbot models.
"""

from flask import request
from werkzeug.utils import secure_filename
from werkzeug.datastructures import FileStorage
import os
import shutil
from maeser.admin_portal.extract_figures import extract_all_figures
from maeser.admin_portal.extract_text import extract_all_pdf_texts
from maeser.admin_portal.vector_store_operator import vectorize_data

[docs] def get_model_config(upload_root: str) -> tuple[ str, str, str, list[str], dict[str, list[FileStorage]] ]: """Retrieves config for a course model from a post request. Args: upload_root (str): The root directory where all course models are created and modified. Raises: AttributeError: If the request form is missing 'course_id'. Returns: ( course_id: str, model_dir: str, bot_path: str, rules: list[str], datasets: dict[str, list[FileStorage]] ): Parsed config data from the request form. See the parameters of `save_model` for more info. """ # Make sure course_id is defined course_id = request.form.get('course_id', '').strip() if not course_id: raise AttributeError("No course ID found in request form.") # Get important file paths model_dir = os.path.join(upload_root, secure_filename(course_id)) os.makedirs(model_dir, exist_ok=True) # TODO: remove this? bot_path = os.path.join(model_dir, 'bot.txt') # Get ruleset rules = request.form.getlist('rules[]') # Get datasets datasets = {} for key in request.files: if key.startswith('file_groups'): # Get dataset path idx = key.split('[')[1].split(']')[0] dataset_name = request.form.get(f'file_groups[{idx}][name]', f'Group_{idx}').lower() dataset_path = os.path.join(model_dir, secure_filename(dataset_name)) # os.makedirs(dataset_path, exist_ok=True) # Get files to go inside dataset files = request.files.getlist(f'file_groups[{idx}][files]') datasets[dataset_path] = files # # Save files # for f in files: # if f and f.filename.endswith('.pdf'): # filename = secure_filename(f.filename) # f.save(os.path.join(dataset_path, filename)) return course_id, model_dir, bot_path, rules, datasets
[docs] def save_model( upload_root: str, course_id: str, model_dir: str, bot_path: str, rules: list[str], datasets: dict[str, list[FileStorage]] ): """Saves the course model using the provided config. Args: upload_root (str): The root directory where all course models are created and modified. course_id (str): The code used to identify the class. model_dir (str): The directory to where the model's data will be saved. bot_path (str): The path to where the model's 'bot.txt' file will be saved. rules (list[str]): The list of rules for the model's chatbot to follow. datasets (dict[str, list[FileStorage]]): A dictionary containing the path to each dataset (key) and a list of its files (value). """ # Make model dir os.makedirs(model_dir, exist_ok=True) # Save datasets for dataset_path, files in datasets.items(): os.makedirs(dataset_path, exist_ok=True) for f in files: if f and f.filename.endswith('.pdf'): filename = secure_filename(f.filename) f.save(os.path.join(dataset_path, filename)) # Write bot.txt with all required sections with open(bot_path, 'w', encoding='utf-8') as bot_file: bot_file.write("#NAME\n") bot_file.write(f"{course_id}\n") bot_file.write("#RULES\n") for rule in rules: bot_file.write(f"{rule}\n") bot_file.write("#DATASETS\n") for dataset in os.listdir(model_dir): dataset_dir = os.path.join(model_dir, dataset) if os.path.isdir(dataset_dir): bot_file.write(f"{dataset.lower()}\n") # Process datasets process_datasets(model_dir)
[docs] def process_datasets(model_dir: str): print(f"Processing subdirectories in {model_dir}...") dirs = sorted([ os.path.join(model_dir, dir) for dir in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, dir)) ]) for dir in dirs: if os.path.exists(os.path.join(dir, "index.faiss")): print(f"(dataset for {dir} already exists, skipping.)") continue print(f"---- Processing {dir} ----") print("1. Converting PDFs to markdown (this may take a moment)...") extract_all_pdf_texts(dir) print("2. Extracting figures from PDFs...") extract_all_figures(dir) print("3. Running vector store operator...") vectorize_data(dir) print("4. Deleting .md and .pdf files...") for f in os.listdir(dir): if f.lower().endswith(".pdf") or f.lower().endswith(".md"): os.remove(os.path.join(dir, f)) print(f"✔ Completed {dir}")
[docs] def delete_datasets(model_dir: str): """Deletes specified datasets from a course model. The specified datasets are retrieved from the last request form. Args: model_dir (str): The directory containing the model's data. """ to_delete = request.form.getlist('delete_datasets[]') for dataset in to_delete: group_path = os.path.join(model_dir, secure_filename(dataset)) if os.path.exists(group_path) and os.path.isdir(group_path): shutil.rmtree(group_path)
[docs] def remove_class_model(upload_root: str, course_id: str): """Removes a course model from the root bot data directory. Args: upload_root (str): The root directory where all course models are created and modified. course_id (str): The code used to identify the class. Raises: NotADirectoryError: If the model's directory does not exist/cannot be found. """ print(f"Removing {course_id} from {upload_root} directory...") class_path = os.path.join(upload_root, course_id) try: if not os.path.isdir(class_path): raise NotADirectoryError(f"Unable to find directory {class_path}") shutil.rmtree(class_path) print(f"Successfuly removed {course_id}.") except Exception as e: print(f"Unable to remove {course_id}: {e}")
[docs] def load_rules(bot_path: str) -> list[str]: """Loads the rules for a course from the model's 'bot.txt' file. Args: bot_path (str): The path of the model's 'bot.txt' file. Returns: list[str]: A list of the model's rules. """ rules = [] if os.path.exists(bot_path): with open(bot_path, 'r', encoding='utf-8') as f: lines = f.readlines() collecting = False for line in lines: if line.strip() == "#RULES": collecting = True continue if line.strip().startswith("#") and collecting: break if collecting: rules.append(line.strip()) return rules
[docs] def load_datasets(model_dir: str) -> list[str]: """Loads the names of a course model's existing datasets. Args: model_dir (str): The directory containing the model's data. Returns: list[str]: The names of the model's datasets. """ datasets = sorted([ d for d in os.listdir(model_dir) if os.path.isdir(os.path.join(model_dir, d)) and d != '__pycache__' ]) return datasets