Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions backend/common/lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any

from fastapi import FastAPI

LifespanFunc = Callable[[FastAPI], AbstractAsyncContextManager[dict[str, Any] | None]]


class LifespanManager:
"""FastAPI lifespan 管理器"""

def __init__(self) -> None:
self._lifespans: list[LifespanFunc] = []

def register(self, func: LifespanFunc) -> LifespanFunc:
"""
注册 lifespan hook

:param func: lifespan hook
:return:
"""
if func not in self._lifespans:
self._lifespans.append(func)
return func

def build(self) -> LifespanFunc:
"""
构建组合后的 lifespan hook

:return:
"""

@asynccontextmanager
async def combined_lifespan(app: FastAPI): # noqa: ANN202
state: dict[str, Any] = {}
async with AsyncExitStack() as exit_stack:
for lifespan_fn in self._lifespans:
result = await exit_stack.enter_async_context(lifespan_fn(app))
if isinstance(result, dict):
state.update(result)

for key, value in state.items():
setattr(app.state, key, value)

yield state or None

return combined_lifespan


# 创建 lifespan_manager 单例
lifespan_manager = LifespanManager()
9 changes: 7 additions & 2 deletions backend/core/registrar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from backend import __version__
from backend.common.cache.pubsub import cache_pubsub_manager
from backend.common.exception.exception_handler import register_exception
from backend.common.lifespan import lifespan_manager
from backend.common.log import set_custom_logfile, setup_logging
from backend.common.observability.otel import init_otel
from backend.common.response.response_code import StandardResponseCode
Expand All @@ -30,14 +31,15 @@
from backend.middleware.jwt_auth_middleware import JwtAuthMiddleware
from backend.middleware.opera_log_middleware import OperaLogMiddleware
from backend.middleware.state_middleware import StateMiddleware
from backend.plugin.core import build_final_router
from backend.plugin.core import build_final_router, setup_plugins
from backend.utils.demo_mode import demo_site
from backend.utils.openapi import ensure_unique_route_names, simplify_operation_ids
from backend.utils.serializers import MsgSpecJSONResponse
from backend.utils.snowflake import snowflake
from backend.utils.trace_id import OtelTraceIdPlugin


@lifespan_manager.register
@asynccontextmanager
async def register_init(app: FastAPI) -> AsyncGenerator[None, None]:
"""
Expand Down Expand Up @@ -84,7 +86,7 @@ def register_app() -> FastAPI:
redoc_url=settings.FASTAPI_REDOC_URL,
openapi_url=settings.FASTAPI_OPENAPI_URL,
default_response_class=MsgSpecJSONResponse,
lifespan=register_init,
lifespan=lifespan_manager.build(),
)

# 注册组件
Expand All @@ -96,6 +98,9 @@ def register_app() -> FastAPI:
register_page(app)
register_exception(app)

# 初始化插件
setup_plugins(app)

if settings.GRAFANA_METRICS_ENABLE:
register_metrics(app)

Expand Down
197 changes: 153 additions & 44 deletions backend/plugin/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
import os
import warnings
Expand All @@ -8,10 +9,11 @@
import anyio
import rtoml

from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends, FastAPI, Request

from backend.common.enums import DataBaseType, PluginLevelType, PrimaryKeyType, StatusType
from backend.common.exception import errors
from backend.common.lifespan import lifespan_manager
from backend.common.log import log
from backend.core.conf import settings
from backend.core.path_conf import PLUGIN_DIR
Expand Down Expand Up @@ -132,53 +134,130 @@ def load_plugin_config(plugin: str) -> dict[str, Any]:
return rtoml.load(f)


def get_plugin_enable(plugin_info: str | None, default_status: int) -> str:
"""
解析插件启用状态

:param plugin_info: 插件缓存信息
:param default_status: 默认状态值
:return:
"""
if not plugin_info:
return str(default_status)

try:
return json.loads(plugin_info)['plugin']['enable']
except Exception:
return str(default_status)


def get_enabled_plugins(plugins: tuple[str, ...] | None = None) -> set[str]:
"""
获取已启用的插件列表

:param plugins: 插件名称列表
:return:
"""
plugin_names = plugins or get_plugins()
enabled_plugins = set(plugin_names)

current_redis_client = RedisCli()
run_await(current_redis_client.init)()

try:
for plugin in plugin_names:
plugin_info = run_await(current_redis_client.get)(f'{settings.PLUGIN_REDIS_PREFIX}:{plugin}')
if get_plugin_enable(plugin_info, StatusType.enable.value) != str(StatusType.enable.value):
enabled_plugins.discard(plugin)
finally:
run_await(current_redis_client.aclose)()

return enabled_plugins


def register_plugin_lifespan_hook(plugin: str, module: Any) -> None:
"""
注册插件 lifespan hook

:param plugin: 插件名称
:param module: 插件 hooks 模块
:return:
"""
lifespan_hook = getattr(module, 'lifespan', None)
if lifespan_hook is None:
return

if not callable(lifespan_hook):
log.warning(f'插件 {plugin} 的 lifespan 不是可调用对象,已跳过')
return

lifespan_manager.register(lifespan_hook)
log.info(f'插件 {plugin} lifespan hook 注册成功')


def run_plugin_startup_hook(plugin: str, module: Any, app: FastAPI) -> None:
"""
执行插件 startup hook

:param plugin: 插件名称
:param module: 插件 hooks 模块
:param app: FastAPI 应用实例
:return:
"""
setup_hook = getattr(module, 'setup', None)
if setup_hook is None:
return

if not callable(setup_hook):
log.warning(f'插件 {plugin} 的 setup 不是可调用对象,已跳过')
return

setup_result = setup_hook(app)
if inspect.isawaitable(setup_result):
run_await(lambda: setup_result)() # type: ignore
log.info(f'插件 {plugin} startup hook 执行成功')


def parse_plugin_config() -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""解析插件配置"""
extend_plugins = []
app_plugins = []

