Skip to content

Commit 122e6c7

Browse files
authored
fix: make desktop plugin dependency loading safer on Windows (#7446)
* fix: make desktop plugin dependency loading safer on Windows * fix: restore dependency recovery after precheck fallback * test: cover version mismatch reinstall path * refactor: clarify dependency recovery state handling * style: format star manager with ruff * fix: skip dependency recovery for plugin import errors * fix: surface unexpected dependency recovery failures
1 parent 9c14a50 commit 122e6c7

8 files changed

Lines changed: 756 additions & 37 deletions

File tree

astrbot/core/star/star_manager.py

Lines changed: 116 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import sys
1212
import tempfile
1313
import traceback
14+
from dataclasses import dataclass
15+
from enum import Enum, auto
1416
from types import ModuleType
1517

1618
import yaml
@@ -37,6 +39,7 @@
3739
from astrbot.core.utils.io import remove_dir
3840
from astrbot.core.utils.metrics import Metric
3941
from astrbot.core.utils.requirements_utils import (
42+
MissingRequirementsPlan,
4043
plan_missing_requirements_install,
4144
)
4245

@@ -77,6 +80,19 @@ def __init__(
7780
self.error = error
7881

7982

83+
class ImportDependencyRecoveryMode(Enum):
84+
DISABLED = auto()
85+
PRELOAD_AND_RECOVER = auto()
86+
RECOVER_ON_FAILURE = auto()
87+
REINSTALL_ON_FAILURE = auto()
88+
89+
90+
@dataclass(frozen=True)
91+
class ImportDependencyRecoveryState:
92+
mode: ImportDependencyRecoveryMode
93+
install_plan: MissingRequirementsPlan | None = None
94+
95+
8096
@contextlib.contextmanager
8197
def _temporary_filtered_requirements_file(
8298
*,
@@ -137,7 +153,10 @@ async def _install_requirements_with_precheck(
137153
requirements_path,
138154
fallback_reason,
139155
)
140-
await pip_installer.install(requirements_path=requirements_path)
156+
await pip_installer.install(
157+
requirements_path=requirements_path,
158+
allow_target_upgrade=bool(install_plan.version_mismatch_names),
159+
)
141160
return
142161

143162
logger.info(
@@ -148,7 +167,10 @@ async def _install_requirements_with_precheck(
148167
with _temporary_filtered_requirements_file(
149168
install_lines=install_plan.install_lines,
150169
) as filtered_requirements_path:
151-
await pip_installer.install(requirements_path=filtered_requirements_path)
170+
await pip_installer.install(
171+
requirements_path=filtered_requirements_path,
172+
allow_target_upgrade=bool(install_plan.version_mismatch_names),
173+
)
152174

153175

154176
class PluginManager:
@@ -332,33 +354,106 @@ async def _ensure_plugin_requirements(
332354
logger.exception(str(dependency_error))
333355
raise dependency_error from e
334356

357+
@staticmethod
358+
def _resolve_import_dependency_recovery_state(
359+
requirements_path: str,
360+
*,
361+
reserved: bool,
362+
) -> ImportDependencyRecoveryState:
363+
if reserved or not os.path.exists(requirements_path):
364+
return ImportDependencyRecoveryState(ImportDependencyRecoveryMode.DISABLED)
365+
366+
install_plan = plan_missing_requirements_install(requirements_path)
367+
if install_plan is None:
368+
return ImportDependencyRecoveryState(
369+
ImportDependencyRecoveryMode.RECOVER_ON_FAILURE
370+
)
371+
if install_plan.version_mismatch_names:
372+
return ImportDependencyRecoveryState(
373+
ImportDependencyRecoveryMode.REINSTALL_ON_FAILURE,
374+
install_plan=install_plan,
375+
)
376+
377+
return ImportDependencyRecoveryState(
378+
ImportDependencyRecoveryMode.PRELOAD_AND_RECOVER,
379+
install_plan=install_plan,
380+
)
381+
382+
@staticmethod
383+
def _try_import_from_installed_dependencies(
384+
path: str,
385+
module_str: str,
386+
root_dir_name: str,
387+
requirements_path: str,
388+
import_exc: Exception,
389+
) -> ModuleType | None:
390+
try:
391+
logger.info(
392+
f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}"
393+
)
394+
pip_installer.prefer_installed_dependencies(
395+
requirements_path=requirements_path
396+
)
397+
module = __import__(path, fromlist=[module_str])
398+
logger.info(
399+
f"插件 {root_dir_name} 已从 site-packages 恢复依赖,跳过重新安装。"
400+
)
401+
return module
402+
except (ImportError, ModuleNotFoundError) as recover_exc:
403+
logger.info(
404+
f"插件 {root_dir_name} 已安装依赖恢复失败,将重新安装依赖: {recover_exc!s}"
405+
)
406+
return None
407+
335408
async def _import_plugin_with_dependency_recovery(
336409
self,
337410
path: str,
338411
module_str: str,
339412
root_dir_name: str,
340413
requirements_path: str,
414+
*,
415+
reserved: bool = False,
341416
) -> ModuleType:
417+
recovery_state = self._resolve_import_dependency_recovery_state(
418+
requirements_path,
419+
reserved=reserved,
420+
)
421+
422+
if recovery_state.mode is ImportDependencyRecoveryMode.PRELOAD_AND_RECOVER:
423+
try:
424+
pip_installer.prefer_installed_dependencies(
425+
requirements_path=requirements_path
426+
)
427+
except Exception as preload_exc:
428+
logger.info(
429+
f"插件 {root_dir_name} 预加载已安装依赖失败,将继续常规导入: {preload_exc!s}"
430+
)
431+
342432
try:
343433
return __import__(path, fromlist=[module_str])
344-
except (ModuleNotFoundError, ImportError) as import_exc:
345-
if os.path.exists(requirements_path):
346-
try:
347-
logger.info(
348-
f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}"
349-
)
350-
pip_installer.prefer_installed_dependencies(
351-
requirements_path=requirements_path
352-
)
353-
module = __import__(path, fromlist=[module_str])
354-
logger.info(
355-
f"插件 {root_dir_name} 已从 site-packages 恢复依赖,跳过重新安装。"
356-
)
357-
return module
358-
except Exception as recover_exc:
359-
logger.info(
360-
f"插件 {root_dir_name} 已安装依赖恢复失败,将重新安装依赖: {recover_exc!s}"
361-
)
434+
except ModuleNotFoundError as import_exc:
435+
if recovery_state.mode in {
436+
ImportDependencyRecoveryMode.PRELOAD_AND_RECOVER,
437+
ImportDependencyRecoveryMode.RECOVER_ON_FAILURE,
438+
}:
439+
recovered_module = self._try_import_from_installed_dependencies(
440+
path,
441+
module_str,
442+
root_dir_name,
443+
requirements_path,
444+
import_exc,
445+
)
446+
if recovered_module is not None:
447+
return recovered_module
448+
elif (
449+
recovery_state.mode is ImportDependencyRecoveryMode.REINSTALL_ON_FAILURE
450+
):
451+
assert recovery_state.install_plan is not None
452+
logger.info(
453+
"插件 %s 预检查检测到版本不匹配,跳过已安装依赖恢复: %s",
454+
root_dir_name,
455+
sorted(recovery_state.install_plan.version_mismatch_names),
456+
)
362457

363458
await self._check_plugin_dept_update(target_plugin=root_dir_name)
364459
return __import__(path, fromlist=[module_str])
@@ -788,6 +883,7 @@ async def load(
788883
module_str=module_str,
789884
root_dir_name=root_dir_name,
790885
requirements_path=requirements_path,
886+
reserved=reserved,
791887
)
792888
except Exception as e:
793889
error_trace = traceback.format_exc()

astrbot/core/utils/pip_installer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,7 @@ async def install(
982982
package_name: str | None = None,
983983
requirements_path: str | None = None,
984984
mirror: str | None = None,
985+
allow_target_upgrade: bool = True,
985986
) -> None:
986987
args, requested_requirements = self._build_pip_args(
987988
package_name, requirements_path, mirror
@@ -995,15 +996,17 @@ async def install(
995996
target_site_packages = get_astrbot_site_packages_path()
996997
os.makedirs(target_site_packages, exist_ok=True)
997998
_prepend_sys_path(target_site_packages)
998-
args.extend(
999-
[
1000-
"--target",
1001-
target_site_packages,
1002-
"--upgrade",
1003-
"--upgrade-strategy",
1004-
"only-if-needed",
1005-
]
1006-
)
999+
# `allow_target_upgrade` only matters for packaged desktop installs that
1000+
# write into the shared `data/site-packages` target directory.
1001+
args.extend(["--target", target_site_packages])
1002+
if allow_target_upgrade:
1003+
args.extend(
1004+
[
1005+
"--upgrade",
1006+
"--upgrade-strategy",
1007+
"only-if-needed",
1008+
]
1009+
)
10071010

10081011
with self._core_constraints.constraints_file() as constraints_file_path:
10091012
if constraints_file_path:

astrbot/core/utils/requirements_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,17 @@ class ParsedPackageInput:
2929
requirement_names: frozenset[str]
3030

3131

32+
@dataclass(frozen=True)
33+
class MissingRequirementsAnalysis:
34+
missing_names: frozenset[str]
35+
version_mismatch_names: frozenset[str] = frozenset()
36+
37+
3238
@dataclass(frozen=True)
3339
class MissingRequirementsPlan:
3440
missing_names: frozenset[str]
3541
install_lines: tuple[str, ...]
42+
version_mismatch_names: frozenset[str] = frozenset()
3643
fallback_reason: str | None = None
3744

3845

@@ -394,24 +401,39 @@ def find_missing_requirements(requirements_path: str) -> set[str] | None:
394401
def find_missing_requirements_from_lines(
395402
requirement_lines: Sequence[str],
396403
) -> set[str] | None:
404+
analysis = classify_missing_requirements_from_lines(requirement_lines)
405+
if analysis is None:
406+
return None
407+
408+
return set(analysis.missing_names)
409+
410+
411+
def classify_missing_requirements_from_lines(
412+
requirement_lines: Sequence[str],
413+
) -> MissingRequirementsAnalysis | None:
397414
required = list(iter_requirements(lines=requirement_lines))
398415
if not required:
399-
return set()
416+
return MissingRequirementsAnalysis(missing_names=frozenset())
400417

401418
installed = collect_installed_distribution_versions(get_requirement_check_paths())
402419
if installed is None:
403420
return None
404421

405422
missing: set[str] = set()
423+
version_mismatch_names: set[str] = set()
406424
for name, specifier in required:
407425
installed_version = installed.get(name)
408426
if not installed_version:
409427
missing.add(name)
410428
continue
411429
if specifier and not _specifier_contains_version(specifier, installed_version):
412430
missing.add(name)
431+
version_mismatch_names.add(name)
413432

414-
return missing
433+
return MissingRequirementsAnalysis(
434+
missing_names=frozenset(missing),
435+
version_mismatch_names=frozenset(version_mismatch_names),
436+
)
415437

416438

417439
def build_missing_requirements_install_lines(
@@ -449,9 +471,11 @@ def plan_missing_requirements_install(
449471
if not can_precheck or requirement_lines is None:
450472
return None
451473

452-
missing = find_missing_requirements_from_lines(requirement_lines)
453-
if missing is None:
474+
analysis = classify_missing_requirements_from_lines(requirement_lines)
475+
if analysis is None:
454476
return None
477+
missing = analysis.missing_names
478+
version_mismatch_names = analysis.version_mismatch_names
455479

456480
install_lines = build_missing_requirements_install_lines(
457481
requirements_path,
@@ -468,12 +492,14 @@ def plan_missing_requirements_install(
468492
)
469493
return MissingRequirementsPlan(
470494
missing_names=frozenset(missing),
495+
version_mismatch_names=frozenset(version_mismatch_names),
471496
install_lines=(),
472497
fallback_reason="unmapped missing requirement names",
473498
)
474499

475500
return MissingRequirementsPlan(
476501
missing_names=frozenset(missing),
502+
version_mismatch_names=frozenset(version_mismatch_names),
477503
install_lines=install_lines,
478504
)
479505

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def check_env() -> None:
5151

5252
site_packages_path = get_astrbot_site_packages_path()
5353
if site_packages_path not in sys.path:
54-
sys.path.insert(0, site_packages_path)
54+
sys.path.append(site_packages_path)
5555

5656
os.makedirs(get_astrbot_config_path(), exist_ok=True)
5757
os.makedirs(get_astrbot_plugin_path(), exist_ok=True)

tests/test_main.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,48 @@ def test_check_env(monkeypatch):
6060
check_env()
6161

6262

63+
def test_check_env_appends_user_site_packages_after_runtime_paths(monkeypatch):
64+
astrbot_root = "/tmp/astrbot-root"
65+
site_packages_path = "/tmp/astrbot-site-packages"
66+
original_sys_path = list(sys.path)
67+
68+
monkeypatch.setattr(sys, "version_info", _version_info(3, 12))
69+
monkeypatch.setattr("main.get_astrbot_root", lambda: astrbot_root)
70+
monkeypatch.setattr("main.get_astrbot_site_packages_path", lambda: site_packages_path)
71+
monkeypatch.setattr("main.get_astrbot_config_path", lambda: "/tmp/config")
72+
monkeypatch.setattr("main.get_astrbot_plugin_path", lambda: "/tmp/plugins")
73+
monkeypatch.setattr("main.get_astrbot_temp_path", lambda: "/tmp/temp")
74+
monkeypatch.setattr("main.get_astrbot_knowledge_base_path", lambda: "/tmp/kb")
75+
monkeypatch.setattr(sys, "path", ["/runtime/lib", *original_sys_path])
76+
77+
with mock.patch("os.makedirs"):
78+
check_env()
79+
80+
assert sys.path[0] == astrbot_root
81+
assert sys.path[-1] == site_packages_path
82+
assert sys.path.index(site_packages_path) > sys.path.index("/runtime/lib")
83+
84+
85+
def test_check_env_does_not_append_duplicate_user_site_packages(monkeypatch):
86+
astrbot_root = "/tmp/astrbot-root"
87+
site_packages_path = "/tmp/astrbot-site-packages"
88+
original_sys_path = list(sys.path)
89+
90+
monkeypatch.setattr(sys, "version_info", _version_info(3, 12))
91+
monkeypatch.setattr("main.get_astrbot_root", lambda: astrbot_root)
92+
monkeypatch.setattr("main.get_astrbot_site_packages_path", lambda: site_packages_path)
93+
monkeypatch.setattr("main.get_astrbot_config_path", lambda: "/tmp/config")
94+
monkeypatch.setattr("main.get_astrbot_plugin_path", lambda: "/tmp/plugins")
95+
monkeypatch.setattr("main.get_astrbot_temp_path", lambda: "/tmp/temp")
96+
monkeypatch.setattr("main.get_astrbot_knowledge_base_path", lambda: "/tmp/kb")
97+
monkeypatch.setattr(sys, "path", [astrbot_root, *original_sys_path, site_packages_path])
98+
99+
with mock.patch("os.makedirs"):
100+
check_env()
101+
102+
assert sys.path.count(site_packages_path) == 1
103+
104+
63105
def test_version_info_comparisons():
64106
"""Test _version_info comparison operators with tuples and other instances."""
65107
v3_10 = _version_info(3, 10)

0 commit comments

Comments
 (0)