Source code for maeser.user_manager

"""
User management module for authentication and authorization.

This module provides classes and utilities for managing users,
including authentication methods, database operations, and request tracking.

© 2024 Carson Bush, Blaine Freestone

This file is part of Maeser.

Maeser is free software: you can redistribute it and/or modify it under the terms of
the GNU Lesser General Public License as published by the Free Software Foundation,
either version 3 of the License, or (at your option) any later version.

Maeser is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along with
Maeser. If not, see <https://www.gnu.org/licenses/>.
"""


import secrets
import sqlite3
from abc import ABC, abstractmethod
from typing import Any, Tuple, Union
from urllib.parse import urlencode

import requests


[docs] class User: """ Provides default implementations for the methods that Flask-Login expects user objects to have. """ # Python 3 implicitly sets __hash__ to None if we override __eq__ # We set it back to its default implementation __hash__ = object.__hash__ def __init__(self, ident: str, blacklisted=False, admin=False, realname='Student', usergroup='b\'guest\'', authmethod='invalid', requests_left=10, max_requests=10, aka=[]): """ Initialize a User object. Args: ident (str): The user's identifier. blacklisted (bool, optional): Whether the user is blacklisted. Defaults to False. admin (bool, optional): Whether the user is an admin. Defaults to False. realname (str, optional): The user's real name. Defaults to 'Student'. usergroup (str, optional): The user's group. Defaults to 'b\'guest\''. authmethod (str, optional): The authentication method. Defaults to 'invalid'. requests_left (int, optional): The number of requests left. Defaults to 10. max_requests (int, optional): The maximum number of requests. Defaults to 10. aka (list, optional): A list of alternate names. Defaults to an empty list. """ self.ident = ident self.is_active = not blacklisted self.admin = admin self.realname = realname self.usergroup = usergroup self.auth_method = authmethod self._requests_remaining = requests_left self._max_requests = max_requests self.aka: list = aka def __str__(self) -> str: return f"""User Information for {self.ident}: Authentication Method: {self.auth_method} Real Name: {self.realname} Admin: {'Yes' if self.admin else 'No'} Banned: {'Yes' if not self.is_active else 'No'} User Group: {self.usergroup} Requests Remaining: {self.requests_remaining}/{self._max_requests}""" @property def json(self) -> dict[str, Any]: return { 'ident': self.ident, 'is_active': self.is_active, 'admin': self.admin, 'realname': self.realname, 'usergroup': self.usergroup, 'auth_method': self.auth_method, 'requests_remaining': self.requests_remaining, 'max_requests': self._max_requests, 'aka': self.aka } @property def is_authenticated(self): """Return True if the user is authenticated.""" return self.is_active @property def is_anonymous(self): """Return False, as anonymous users are not supported.""" return False
[docs] def get_id(self): """Return the user's full identifier name including authentication method.""" return self.full_id_name
@property def full_id_name(self): """Return the user's full identifier name including authentication method.""" return f'{self.auth_method}.{self.ident}' @property def requests_remaining(self): """Return the number of requests remaining for the user.""" return self._requests_remaining @requests_remaining.setter def requests_remaining(self, num: int): """ Set the number of requests remaining for the user. Args: num (int): The new number of requests remaining. """ if num >= self._max_requests: self._requests_remaining = self._max_requests elif num <= 0: self._requests_remaining = 0 else: self._requests_remaining = num def __eq__(self, other): """ Check the equality of two User objects using get_id. Args: other (User): The other user to compare. Returns: bool: True if the users are equal, False otherwise. """ if isinstance(other, User): return self.get_id() == other.get_id() return NotImplemented def __ne__(self, other): """ Check the inequality of two User objects using get_id. Args: other (User): The other user to compare. Returns: bool: True if the users are not equal, False otherwise. """ equal = self.__eq__(other) return not equal
[docs] class LoginStyle: def __init__(self, icon: str, login_submit: str, direct_submit: bool=False): # Not a url, but a controller name for url_for. i.e. 'maeser.github_authorize' or 'localauth' self.login_submit = login_submit self.direct_submit = direct_submit # HTML for a custom form (labels and inputs only) self._custom_form: str = '<label for="username" class="form-label">Username</label><input type="text" id="username" name="username" class="form-input" required><label for="password" class="form-label">Password</label><input type="password" id="password" name="password" class="form-input" required>' self.icon_html = f'<i class="bi bi-{icon}"></i>' @property def form_html(self) -> str: if self.direct_submit: raise ValueError("Cannot use form_html with direct_submit=True") return self._custom_form @form_html.setter def form_html(self, html: str): self._custom_form = html
[docs] class BaseAuthenticator(ABC): """ Base class for authenticators. """ @abstractmethod def __init__(self, *args, **kwargs): """ Initialize the authenticator with any required arguments. """ pass @abstractmethod def __str__(self) -> str: """Return the string representation of the authenticator.""" pass
[docs] @abstractmethod def authenticate(self, *args, **kwargs) -> Union[tuple, None]: """ Authenticate a user. Args: *args: Positional arguments for authentication. **kwargs: Keyword arguments for authentication. Returns: tuple or None: A tuple containing the user's username, real name, and user group if authentication is successful, otherwise None. """ pass
[docs] @abstractmethod def fetch_user(self, ident: str) -> Union[User, None]: """ Fetch a user from the authenticator. Args: ident (str): The identifier of the user to fetch. Returns: User or None: The fetched user object or None if not found. ENSURE THAT YOU SET max_requests TO THE CORRECT VALUE FOR THE USER! """ pass
@property @abstractmethod def style(self) -> LoginStyle: """ Get the login style for the authenticator. Returns: LoginStyle: The login style object. """ pass
[docs] class GithubAuthenticator(BaseAuthenticator): """ Handles authentication with GitHub OAuth. """ def __init__(self, client_id: str, client_secret: str, auth_callback_uri: str, timeout: int = 10, max_requests: int = 10): """ Initialize the GitHub authenticator. Args: client_id (str): The GitHub client ID. client_secret (str): The GitHub client secret. auth_callback_uri (str): The callback URI for GitHub authentication. """ self.client_id = client_id self.client_secret = client_secret # Generally this should be set from your Flask app as this will differ between applications # url_for('github_auth_callback', _external=True) self._max_requests = max_requests self.auth_callback_uri = auth_callback_uri self.timeout = timeout self._login_style = LoginStyle('github', 'maeser.github_authorize', direct_submit=True) def __str__(self) -> str: return 'GitHub' @property def style(self) -> LoginStyle: return self._login_style
[docs] def authenticate(self, request_args: dict, oauth_state: str) -> Union[tuple, None]: """ Authenticate a user with GitHub OAuth. Args: request_args (dict): The request arguments containing the authorization code and state. oauth_state (str): The state value used to prevent CSRF attacks. Returns: tuple or None: A tuple containing the user's username, real name, and user group if authentication is successful, otherwise None. """ if request_args['state'] != oauth_state or 'code' not in request_args: print(request_args['state'], oauth_state, 'ERROR') return None token_url = 'https://github.com/login/oauth/access_token' user_info_url = 'https://api.github.com/user' # exchange the authorization code for an access token response = requests.post(token_url, data={ 'client_id': self.client_id, 'client_secret': self.client_secret, 'code': request_args['code'], 'grant_type': 'authorization_code', 'redirect_uri': self.auth_callback_uri }, headers={'Accept': 'application/json'}, timeout=self.timeout) if response.status_code != 200: print(f'GitHub authentication failed during token exchange: {response.status_code}', 'ERROR') return None oauth2_token = response.json().get('access_token') if not oauth2_token: print('GitHub authentication failed: No access token received', 'ERROR') return None response = requests.get(user_info_url, headers={ 'Authorization': 'Bearer ' + oauth2_token, 'Accept': 'application/json', }, timeout=self.timeout) if response.status_code != 200: print(f'GitHub authentication failed when fetching user info: {response.status_code}', 'ERROR') return None json_response = response.json() print(json_response) return json_response['login'], json_response['name'], 'b\'guest\''
[docs] def fetch_user(self, ident: str) -> Union[User, None]: """ Fetch a user from the GitHub API. Args: ident (str): The username of the user to fetch. Returns: User or None: The fetched user object or None if the user is not found. """ user_info_url = f'https://api.github.com/users/{ident}' response = requests.get(user_info_url) if response.status_code == 200: json_response = response.json() return User(json_response['login'], realname=json_response.get('name', ''), usergroup='b\'guest\'', authmethod='github', max_requests=self._max_requests) print(f'WARNING: No GitHub user "{ident}" found') return None
[docs] def get_auth_info(self) -> Tuple[str, str]: """ Get the GitHub authorization information. Returns: tuple: A tuple containing the OAuth state and provider URL. """ authorize_url = 'https://github.com/login/oauth/authorize' scopes = ['user:email'] # generate a random string for the state parameter oauth_state = secrets.token_urlsafe(16) query_string = urlencode({ 'client_id': self.client_id, 'redirect_uri': self.auth_callback_uri, 'response_type': 'code', 'scope': ' '.join(scopes), 'state': oauth_state, }) provider_url = authorize_url + '?' + query_string return oauth_state, provider_url
[docs] class UserManager: """ Manages user operations including authentication, database interactions, and request tracking. """ def __init__(self, db_file_path: str, max_requests: int = 10, rate_limit_interval: int = 180): """ Initialize the UserManager. Args: db_file_path (str): The file path to the SQLite database. max_requests (int, optional): The maximum number of requests a user can have. Defaults to 10. """ self.db_file_path = db_file_path self.authenticators: dict[str, BaseAuthenticator] = {} self.max_requests = max_requests self.rate_limit_interval = rate_limit_interval self._create_tables()
[docs] def register_authenticator(self, name: str, authenticator: BaseAuthenticator): """ Register a new authentication method. Args: name (str): The shorthand name of the authentication method. Must only contain letters. authenticator (BaseAuthenticator): The authenticator object. Raises: ValueError: If the provided name is invalid or the authenticator is already registered. """ if not name.isalpha(): raise ValueError(f"Invalid authenticator name: {name}, must only contain letters!") self.authenticators[name] = authenticator with self.db_connection as db: self._create_table(db, name)
@property def db_connection(self) -> sqlite3.Connection: """ Open a connection to the SQLite database. Returns: sqlite3.Connection: The database connection. Raises: sqlite3.OperationalError: If the database cannot be opened. """ try: return sqlite3.connect(self.db_file_path) except sqlite3.OperationalError as e: print(f'Unable to open sqlite db, using tempory storage: {e}') return sqlite3.connect(':memory:') def _create_tables(self): with self.db_connection as db: for auth_method in self.authenticators: self._create_table(db, auth_method) def _create_table(self, db: sqlite3.Connection, auth_method: str): if not auth_method.isalnum(): raise ValueError(f"Invalid authenticator name: {auth_method}") table_name = f"{auth_method}Users" db.execute(f''' CREATE TABLE IF NOT EXISTS "{table_name}" ( user_id TEXT PRIMARY KEY, blacklisted BOOL, admin BOOL, realname TEXT, usertype TEXT, requests_left INT, aka TEXT ) ''')
[docs] def get_user(self, auth_method: str, ident: str) -> Union[User, None]: """ Retrieve a user from the database. Args: auth_method (str): The authentication method used. ident (str): The unique identifier of the user. Returns: User: The user object, or None if not found. Raises: ValueError: If the provided auth_method is invalid. """ if not auth_method.isalnum(): raise ValueError(f"Invalid authenticator name: {auth_method}") table_name = f"{auth_method}Users" with self.db_connection as db: cursor: sqlite3.Cursor = db.execute( f'SELECT user_id, blacklisted, admin, realname, usertype, requests_left FROM "{table_name}" WHERE user_id=?', (ident,) ) row = cursor.fetchone() if row: return User(row[0], bool(row[1]), bool(row[2]), realname=row[3], usergroup=str(row[4]), requests_left=row[5], authmethod=auth_method, max_requests=self.max_requests) return None
[docs] def list_users(self, auth_filter: str | None = None, admin_filter: str | None = None, banned_filter: str | None = None) -> list[User]: """ List all users in the database, optionally filtered by authentication method, admin status, and banned status. Args: auth_filter (str, optional): The authentication method to list users for. If None or 'all', list users from all authentication methods. admin_filter (str, optional): Filter users by admin status. Can be 'all', 'admin', or 'non-admin'. banned_filter (str, optional): Filter users by banned status. Can be 'all', 'banned', or 'non-banned'. Returns: list[User]: A list of user objects. Raises: ValueError: If the provided auth_method is invalid or if admin_filter or banned_filter have invalid values. """ if auth_filter is not None and auth_filter != 'all' and not auth_filter.isalnum(): raise ValueError(f"Invalid authenticator name: {auth_filter}") if admin_filter is not None and admin_filter not in ['all', 'admin', 'non-admin']: raise ValueError(f"Invalid admin_filter value: {admin_filter}") if banned_filter is not None and banned_filter not in ['all', 'banned', 'non-banned']: raise ValueError(f"Invalid banned_filter value: {banned_filter}") users: list[User] = [] with self.db_connection as db: if auth_filter is None or auth_filter == 'all': cursor: sqlite3.Cursor = db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE '%Users'") tables = [table_name[0] for table_name in cursor.fetchall()] else: tables = [f"{auth_filter}Users"] for table_name in tables: query = f'SELECT user_id, blacklisted, admin, realname, usertype, requests_left FROM "{table_name}"' conditions = [] if admin_filter == 'admin': conditions.append("admin = 1") elif admin_filter == 'non-admin': conditions.append("admin = 0") if banned_filter == 'banned': conditions.append("blacklisted = 1") elif banned_filter == 'non-banned': conditions.append("blacklisted = 0") if conditions: query += " WHERE " + " AND ".join(conditions) cursor = db.execute(query) for row in cursor.fetchall(): auth_method_from_table = table_name.replace("Users", "") users.append(User(row[0], bool(row[1]), bool(row[2]), realname=row[3], usergroup=str(row[4]), requests_left=row[5], authmethod=auth_method_from_table, max_requests=self.max_requests)) return users
[docs] def authenticate(self, auth_method: str, *args: Any, **kwargs: Any) -> Union[User, None]: """ Authenticate a user using the specified authentication method. Args: auth_method (str): The authentication method to use. *args: Positional arguments for the authentication method. **kwargs: Keyword arguments for the authentication method. Returns: User: The authenticated user object, or None if authentication fails. Raises: ValueError: If the provided auth_method is invalid. """ authenticator = self.authenticators.get(auth_method) if not authenticator: raise ValueError(f"Unsupported authentication method: {auth_method}") auth_result = authenticator.authenticate(*args, **kwargs) print(auth_result) if auth_result: user_id, display_name, user_group = auth_result return self._create_or_update_user(auth_method, user_id, display_name, user_group) return None
def _create_or_update_user(self, auth_method: str, user_id: str, display_name: str, user_group: str) -> User: """ Create or update a user in the database. Args: auth_method (str): The authentication method used. user_id (str): The unique identifier of the user. display_name (str): The display name of the user. user_group (str): The group the user belongs to. Returns: User: The user object. Raises: ValueError: If the provided auth_method is invalid. """ if auth_method not in self.authenticators: raise ValueError(f"Unsupported authentication method: {auth_method}") with self.db_connection as db: table_name = f"{auth_method}Users" cursor = db.execute(f'SELECT user_id, blacklisted, admin, realname, usertype, requests_left FROM "{table_name}" WHERE user_id=?', (user_id,)) row = cursor.fetchone() if row: user = User(row[0], bool(row[1]), bool(row[2]), realname=row[3], requests_left=row[5], authmethod=auth_method, max_requests=self.max_requests) else: db.execute( f'INSERT INTO "{table_name}" (user_id, blacklisted, admin, realname, usertype, requests_left) VALUES (?, ?, ?, ?, ?, ?)', (str(user_id), False, False, str(display_name), str(user_group), int(self.max_requests)) ) db.commit() user = User(user_id, realname=display_name, usergroup=user_group, authmethod=auth_method, max_requests=self.max_requests) return user
[docs] def update_admin_status(self, auth_method: str, ident: str, is_admin: bool): """ Update the admin status of a user. Args: auth_method (str): The authentication method used. ident (str): The identifier of the user. is_admin (bool): Whether the user should be an admin or not. Raises: ValueError: If the provided auth_method is invalid. """ if auth_method not in self.authenticators: raise ValueError(f"Invalid authenticator name: {auth_method}") table_name = f"{auth_method}Users" with self.db_connection as db: db.execute(f'UPDATE "{table_name}" SET admin=? WHERE user_id=?', (is_admin, ident)) db.commit()
[docs] def update_banned_status(self, auth_method: str, ident: str, is_banned: bool): """ Update the banned status of a user. Args: auth_method (str): The authentication method used. ident (str): The identifier of the user. is_banned (bool): Whether the user should be banned or not. Raises: ValueError: If the provided auth_method is invalid. """ if auth_method not in self.authenticators: raise ValueError(f"Invalid authenticator name: {auth_method}") table_name = f"{auth_method}Users" with self.db_connection as db: db.execute(f'UPDATE "{table_name}" SET blacklisted=? WHERE user_id=?', (is_banned, ident)) db.commit()
[docs] def refresh_requests(self, inc_by: int = 1): """ Refresh the number of requests for all users by the given amount. Args: inc_by (int, optional): The amount to increase the requests by. Defaults to 1. """ with self.db_connection as db: for auth_method in self.authenticators: table_name = f"{auth_method}Users" db.execute(f''' UPDATE "{table_name}" SET requests_left = MIN(?, MAX(0, requests_left + ?)) ''', (self.max_requests, inc_by)) db.commit()
[docs] def decrease_requests(self, auth_method: str, user_id: str, dec_by: int = 1): """ Decrease the number of requests remaining for a user. Args: auth_method (str): The authentication method used. user_id (str): The identifier of the user. dec_by (int, optional): The amount to decrease the requests by. Defaults to 1. Raises: ValueError: If the provided auth_method is invalid. """ if auth_method not in self.authenticators: raise ValueError(f"Invalid authenticator name: {auth_method}") dec_by = min(dec_by, self.max_requests) table_name = f"{auth_method}Users" with self.db_connection as db: db.execute(f''' UPDATE "{table_name}" SET requests_left = MAX(0, requests_left - ?) WHERE user_id = ? ''', (dec_by, user_id)) db.commit()
[docs] def increase_requests(self, auth_method: str, user_id: str, inc_by: int = 1): """ Increase the number of requests remaining for a user. Args: auth_method (str): The authentication method used. user_id (str): The identifier of the user. inc_by (int, optional): The amount to increase the requests by. Defaults to 1. Raises: ValueError: If the provided auth_method is invalid. """ if auth_method not in self.authenticators: raise ValueError(f"Invalid authenticator name: {auth_method}") inc_by = min(inc_by, self.max_requests) table_name = f"{auth_method}Users" with self.db_connection as db: db.execute(f''' UPDATE "{table_name}" SET requests_left = MIN(?, MAX(0, requests_left + ?)) WHERE user_id = ? ''', (self.max_requests, inc_by, user_id))
[docs] def get_requests_remaining(self, auth_method: str, user_id: str) -> Union[int, None]: """ Get the number of requests remaining for a user. Args: auth_method (str): The authentication method used. user_id (str): The identifier of the user. Returns: Union[int, None]: The number of requests remaining, or None if the user is not found. Raises: ValueError: If the provided auth_method is invalid. """ if auth_method not in self.authenticators: raise ValueError(f"Invalid authenticator name: {auth_method}") user = self.get_user(auth_method, user_id) return user.requests_remaining if user else None
[docs] def fetch_user(self, auth_method: str, ident: str) -> bool: """ Fetch a user from the authentication source and add them to the cache without modifying their admin or banned status. Args: auth_method (str): The authentication method ('caedm' or 'github'). ident (str): The user's identifier. Returns: bool: True if the user was successfully fetched and cached, False otherwise. """ if auth_method not in self.authenticators: raise ValueError(f"Invalid authenticator name: {auth_method}") user = self.authenticators[auth_method].fetch_user(ident) if user: self._create_or_update_user(auth_method, user.ident, user.realname, user.usergroup) return True return False
[docs] def remove_user_from_cache(self, auth_method: str, ident: str) -> bool: """ Remove a user from the cache. Args: auth_method (str): The authentication method used. ident (str): The identifier of the user. Returns: bool: True if the user was removed, False otherwise. Raises: ValueError: If the provided auth_method is invalid. """ if auth_method not in self.authenticators: raise ValueError(f"Invalid authenticator name: {auth_method}") table_name = f"{auth_method}Users" with self.db_connection as db: cursor = db.execute(f'DELETE FROM "{table_name}" WHERE user_id=?', (ident, )) db.commit() return bool(cursor.rowcount)
[docs] def list_cleanables(self): """ List non-banned and non-admin users in the cache/database. Returns: list[str]: A list of user identifiers in the format "auth_method:user_id". """ cleanables = [] with self.db_connection as db: for auth_method in self.authenticators: table_name = f"{auth_method}Users" cursor = db.execute(f'SELECT user_id FROM "{table_name}" WHERE blacklisted=0 AND admin=0') cleanables.extend([f'{auth_method}:{row[0]}' for row in cursor.fetchall()]) return cleanables
[docs] def clean_cache(self) -> int: """ Clean the cache by removing non-banned and non-admin users. Returns: int: The number of users removed from the cache. """ removed_count = 0 with self.db_connection as db: for auth_method in self.authenticators: table_name = f"{auth_method}Users" removed = db.execute(f'DELETE FROM "{table_name}" WHERE blacklisted=0 AND admin=0').rowcount removed_count += removed return removed_count