diff --git a/backend/common/lifespan.py b/backend/common/lifespan.py new file mode 100644 index 000000000..534f104fc --- /dev/null +++ b/backend/common/lifespan.py @@ -0,0 +1,52 @@ +from collections.abc import Callable +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from typing import Any + +from fastapi import FastAPI + +LifespanFunc = Callable[[FastAPI], AbstractAsyncContextManager[dict[str, Any] | None]] + + +class LifespanManager: + """FastAPI lifespan 管理器""" + + def __init__(self) -> None: + self._lifespans: list[LifespanFunc] = [] + + def register(self, func: LifespanFunc) -> LifespanFunc: + """ + 注册 lifespan hook + + :param func: lifespan hook + :return: + """ + if func not in self._lifespans: + self._lifespans.append(func) + return func + + def build(self) -> LifespanFunc: + """ + 构建组合后的 lifespan hook + + :return: + """ + + @asynccontextmanager + async def combined_lifespan(app: FastAPI): # noqa: ANN202 + state: dict[str, Any] = {} + async with AsyncExitStack() as exit_stack: + for lifespan_fn in self._lifespans: + result = await exit_stack.enter_async_context(lifespan_fn(app)) + if isinstance(result, dict): + state.update(result) + + for key, value in state.items(): + setattr(app.state, key, value) + + yield state or None + + return combined_lifespan + + +# 创建 lifespan_manager 单例 +lifespan_manager = LifespanManager() diff --git a/backend/core/registrar.py b/backend/core/registrar.py index bc64058bd..a3e464b52 100644 --- a/backend/core/registrar.py +++ b/backend/core/registrar.py @@ -18,6 +18,7 @@ from backend import __version__ from backend.common.cache.pubsub import cache_pubsub_manager from backend.common.exception.exception_handler import register_exception +from backend.common.lifespan import lifespan_manager from backend.common.log import set_custom_logfile, setup_logging from backend.common.observability.otel import init_otel from backend.common.response.response_code import StandardResponseCode @@ -30,7 +31,7 @@ from backend.middleware.jwt_auth_middleware import JwtAuthMiddleware from backend.middleware.opera_log_middleware import OperaLogMiddleware from backend.middleware.state_middleware import StateMiddleware -from backend.plugin.core import build_final_router +from backend.plugin.core import build_final_router, setup_plugins from backend.utils.demo_mode import demo_site from backend.utils.openapi import ensure_unique_route_names, simplify_operation_ids from backend.utils.serializers import MsgSpecJSONResponse @@ -38,6 +39,7 @@ from backend.utils.trace_id import OtelTraceIdPlugin +@lifespan_manager.register @asynccontextmanager async def register_init(app: FastAPI) -> AsyncGenerator[None, None]: """ @@ -84,7 +86,7 @@ def register_app() -> FastAPI: redoc_url=settings.FASTAPI_REDOC_URL, openapi_url=settings.FASTAPI_OPENAPI_URL, default_response_class=MsgSpecJSONResponse, - lifespan=register_init, + lifespan=lifespan_manager.build(), ) # 注册组件 @@ -96,6 +98,9 @@ def register_app() -> FastAPI: register_page(app) register_exception(app) + # 初始化插件 + setup_plugins(app) + if settings.GRAFANA_METRICS_ENABLE: register_metrics(app) diff --git a/backend/plugin/core.py b/backend/plugin/core.py index cceb60de0..132911d8c 100644 --- a/backend/plugin/core.py +++ b/backend/plugin/core.py @@ -1,3 +1,4 @@ +import inspect import json import os import warnings @@ -8,10 +9,11 @@ import anyio import rtoml -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, FastAPI, Request from backend.common.enums import DataBaseType, PluginLevelType, PrimaryKeyType, StatusType from backend.common.exception import errors +from backend.common.lifespan import lifespan_manager from backend.common.log import log from backend.core.conf import settings from backend.core.path_conf import PLUGIN_DIR @@ -132,53 +134,130 @@ def load_plugin_config(plugin: str) -> dict[str, Any]: return rtoml.load(f) +def get_plugin_enable(plugin_info: str | None, default_status: int) -> str: + """ + 解析插件启用状态 + + :param plugin_info: 插件缓存信息 + :param default_status: 默认状态值 + :return: + """ + if not plugin_info: + return str(default_status) + + try: + return json.loads(plugin_info)['plugin']['enable'] + except Exception: + return str(default_status) + + +def get_enabled_plugins(plugins: tuple[str, ...] | None = None) -> set[str]: + """ + 获取已启用的插件列表 + + :param plugins: 插件名称列表 + :return: + """ + plugin_names = plugins or get_plugins() + enabled_plugins = set(plugin_names) + + current_redis_client = RedisCli() + run_await(current_redis_client.init)() + + try: + for plugin in plugin_names: + plugin_info = run_await(current_redis_client.get)(f'{settings.PLUGIN_REDIS_PREFIX}:{plugin}') + if get_plugin_enable(plugin_info, StatusType.enable.value) != str(StatusType.enable.value): + enabled_plugins.discard(plugin) + finally: + run_await(current_redis_client.aclose)() + + return enabled_plugins + + +def register_plugin_lifespan_hook(plugin: str, module: Any) -> None: + """ + 注册插件 lifespan hook + + :param plugin: 插件名称 + :param module: 插件 hooks 模块 + :return: + """ + lifespan_hook = getattr(module, 'lifespan', None) + if lifespan_hook is None: + return + + if not callable(lifespan_hook): + log.warning(f'插件 {plugin} 的 lifespan 不是可调用对象,已跳过') + return + + lifespan_manager.register(lifespan_hook) + log.info(f'插件 {plugin} lifespan hook 注册成功') + + +def run_plugin_startup_hook(plugin: str, module: Any, app: FastAPI) -> None: + """ + 执行插件 startup hook + + :param plugin: 插件名称 + :param module: 插件 hooks 模块 + :param app: FastAPI 应用实例 + :return: + """ + setup_hook = getattr(module, 'setup', None) + if setup_hook is None: + return + + if not callable(setup_hook): + log.warning(f'插件 {plugin} 的 setup 不是可调用对象,已跳过') + return + + setup_result = setup_hook(app) + if inspect.isawaitable(setup_result): + run_await(lambda: setup_result)() # type: ignore + log.info(f'插件 {plugin} startup hook 执行成功') + + def parse_plugin_config() -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """解析插件配置""" extend_plugins = [] app_plugins = [] - plugins = get_plugins() # 使用独立连接 current_redis_client = RedisCli() run_await(current_redis_client.init)() - # 清理未知插件信息 - exclude_keys = [f'{settings.PLUGIN_REDIS_PREFIX}:{key}' for key in plugins] - run_await(current_redis_client.delete_prefix)( - settings.PLUGIN_REDIS_PREFIX, - exclude=exclude_keys, - ) - - for plugin in plugins: - data = load_plugin_config(plugin) - plugin_type = validate_plugin_config(plugin, data) - - if plugin_type == PluginLevelType.extend: - extend_plugins.append(data) - else: - app_plugins.append(data) - - # 补充插件信息 - data['plugin']['name'] = plugin - plugin_cache_key = f'{settings.PLUGIN_REDIS_PREFIX}:{plugin}' - plugin_cache_info = run_await(current_redis_client.get)(plugin_cache_key) - if plugin_cache_info: - try: - data['plugin']['enable'] = json.loads(plugin_cache_info)['plugin']['enable'] - except Exception: - data['plugin']['enable'] = str(StatusType.enable.value) - else: - data['plugin']['enable'] = str(StatusType.enable.value) - - # 缓存最新插件信息 - run_await(current_redis_client.set)(plugin_cache_key, json.dumps(data, ensure_ascii=False)) - - # 重置插件变更状态 - run_await(current_redis_client.delete)(f'{settings.PLUGIN_REDIS_PREFIX}:changed') - - # 关闭连接 - run_await(current_redis_client.aclose)() + try: + # 清理未知插件信息 + exclude_keys = [f'{settings.PLUGIN_REDIS_PREFIX}:{key}' for key in plugins] + run_await(current_redis_client.delete_prefix)( + settings.PLUGIN_REDIS_PREFIX, + exclude=exclude_keys, + ) + + for plugin in plugins: + plugin_config = load_plugin_config(plugin) + plugin_type = validate_plugin_config(plugin, plugin_config) + + if plugin_type == PluginLevelType.extend: + extend_plugins.append(plugin_config) + else: + app_plugins.append(plugin_config) + + # 补充插件信息 + plugin_config['plugin']['name'] = plugin + plugin_cache_key = f'{settings.PLUGIN_REDIS_PREFIX}:{plugin}' + plugin_cache_info = run_await(current_redis_client.get)(plugin_cache_key) + plugin_config['plugin']['enable'] = get_plugin_enable(plugin_cache_info, StatusType.enable.value) + + # 缓存最新插件信息 + run_await(current_redis_client.set)(plugin_cache_key, json.dumps(plugin_config, ensure_ascii=False)) + + # 重置插件变更状态 + run_await(current_redis_client.delete)(f'{settings.PLUGIN_REDIS_PREFIX}:changed') + finally: + run_await(current_redis_client.aclose)() return extend_plugins, app_plugins @@ -288,6 +367,41 @@ def build_final_router() -> APIRouter: return main_router +def setup_plugins(app: FastAPI) -> None: + """ + 注册并执行插件 hooks + + :param app: FastAPI 应用实例 + :return: + """ + plugins = get_plugins() + enabled_plugins = get_enabled_plugins(plugins) + + for plugin in plugins: + if plugin not in enabled_plugins: + log.info(f'插件 {plugin} 未启用,已跳过 hooks 注册与执行') + continue + + module_path = f'backend.plugin.{plugin}.hooks' + try: + module = import_module_cached(module_path) + except ModuleNotFoundError as e: + if e.name == module_path: + # 未定义 hooks.py + continue + log.warning(f'插件 {plugin} hooks 模块加载失败: {e}') + continue + except Exception as e: + log.warning(f'插件 {plugin} hooks 模块加载失败: {e}') + continue + + try: + register_plugin_lifespan_hook(plugin, module) + run_plugin_startup_hook(plugin, module, app) + except Exception as e: + log.error(f'插件 {plugin} hooks 执行失败: {e}') + + class PluginStatusChecker: """插件状态检查器""" @@ -312,10 +426,5 @@ async def __call__(self, request: Request) -> None: log.error('插件状态未初始化或丢失,需重启服务自动修复') raise PluginInjectError('插件状态未初始化或丢失,请联系系统管理员') - try: - is_enabled = int(json.loads(plugin_info)['plugin']['enable']) - except Exception: - is_enabled = 0 - - if not is_enabled: + if get_plugin_enable(plugin_info, StatusType.disable.value) != str(StatusType.enable.value): raise errors.ServerError(msg=f'插件 {self.plugin} 未启用,请联系系统管理员') diff --git a/backend/plugin/patching.py b/backend/plugin/patching.py new file mode 100644 index 000000000..dc7477e59 --- /dev/null +++ b/backend/plugin/patching.py @@ -0,0 +1,25 @@ +from fastapi import FastAPI +from starlette.middleware import Middleware + + +def replace_middleware( + app: FastAPI, + original_middleware_cls: type, + replacement_middleware_cls: type, + **replacement_kwargs, +) -> None: + """ + 替换中间件(应在插件的 startup hook 中调用) + + :param app: FastAPI 应用实例 + :param original_middleware_cls: 原始中间件类 + :param replacement_middleware_cls: 替换后的中间件类 + :param replacement_kwargs: 传给替换后中间件的初始化参数 + :return: + """ + for index, middleware in enumerate(app.user_middleware): + if middleware.cls is original_middleware_cls: + app.user_middleware[index] = Middleware(replacement_middleware_cls, **replacement_kwargs) + return + + raise ValueError(f'{original_middleware_cls.__name__} not found in app.user_middleware')