plugins = get_plugins()

# 使用独立连接
current_redis_client = RedisCli()
run_await(current_redis_client.init)()

# 清理未知插件信息
exclude_keys = [f'{settings.PLUGIN_REDIS_PREFIX}:{key}' for key in plugins]
run_await(current_redis_client.delete_prefix)(
settings.PLUGIN_REDIS_PREFIX,
exclude=exclude_keys,
)

for plugin in plugins:
data = load_plugin_config(plugin)
plugin_type = validate_plugin_config(plugin, data)

if plugin_type == PluginLevelType.extend:
extend_plugins.append(data)
else:
app_plugins.append(data)

# 补充插件信息
data['plugin']['name'] = plugin
plugin_cache_key = f'{settings.PLUGIN_REDIS_PREFIX}:{plugin}'
plugin_cache_info = run_await(current_redis_client.get)(plugin_cache_key)
if plugin_cache_info:
try:
data['plugin']['enable'] = json.loads(plugin_cache_info)['plugin']['enable']
except Exception:
data['plugin']['enable'] = str(StatusType.enable.value)
else:
data['plugin']['enable'] = str(StatusType.enable.value)

# 缓存最新插件信息
run_await(current_redis_client.set)(plugin_cache_key, json.dumps(data, ensure_ascii=False))

# 重置插件变更状态
run_await(current_redis_client.delete)(f'{settings.PLUGIN_REDIS_PREFIX}:changed')

# 关闭连接
run_await(current_redis_client.aclose)()
try:
# 清理未知插件信息
exclude_keys = [f'{settings.PLUGIN_REDIS_PREFIX}:{key}' for key in plugins]
run_await(current_redis_client.delete_prefix)(
settings.PLUGIN_REDIS_PREFIX,
exclude=exclude_keys,
)

for plugin in plugins:
plugin_config = load_plugin_config(plugin)
plugin_type = validate_plugin_config(plugin, plugin_config)

if plugin_type == PluginLevelType.extend:
extend_plugins.append(plugin_config)
else:
app_plugins.append(plugin_config)

# 补充插件信息
plugin_config['plugin']['name'] = plugin
plugin_cache_key = f'{settings.PLUGIN_REDIS_PREFIX}:{plugin}'
plugin_cache_info = run_await(current_redis_client.get)(plugin_cache_key)
plugin_config['plugin']['enable'] = get_plugin_enable(plugin_cache_info, StatusType.enable.value)

# 缓存最新插件信息
run_await(current_redis_client.set)(plugin_cache_key, json.dumps(plugin_config, ensure_ascii=False))

# 重置插件变更状态
run_await(current_redis_client.delete)(f'{settings.PLUGIN_REDIS_PREFIX}:changed')
finally:
run_await(current_redis_client.aclose)()

return extend_plugins, app_plugins

Expand Down Expand Up @@ -288,6 +367,41 @@ def build_final_router() -> APIRouter:
return main_router


def setup_plugins(app: FastAPI) -> None:
"""
注册并执行插件 hooks

:param app: FastAPI 应用实例
:return:
"""
plugins = get_plugins()
enabled_plugins = get_enabled_plugins(plugins)

for plugin in plugins:
if plugin not in enabled_plugins:
log.info(f'插件 {plugin} 未启用,已跳过 hooks 注册与执行')
continue

module_path = f'backend.plugin.{plugin}.hooks'
try:
module = import_module_cached(module_path)
except ModuleNotFoundError as e:
if e.name == module_path:
# 未定义 hooks.py
continue
log.warning(f'插件 {plugin} hooks 模块加载失败: {e}')
continue
except Exception as e:
log.warning(f'插件 {plugin} hooks 模块加载失败: {e}')
continue

try:
register_plugin_lifespan_hook(plugin, module)
run_plugin_startup_hook(plugin, module, app)
except Exception as e:
log.error(f'插件 {plugin} hooks 执行失败: {e}')


class PluginStatusChecker:
"""插件状态检查器"""

Expand All @@ -312,10 +426,5 @@ async def __call__(self, request: Request) -> None:
log.error('插件状态未初始化或丢失,需重启服务自动修复')
raise PluginInjectError('插件状态未初始化或丢失,请联系系统管理员')

try:
is_enabled = int(json.loads(plugin_info)['plugin']['enable'])
except Exception:
is_enabled = 0

if not is_enabled:
if get_plugin_enable(plugin_info, StatusType.disable.value) != str(StatusType.enable.value):
raise errors.ServerError(msg=f'插件 {self.plugin} 未启用,请联系系统管理员')
25 changes: 25 additions & 0 deletions backend/plugin/patching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from fastapi import FastAPI
from starlette.middleware import Middleware


def replace_middleware(
app: FastAPI,
original_middleware_cls: type,
replacement_middleware_cls: type,
**replacement_kwargs,
) -> None:
"""
替换中间件(应在插件的 startup hook 中调用)

:param app: FastAPI 应用实例
:param original_middleware_cls: 原始中间件类
:param replacement_middleware_cls: 替换后的中间件类
:param replacement_kwargs: 传给替换后中间件的初始化参数
:return:
"""
for index, middleware in enumerate(app.user_middleware):
if middleware.cls is original_middleware_cls:
app.user_middleware[index] = Middleware(replacement_middleware_cls, **replacement_kwargs)
return

raise ValueError(f'{original_middleware_cls.__name__} not found in app.user_middleware')