Skip to content

Commit 4e8009d

Browse files
committed
Squashed 'astrbot-sdk/' changes from 9724f6230..56943300b
56943300b chore: refresh vendor snapshot [skip ci] 215e06572 Merge pull request #104 from united-pooh:dev ea10d593d feat: 增强插件导入机制,支持命名空间和动态导入 REVERT: 9724f6230 chore: refresh vendor snapshot [skip ci] git-subtree-dir: astrbot-sdk git-subtree-split: 56943300b29038e57ab859d19d900e2e967a3a8a
1 parent 35feedd commit 4e8009d

1 file changed

Lines changed: 168 additions & 51 deletions

File tree

src/astrbot_sdk/runtime/loader.py

Lines changed: 168 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252

5353
from __future__ import annotations
5454

55+
import builtins
5556
import copy
57+
import hashlib
5658
import importlib
5759
import inspect
5860
import json
@@ -62,6 +64,7 @@
6264
import shutil
6365
import sys
6466
import threading
67+
import types
6568
import typing
6669
from dataclasses import dataclass, field
6770
from importlib import import_module
@@ -72,6 +75,7 @@
7275

7376
from .._internal.command_model import resolve_command_model_param
7477
from .._internal.injected_params import is_framework_injected_parameter
78+
from .._internal.invocation_context import caller_plugin_scope, current_caller_plugin_id
7579
from .._internal.plugin_ids import (
7680
capability_belongs_to_plugin,
7781
plugin_capability_prefix,
@@ -116,12 +120,16 @@
116120
_LOGGER = logging.getLogger(__name__)
117121
_PLUGIN_IMPORT_LOCK = threading.RLock()
118122
_VALID_HANDLER_KINDS: tuple[HandlerKind, ...] = ("handler", "hook", "tool", "session")
123+
_PLUGIN_PACKAGE_PREFIX = "astrbot_ext_"
119124
_GITHUB_REPO_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
120125
_GITHUB_REPO_SLUG_RE = re.compile(r"^[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+$")
121126
_GITHUB_REPO_URL_RE = re.compile(
122127
r"^https://github\.com/[A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+/?$",
123128
re.IGNORECASE,
124129
)
130+
_PLUGIN_IMPORT_NAMESPACES: dict[str, _PluginImportNamespace] = {}
131+
_ORIGINAL_BUILTIN_IMPORT = builtins.__import__
132+
_PLUGIN_IMPORT_HOOK_INSTALLED = False
125133

126134

127135
def _default_python_version() -> str:
@@ -234,13 +242,35 @@ class _ResolvedComponent:
234242
index: int
235243

236244

245+
@dataclass(slots=True)
246+
class _PluginImportNamespace:
247+
plugin_id: str
248+
plugin_dir: Path
249+
package_name: str
250+
251+
237252
@dataclass(slots=True)
238253
class _ParamTypeInfo:
239254
type_name: ParamTypeName
240255
inner_type: OptionalInnerType
241256
required: bool
242257

243258

259+
def _sanitize_package_component(plugin_id: str) -> str:
260+
sanitized = re.sub(r"[^A-Za-z0-9_]+", "_", plugin_id).strip("_")
261+
return sanitized or "plugin"
262+
263+
264+
def _plugin_package_name(plugin_id: str) -> str:
265+
digest = hashlib.sha256(plugin_id.encode("utf-8")).hexdigest()[:8]
266+
return f"{_PLUGIN_PACKAGE_PREFIX}{_sanitize_package_component(plugin_id)}_{digest}"
267+
268+
269+
def _plugin_module_name(package_name: str, module_name: str) -> str:
270+
normalized = module_name.strip()
271+
return f"{package_name}.{normalized}" if normalized else package_name
272+
273+
244274
def _iter_handler_names(instance: Any) -> list[str]:
245275
handler_names = getattr(instance.__class__, "__handlers__", ())
246276
if handler_names:
@@ -669,7 +699,7 @@ def _plugin_component_classes(plugin: PluginSpec) -> list[_ResolvedComponent]:
669699
"必须是 '<module>:<Class>'。"
670700
)
671701
try:
672-
cls = import_string(class_path, plugin.plugin_dir)
702+
cls = _import_plugin_string(class_path, plugin)
673703
except Exception as exc:
674704
raise ValueError(
675705
f"{_component_context(plugin, class_path=class_path, index=index)} "
@@ -1168,11 +1198,13 @@ def load_plugin(plugin: PluginSpec) -> LoadedPlugin:
11681198
"""
11691199
with _PLUGIN_IMPORT_LOCK:
11701200
_LOGGER.debug("Loading SDK plugin %s from %s", plugin.name, plugin.plugin_dir)
1171-
plugin_path = str(plugin.plugin_dir)
1172-
if plugin_path not in sys.path:
1173-
sys.path.insert(0, plugin_path)
1201+
_ensure_plugin_import_hook_installed()
1202+
namespace = _register_plugin_import_namespace(plugin)
11741203
_purge_plugin_bytecode(plugin.plugin_dir)
1204+
_purge_plugin_package(namespace.package_name)
11751205
_purge_plugin_modules(plugin.plugin_dir)
1206+
_ensure_plugin_package(namespace)
1207+
importlib.invalidate_caches()
11761208

11771209
instances: list[Any] = []
11781210
handlers: list[LoadedHandler] = []
@@ -1181,29 +1213,30 @@ def load_plugin(plugin: PluginSpec) -> LoadedPlugin:
11811213
agents: list[LoadedAgent] = []
11821214
seen_agents: set[str] = set()
11831215
seen_capability_sources: dict[str, str] = {}
1184-
resolved_components = _plugin_component_classes(plugin)
1185-
1186-
for resolved_component in resolved_components:
1187-
instance = _load_component_instance(plugin, resolved_component)
1188-
instances.append(instance)
1189-
agents.extend(
1190-
_collect_component_agents(
1191-
plugin,
1192-
resolved_component.cls,
1193-
seen_agents=seen_agents,
1216+
with caller_plugin_scope(plugin.name):
1217+
resolved_components = _plugin_component_classes(plugin)
1218+
1219+
for resolved_component in resolved_components:
1220+
instance = _load_component_instance(plugin, resolved_component)
1221+
instances.append(instance)
1222+
agents.extend(
1223+
_collect_component_agents(
1224+
plugin,
1225+
resolved_component.cls,
1226+
seen_agents=seen_agents,
1227+
)
11941228
)
1195-
)
1196-
component_handlers, component_capabilities, component_tools = (
1197-
_collect_component_members(
1198-
plugin,
1199-
resolved_component=resolved_component,
1200-
instance=instance,
1201-
seen_capability_sources=seen_capability_sources,
1229+
component_handlers, component_capabilities, component_tools = (
1230+
_collect_component_members(
1231+
plugin,
1232+
resolved_component=resolved_component,
1233+
instance=instance,
1234+
seen_capability_sources=seen_capability_sources,
1235+
)
12021236
)
1203-
)
1204-
handlers.extend(component_handlers)
1205-
capabilities.extend(component_capabilities)
1206-
llm_tools.extend(component_tools)
1237+
handlers.extend(component_handlers)
1238+
capabilities.extend(component_capabilities)
1239+
llm_tools.extend(component_tools)
12071240

12081241
_LOGGER.debug(
12091242
"Loaded SDK plugin %s: %d components, %d handlers, %d capabilities, %d llm tools, %d agents",
@@ -1238,6 +1271,45 @@ def _plugin_defines_module_root(plugin_dir: Path, root_name: str) -> bool:
12381271
).exists()
12391272

12401273

1274+
def _register_plugin_import_namespace(plugin: PluginSpec) -> _PluginImportNamespace:
1275+
existing = _PLUGIN_IMPORT_NAMESPACES.get(plugin.name)
1276+
package_name = (
1277+
existing.package_name
1278+
if existing is not None
1279+
else _plugin_package_name(plugin.name)
1280+
)
1281+
namespace = _PluginImportNamespace(
1282+
plugin_id=plugin.name,
1283+
plugin_dir=plugin.plugin_dir.resolve(),
1284+
package_name=package_name,
1285+
)
1286+
_PLUGIN_IMPORT_NAMESPACES[plugin.name] = namespace
1287+
return namespace
1288+
1289+
1290+
def _ensure_plugin_package(namespace: _PluginImportNamespace) -> types.ModuleType:
1291+
existing = sys.modules.get(namespace.package_name)
1292+
if isinstance(existing, types.ModuleType):
1293+
existing.__path__ = [str(namespace.plugin_dir)]
1294+
existing.__package__ = namespace.package_name
1295+
return existing
1296+
1297+
module = types.ModuleType(namespace.package_name)
1298+
module.__file__ = str(namespace.plugin_dir)
1299+
module.__package__ = namespace.package_name
1300+
module.__path__ = [str(namespace.plugin_dir)]
1301+
module.__loader__ = None
1302+
spec = importlib.machinery.ModuleSpec(
1303+
namespace.package_name,
1304+
loader=None,
1305+
is_package=True,
1306+
)
1307+
spec.submodule_search_locations = [str(namespace.plugin_dir)]
1308+
module.__spec__ = spec
1309+
sys.modules[namespace.package_name] = module
1310+
return module
1311+
1312+
12411313
def _module_belongs_to_plugin(module: Any, plugin_dir: Path) -> bool:
12421314
file_path = getattr(module, "__file__", None)
12431315
if isinstance(file_path, str) and _path_within_root(Path(file_path), plugin_dir):
@@ -1261,6 +1333,12 @@ def _purge_plugin_modules(plugin_dir: Path) -> None:
12611333
sys.modules.pop(module_name, None)
12621334

12631335

1336+
def _purge_plugin_package(package_name: str) -> None:
1337+
for module_name in list(sys.modules):
1338+
if module_name == package_name or module_name.startswith(f"{package_name}."):
1339+
sys.modules.pop(module_name, None)
1340+
1341+
12641342
def _purge_plugin_bytecode(plugin_dir: Path) -> None:
12651343
plugin_root = plugin_dir.resolve()
12661344
for path in plugin_root.rglob("*"):
@@ -1274,43 +1352,82 @@ def _purge_plugin_bytecode(plugin_dir: Path) -> None:
12741352
continue
12751353

12761354

1277-
def _purge_module_root(root_name: str) -> None:
1278-
for module_name in list(sys.modules):
1279-
if module_name == root_name or module_name.startswith(f"{root_name}."):
1280-
sys.modules.pop(module_name, None)
1355+
def _import_plugin_string(path: str, plugin: PluginSpec) -> Any:
1356+
module_name, attr = path.split(":", 1)
1357+
namespace = _PLUGIN_IMPORT_NAMESPACES.get(plugin.name)
1358+
if namespace is None:
1359+
raise RuntimeError(f"plugin import namespace missing: {plugin.name}")
1360+
module = import_module(_plugin_module_name(namespace.package_name, module_name))
1361+
return getattr(module, attr)
12811362

12821363

1283-
def _prepare_plugin_import(module_name: str, plugin_dir: Path | None) -> None:
1284-
if plugin_dir is None:
1285-
return
1364+
def _plugin_import_namespace_for_current_caller() -> _PluginImportNamespace | None:
1365+
plugin_id = current_caller_plugin_id()
1366+
if not plugin_id:
1367+
return None
1368+
return _PLUGIN_IMPORT_NAMESPACES.get(plugin_id)
12861369

1287-
plugin_root = plugin_dir.resolve()
1288-
plugin_path = str(plugin_root)
1289-
sys.path[:] = [entry for entry in sys.path if entry != plugin_path]
1290-
sys.path.insert(0, plugin_path)
12911370

1292-
root_name = module_name.split(".", 1)[0]
1293-
if not _plugin_defines_module_root(plugin_root, root_name):
1294-
return
1371+
def _rewrite_plugin_import_name(
1372+
namespace: _PluginImportNamespace,
1373+
name: str,
1374+
) -> str | None:
1375+
normalized = name.strip()
1376+
if not normalized:
1377+
return None
1378+
if normalized.startswith(_PLUGIN_PACKAGE_PREFIX):
1379+
return None
1380+
root_name = normalized.split(".", 1)[0]
1381+
if not _plugin_defines_module_root(namespace.plugin_dir, root_name):
1382+
return None
1383+
return _plugin_module_name(namespace.package_name, normalized)
1384+
1385+
1386+
def _plugin_scoped_import(
1387+
name: str,
1388+
globals: dict[str, Any] | None = None,
1389+
locals: dict[str, Any] | None = None,
1390+
fromlist: tuple[Any, ...] | list[Any] = (),
1391+
level: int = 0,
1392+
) -> Any:
1393+
with _PLUGIN_IMPORT_LOCK:
1394+
if level != 0:
1395+
return _ORIGINAL_BUILTIN_IMPORT(name, globals, locals, fromlist, level)
1396+
1397+
namespace = _plugin_import_namespace_for_current_caller()
1398+
rewritten_name = (
1399+
None if namespace is None else _rewrite_plugin_import_name(namespace, name)
1400+
)
1401+
if rewritten_name is None:
1402+
return _ORIGINAL_BUILTIN_IMPORT(name, globals, locals, fromlist, level)
1403+
1404+
imported = _ORIGINAL_BUILTIN_IMPORT(
1405+
rewritten_name,
1406+
globals,
1407+
locals,
1408+
fromlist,
1409+
level,
1410+
)
1411+
if fromlist:
1412+
return imported
1413+
root_name = name.split(".", 1)[0].strip()
1414+
root_module = sys.modules.get(
1415+
_plugin_module_name(namespace.package_name, root_name)
1416+
)
1417+
return root_module if root_module is not None else imported
12951418

1296-
cached_root = sys.modules.get(root_name)
1297-
cached_module = sys.modules.get(module_name)
1298-
if cached_root is not None and not _module_belongs_to_plugin(
1299-
cached_root, plugin_root
1300-
):
1301-
_purge_module_root(root_name)
1302-
elif cached_module is not None and not _module_belongs_to_plugin(
1303-
cached_module, plugin_root
1304-
):
1305-
_purge_module_root(root_name)
13061419

1307-
importlib.invalidate_caches()
1420+
def _ensure_plugin_import_hook_installed() -> None:
1421+
global _PLUGIN_IMPORT_HOOK_INSTALLED
1422+
if _PLUGIN_IMPORT_HOOK_INSTALLED:
1423+
return
1424+
builtins.__import__ = _plugin_scoped_import
1425+
_PLUGIN_IMPORT_HOOK_INSTALLED = True
13081426

13091427

13101428
def import_string(path: str, plugin_dir: Path | None = None) -> Any:
13111429
"""通过字符串路径导入对象。"""
13121430
with _PLUGIN_IMPORT_LOCK:
13131431
module_name, attr = path.split(":", 1)
1314-
_prepare_plugin_import(module_name, plugin_dir)
13151432
module = import_module(module_name)
13161433
return getattr(module, attr)

0 commit comments

Comments
 (0)