diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 25df73f642..1ec0b83662 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -11,6 +11,8 @@ import sys import tempfile import traceback +from dataclasses import dataclass +from enum import Enum, auto from types import ModuleType import yaml @@ -37,6 +39,7 @@ from astrbot.core.utils.io import remove_dir from astrbot.core.utils.metrics import Metric from astrbot.core.utils.requirements_utils import ( + MissingRequirementsPlan, plan_missing_requirements_install, ) @@ -77,6 +80,19 @@ def __init__( self.error = error +class ImportDependencyRecoveryMode(Enum): + DISABLED = auto() + PRELOAD_AND_RECOVER = auto() + RECOVER_ON_FAILURE = auto() + REINSTALL_ON_FAILURE = auto() + + +@dataclass(frozen=True) +class ImportDependencyRecoveryState: + mode: ImportDependencyRecoveryMode + install_plan: MissingRequirementsPlan | None = None + + @contextlib.contextmanager def _temporary_filtered_requirements_file( *, @@ -137,7 +153,10 @@ async def _install_requirements_with_precheck( requirements_path, fallback_reason, ) - await pip_installer.install(requirements_path=requirements_path) + await pip_installer.install( + requirements_path=requirements_path, + allow_target_upgrade=bool(install_plan.version_mismatch_names), + ) return logger.info( @@ -148,7 +167,10 @@ async def _install_requirements_with_precheck( with _temporary_filtered_requirements_file( install_lines=install_plan.install_lines, ) as filtered_requirements_path: - await pip_installer.install(requirements_path=filtered_requirements_path) + await pip_installer.install( + requirements_path=filtered_requirements_path, + allow_target_upgrade=bool(install_plan.version_mismatch_names), + ) class PluginManager: @@ -332,33 +354,106 @@ async def _ensure_plugin_requirements( logger.exception(str(dependency_error)) raise dependency_error from e + @staticmethod + def _resolve_import_dependency_recovery_state( + requirements_path: str, + *, + reserved: bool, + ) -> ImportDependencyRecoveryState: + if reserved or not os.path.exists(requirements_path): + return ImportDependencyRecoveryState(ImportDependencyRecoveryMode.DISABLED) + + install_plan = plan_missing_requirements_install(requirements_path) + if install_plan is None: + return ImportDependencyRecoveryState( + ImportDependencyRecoveryMode.RECOVER_ON_FAILURE + ) + if install_plan.version_mismatch_names: + return ImportDependencyRecoveryState( + ImportDependencyRecoveryMode.REINSTALL_ON_FAILURE, + install_plan=install_plan, + ) + + return ImportDependencyRecoveryState( + ImportDependencyRecoveryMode.PRELOAD_AND_RECOVER, + install_plan=install_plan, + ) + + @staticmethod + def _try_import_from_installed_dependencies( + path: str, + module_str: str, + root_dir_name: str, + requirements_path: str, + import_exc: Exception, + ) -> ModuleType | None: + try: + logger.info( + f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}" + ) + pip_installer.prefer_installed_dependencies( + requirements_path=requirements_path + ) + module = __import__(path, fromlist=[module_str]) + logger.info( + f"插件 {root_dir_name} 已从 site-packages 恢复依赖,跳过重新安装。" + ) + return module + except (ImportError, ModuleNotFoundError) as recover_exc: + logger.info( + f"插件 {root_dir_name} 已安装依赖恢复失败,将重新安装依赖: {recover_exc!s}" + ) + return None + async def _import_plugin_with_dependency_recovery( self, path: str, module_str: str, root_dir_name: str, requirements_path: str, + *, + reserved: bool = False, ) -> ModuleType: + recovery_state = self._resolve_import_dependency_recovery_state( + requirements_path, + reserved=reserved, + ) + + if recovery_state.mode is ImportDependencyRecoveryMode.PRELOAD_AND_RECOVER: + try: + pip_installer.prefer_installed_dependencies( + requirements_path=requirements_path + ) + except Exception as preload_exc: + logger.info( + f"插件 {root_dir_name} 预加载已安装依赖失败,将继续常规导入: {preload_exc!s}" + ) + try: return __import__(path, fromlist=[module_str]) - except (ModuleNotFoundError, ImportError) as import_exc: - if os.path.exists(requirements_path): - try: - logger.info( - f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}" - ) - pip_installer.prefer_installed_dependencies( - requirements_path=requirements_path - ) - module = __import__(path, fromlist=[module_str]) - logger.info( - f"插件 {root_dir_name} 已从 site-packages 恢复依赖,跳过重新安装。" - ) - return module - except Exception as recover_exc: - logger.info( - f"插件 {root_dir_name} 已安装依赖恢复失败,将重新安装依赖: {recover_exc!s}" - ) + except ModuleNotFoundError as import_exc: + if recovery_state.mode in { + ImportDependencyRecoveryMode.PRELOAD_AND_RECOVER, + ImportDependencyRecoveryMode.RECOVER_ON_FAILURE, + }: + recovered_module = self._try_import_from_installed_dependencies( + path, + module_str, + root_dir_name, + requirements_path, + import_exc, + ) + if recovered_module is not None: + return recovered_module + elif ( + recovery_state.mode is ImportDependencyRecoveryMode.REINSTALL_ON_FAILURE + ): + assert recovery_state.install_plan is not None + logger.info( + "插件 %s 预检查检测到版本不匹配,跳过已安装依赖恢复: %s", + root_dir_name, + sorted(recovery_state.install_plan.version_mismatch_names), + ) await self._check_plugin_dept_update(target_plugin=root_dir_name) return __import__(path, fromlist=[module_str]) @@ -788,6 +883,7 @@ async def load( module_str=module_str, root_dir_name=root_dir_name, requirements_path=requirements_path, + reserved=reserved, ) except Exception as e: error_trace = traceback.format_exc() diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 8aad8db75a..e5f7138209 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -982,6 +982,7 @@ async def install( package_name: str | None = None, requirements_path: str | None = None, mirror: str | None = None, + allow_target_upgrade: bool = True, ) -> None: args, requested_requirements = self._build_pip_args( package_name, requirements_path, mirror @@ -995,15 +996,17 @@ async def install( target_site_packages = get_astrbot_site_packages_path() os.makedirs(target_site_packages, exist_ok=True) _prepend_sys_path(target_site_packages) - args.extend( - [ - "--target", - target_site_packages, - "--upgrade", - "--upgrade-strategy", - "only-if-needed", - ] - ) + # `allow_target_upgrade` only matters for packaged desktop installs that + # write into the shared `data/site-packages` target directory. + args.extend(["--target", target_site_packages]) + if allow_target_upgrade: + args.extend( + [ + "--upgrade", + "--upgrade-strategy", + "only-if-needed", + ] + ) with self._core_constraints.constraints_file() as constraints_file_path: if constraints_file_path: diff --git a/astrbot/core/utils/requirements_utils.py b/astrbot/core/utils/requirements_utils.py index f2c749db1a..969976a4fc 100644 --- a/astrbot/core/utils/requirements_utils.py +++ b/astrbot/core/utils/requirements_utils.py @@ -29,10 +29,17 @@ class ParsedPackageInput: requirement_names: frozenset[str] +@dataclass(frozen=True) +class MissingRequirementsAnalysis: + missing_names: frozenset[str] + version_mismatch_names: frozenset[str] = frozenset() + + @dataclass(frozen=True) class MissingRequirementsPlan: missing_names: frozenset[str] install_lines: tuple[str, ...] + version_mismatch_names: frozenset[str] = frozenset() fallback_reason: str | None = None @@ -394,15 +401,26 @@ def find_missing_requirements(requirements_path: str) -> set[str] | None: def find_missing_requirements_from_lines( requirement_lines: Sequence[str], ) -> set[str] | None: + analysis = classify_missing_requirements_from_lines(requirement_lines) + if analysis is None: + return None + + return set(analysis.missing_names) + + +def classify_missing_requirements_from_lines( + requirement_lines: Sequence[str], +) -> MissingRequirementsAnalysis | None: required = list(iter_requirements(lines=requirement_lines)) if not required: - return set() + return MissingRequirementsAnalysis(missing_names=frozenset()) installed = collect_installed_distribution_versions(get_requirement_check_paths()) if installed is None: return None missing: set[str] = set() + version_mismatch_names: set[str] = set() for name, specifier in required: installed_version = installed.get(name) if not installed_version: @@ -410,8 +428,12 @@ def find_missing_requirements_from_lines( continue if specifier and not _specifier_contains_version(specifier, installed_version): missing.add(name) + version_mismatch_names.add(name) - return missing + return MissingRequirementsAnalysis( + missing_names=frozenset(missing), + version_mismatch_names=frozenset(version_mismatch_names), + ) def build_missing_requirements_install_lines( @@ -449,9 +471,11 @@ def plan_missing_requirements_install( if not can_precheck or requirement_lines is None: return None - missing = find_missing_requirements_from_lines(requirement_lines) - if missing is None: + analysis = classify_missing_requirements_from_lines(requirement_lines) + if analysis is None: return None + missing = analysis.missing_names + version_mismatch_names = analysis.version_mismatch_names install_lines = build_missing_requirements_install_lines( requirements_path, @@ -468,12 +492,14 @@ def plan_missing_requirements_install( ) return MissingRequirementsPlan( missing_names=frozenset(missing), + version_mismatch_names=frozenset(version_mismatch_names), install_lines=(), fallback_reason="unmapped missing requirement names", ) return MissingRequirementsPlan( missing_names=frozenset(missing), + version_mismatch_names=frozenset(version_mismatch_names), install_lines=install_lines, ) diff --git a/main.py b/main.py index 1cc9009826..14e0c23a81 100644 --- a/main.py +++ b/main.py @@ -51,7 +51,7 @@ def check_env() -> None: site_packages_path = get_astrbot_site_packages_path() if site_packages_path not in sys.path: - sys.path.insert(0, site_packages_path) + sys.path.append(site_packages_path) os.makedirs(get_astrbot_config_path(), exist_ok=True) os.makedirs(get_astrbot_plugin_path(), exist_ok=True) diff --git a/tests/test_main.py b/tests/test_main.py index c4bab2c2c3..0f20016aeb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -60,6 +60,48 @@ def test_check_env(monkeypatch): check_env() +def test_check_env_appends_user_site_packages_after_runtime_paths(monkeypatch): + astrbot_root = "/tmp/astrbot-root" + site_packages_path = "/tmp/astrbot-site-packages" + original_sys_path = list(sys.path) + + monkeypatch.setattr(sys, "version_info", _version_info(3, 12)) + monkeypatch.setattr("main.get_astrbot_root", lambda: astrbot_root) + monkeypatch.setattr("main.get_astrbot_site_packages_path", lambda: site_packages_path) + monkeypatch.setattr("main.get_astrbot_config_path", lambda: "/tmp/config") + monkeypatch.setattr("main.get_astrbot_plugin_path", lambda: "/tmp/plugins") + monkeypatch.setattr("main.get_astrbot_temp_path", lambda: "/tmp/temp") + monkeypatch.setattr("main.get_astrbot_knowledge_base_path", lambda: "/tmp/kb") + monkeypatch.setattr(sys, "path", ["/runtime/lib", *original_sys_path]) + + with mock.patch("os.makedirs"): + check_env() + + assert sys.path[0] == astrbot_root + assert sys.path[-1] == site_packages_path + assert sys.path.index(site_packages_path) > sys.path.index("/runtime/lib") + + +def test_check_env_does_not_append_duplicate_user_site_packages(monkeypatch): + astrbot_root = "/tmp/astrbot-root" + site_packages_path = "/tmp/astrbot-site-packages" + original_sys_path = list(sys.path) + + monkeypatch.setattr(sys, "version_info", _version_info(3, 12)) + monkeypatch.setattr("main.get_astrbot_root", lambda: astrbot_root) + monkeypatch.setattr("main.get_astrbot_site_packages_path", lambda: site_packages_path) + monkeypatch.setattr("main.get_astrbot_config_path", lambda: "/tmp/config") + monkeypatch.setattr("main.get_astrbot_plugin_path", lambda: "/tmp/plugins") + monkeypatch.setattr("main.get_astrbot_temp_path", lambda: "/tmp/temp") + monkeypatch.setattr("main.get_astrbot_knowledge_base_path", lambda: "/tmp/kb") + monkeypatch.setattr(sys, "path", [astrbot_root, *original_sys_path, site_packages_path]) + + with mock.patch("os.makedirs"): + check_env() + + assert sys.path.count(site_packages_path) == 1 + + def test_version_info_comparisons(): """Test _version_info comparison operators with tuples and other instances.""" v3_10 = _version_info(3, 10) diff --git a/tests/test_pip_helper_modules.py b/tests/test_pip_helper_modules.py index 506dd09453..1ce4967139 100644 --- a/tests/test_pip_helper_modules.py +++ b/tests/test_pip_helper_modules.py @@ -226,8 +226,11 @@ def test_plan_missing_requirements_install_returns_none_when_missing_names_canno monkeypatch.setattr( requirements_utils, - "find_missing_requirements_from_lines", - lambda lines: {"botocore"}, + "classify_missing_requirements_from_lines", + lambda lines: requirements_utils.MissingRequirementsAnalysis( + missing_names=frozenset({"botocore"}), + version_mismatch_names=frozenset(), + ), ) plan = requirements_utils.plan_missing_requirements_install(str(requirements_path)) @@ -238,6 +241,29 @@ def test_plan_missing_requirements_install_returns_none_when_missing_names_canno assert plan.fallback_reason == "unmapped missing requirement names" +def test_classify_missing_requirements_from_lines_tracks_missing_and_version_mismatches( + monkeypatch, +): + monkeypatch.setattr( + requirements_utils, + "collect_installed_distribution_versions", + lambda paths: {"boto3": "1.0"}, + ) + monkeypatch.setattr( + requirements_utils, + "get_requirement_check_paths", + lambda: ["/tmp/site-packages"], + ) + + analysis = requirements_utils.classify_missing_requirements_from_lines( + ["boto3>=2.0", "botocore"] + ) + + assert analysis is not None + assert analysis.missing_names == frozenset({"boto3", "botocore"}) + assert analysis.version_mismatch_names == frozenset({"boto3"}) + + def test_plan_missing_requirements_install_loads_requirement_lines_once( monkeypatch, tmp_path, @@ -274,6 +300,31 @@ def mock_load(path): assert calls == [str(requirements_path)] +def test_plan_missing_requirements_install_tracks_version_mismatches( + monkeypatch, tmp_path +): + requirements_path = tmp_path / "requirements.txt" + requirements_path.write_text("boto3>=2.0\nbotocore\n", encoding="utf-8") + + monkeypatch.setattr( + requirements_utils, + "collect_installed_distribution_versions", + lambda paths: {"boto3": "1.0"}, + ) + monkeypatch.setattr( + requirements_utils, + "get_requirement_check_paths", + lambda: ["/tmp/site-packages"], + ) + + plan = requirements_utils.plan_missing_requirements_install(str(requirements_path)) + + assert plan is not None + assert plan.missing_names == frozenset({"boto3", "botocore"}) + assert plan.version_mismatch_names == frozenset({"boto3"}) + assert plan.install_lines == ("boto3>=2.0", "botocore") + + def test_build_missing_requirements_install_lines_logs_why_option_lines_fall_back( monkeypatch, tmp_path, diff --git a/tests/test_pip_installer.py b/tests/test_pip_installer.py index adbd174e32..bfddf60e1c 100644 --- a/tests/test_pip_installer.py +++ b/tests/test_pip_installer.py @@ -128,6 +128,58 @@ async def test_install_targets_site_packages_for_desktop_client(monkeypatch, tmp assert ensure_preferred_calls == [(str(site_packages_path), {"demo-package"})] +@pytest.mark.asyncio +async def test_install_keeps_target_upgrade_enabled_by_default_for_desktop_client( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + monkeypatch.delattr("sys.frozen", raising=False) + + site_packages_path = tmp_path / "site-packages" + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.get_astrbot_site_packages_path", + lambda: str(site_packages_path), + ) + + installer = PipInstaller("") + await installer.install(package_name="demo-package") + + recorded_args = run_pip.await_args_list[0].args[0] + + assert "--target" in recorded_args + assert "--upgrade" in recorded_args + assert "--upgrade-strategy" in recorded_args + + +@pytest.mark.asyncio +async def test_install_skips_target_upgrade_when_disabled_for_desktop_client( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + monkeypatch.delattr("sys.frozen", raising=False) + + site_packages_path = tmp_path / "site-packages" + run_pip = _make_run_pip_mock() + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.get_astrbot_site_packages_path", + lambda: str(site_packages_path), + ) + + installer = PipInstaller("") + await installer.install(package_name="demo-package", allow_target_upgrade=False) + + recorded_args = run_pip.await_args_list[0].args[0] + + assert "--target" in recorded_args + assert "--upgrade" not in recorded_args + assert "--upgrade-strategy" not in recorded_args + + @pytest.mark.asyncio async def test_run_pip_in_process_streams_output_lines(monkeypatch): logged_lines = [] diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 6e6db7da3a..2479de5626 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,12 +1,12 @@ import asyncio import os from pathlib import Path - from typing import Any, cast import pytest import yaml +from astrbot.core.star import star_manager as star_manager_module from astrbot.core.star.star_manager import PluginDependencyInstallError, PluginManager from astrbot.core.utils.pip_installer import PipInstallError from astrbot.core.utils.requirements_utils import MissingRequirementsPlan @@ -114,12 +114,14 @@ def _mock_missing_requirements_plan( missing_names, install_lines, *, + version_mismatch_names=(), fallback_reason: str | None = None, ): monkeypatch.setattr( "astrbot.core.star.star_manager.plan_missing_requirements_install", lambda requirements_path: MissingRequirementsPlan( missing_names=frozenset(missing_names), + version_mismatch_names=frozenset(version_mismatch_names), install_lines=tuple(install_lines), fallback_reason=fallback_reason, ), @@ -440,6 +442,417 @@ async def mock_install_requirements(*args, **kwargs): assert any("按 requirements.txt 安装" in line for line in logged_lines) +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("version_mismatch_names", "expected_allow_target_upgrade"), + [ + (set(), False), + ({"networkx"}, True), + ], +) +async def test_ensure_plugin_requirements_sets_target_upgrade_based_on_version_mismatch( + plugin_manager_pm: PluginManager, + local_updator: Path, + monkeypatch, + version_mismatch_names, + expected_allow_target_upgrade: bool, +): + _write_requirements(local_updator) + _mock_missing_requirements_plan( + monkeypatch, + {"networkx"}, + ["networkx"], + version_mismatch_names=version_mismatch_names, + ) + observed_calls = [] + + async def mock_install_requirements(*args, **kwargs): + observed_calls.append(kwargs) + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + mock_install_requirements, + ) + + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert len(observed_calls) == 1 + assert observed_calls[0]["allow_target_upgrade"] is expected_allow_target_upgrade + + +@pytest.mark.asyncio +async def test_import_plugin_prefers_installed_dependencies_before_first_import( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx\n", encoding="utf-8") + events = [] + sentinel_module = object() + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + lambda *, requirements_path: events.append(("prefer", requirements_path)), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: MissingRequirementsPlan( + missing_names=frozenset(), + install_lines=(), + version_mismatch_names=frozenset(), + ), + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + events.append(("import", name, tuple(fromlist))) + return sentinel_module + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + imported_module = await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="data.plugins.helloworld.main", + module_str="main", + root_dir_name=TEST_PLUGIN_DIR, + requirements_path=str(requirements_path), + ) + + assert imported_module is sentinel_module + assert events == [ + ("prefer", str(requirements_path)), + ("import", "data.plugins.helloworld.main", ("main",)), + ] + + +@pytest.mark.asyncio +async def test_import_reserved_plugin_skips_preloading_user_site_dependencies( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx\n", encoding="utf-8") + events = [] + sentinel_module = object() + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + lambda *, requirements_path: events.append(("prefer", requirements_path)), + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + events.append(("import", name, tuple(fromlist))) + return sentinel_module + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + imported_module = await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="astrbot.builtin_stars.web_searcher.main", + module_str="main", + root_dir_name="web_searcher", + requirements_path=str(requirements_path), + reserved=True, + ) + + assert imported_module is sentinel_module + assert events == [ + ("import", "astrbot.builtin_stars.web_searcher.main", ("main",)), + ] + + +@pytest.mark.asyncio +async def test_import_plugin_skips_preloading_when_requirements_version_mismatch_detected( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx>=3\n", encoding="utf-8") + events = [] + sentinel_module = object() + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + lambda *, requirements_path: events.append(("prefer", requirements_path)), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: MissingRequirementsPlan( + missing_names=frozenset({"networkx"}), + install_lines=("networkx>=3",), + version_mismatch_names=frozenset({"networkx"}), + ), + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + events.append(("import", name, tuple(fromlist))) + return sentinel_module + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + imported_module = await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="data.plugins.helloworld.main", + module_str="main", + root_dir_name=TEST_PLUGIN_DIR, + requirements_path=str(requirements_path), + ) + + assert imported_module is sentinel_module + assert events == [ + ("import", "data.plugins.helloworld.main", ("main",)), + ] + + +@pytest.mark.asyncio +async def test_import_plugin_reinstalls_when_version_mismatch_import_fails( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx>=3\n", encoding="utf-8") + events = [] + sentinel_module = object() + import_attempts = {"count": 0} + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + lambda *, requirements_path: events.append(("prefer", requirements_path)), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: MissingRequirementsPlan( + missing_names=frozenset({"networkx"}), + install_lines=("networkx>=3",), + version_mismatch_names=frozenset({"networkx"}), + ), + ) + + async def mock_check_plugin_dept_update(*, target_plugin=None): + events.append(("reinstall", target_plugin)) + + monkeypatch.setattr( + plugin_manager_pm, + "_check_plugin_dept_update", + mock_check_plugin_dept_update, + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + import_attempts["count"] += 1 + events.append(("import", name, tuple(fromlist), import_attempts["count"])) + if import_attempts["count"] == 1: + raise ModuleNotFoundError("networkx") + return sentinel_module + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + imported_module = await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="data.plugins.helloworld.main", + module_str="main", + root_dir_name=TEST_PLUGIN_DIR, + requirements_path=str(requirements_path), + ) + + assert imported_module is sentinel_module + assert events == [ + ("import", "data.plugins.helloworld.main", ("main",), 1), + ("reinstall", TEST_PLUGIN_DIR), + ("import", "data.plugins.helloworld.main", ("main",), 2), + ] + + +@pytest.mark.asyncio +async def test_import_plugin_skips_preloading_when_requirement_precheck_is_unavailable( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx\n", encoding="utf-8") + events = [] + sentinel_module = object() + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + lambda *, requirements_path: events.append(("prefer", requirements_path)), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: None, + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + events.append(("import", name, tuple(fromlist))) + return sentinel_module + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + imported_module = await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="data.plugins.helloworld.main", + module_str="main", + root_dir_name=TEST_PLUGIN_DIR, + requirements_path=str(requirements_path), + ) + + assert imported_module is sentinel_module + assert events == [ + ("import", "data.plugins.helloworld.main", ("main",)), + ] + + +@pytest.mark.asyncio +async def test_import_plugin_attempts_dependency_recovery_when_precheck_is_unavailable( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx\n", encoding="utf-8") + events = [] + sentinel_module = object() + import_attempts = {"count": 0} + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + lambda *, requirements_path: events.append(("prefer", requirements_path)), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: None, + ) + + async def unexpected_check_plugin_dept_update(*args, **kwargs): + raise AssertionError("dependency install fallback should not run") + + monkeypatch.setattr( + plugin_manager_pm, + "_check_plugin_dept_update", + unexpected_check_plugin_dept_update, + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + import_attempts["count"] += 1 + events.append(("import", name, tuple(fromlist), import_attempts["count"])) + if import_attempts["count"] == 1: + raise ModuleNotFoundError("networkx") + return sentinel_module + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + imported_module = await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="data.plugins.helloworld.main", + module_str="main", + root_dir_name=TEST_PLUGIN_DIR, + requirements_path=str(requirements_path), + ) + + assert imported_module is sentinel_module + assert events == [ + ("import", "data.plugins.helloworld.main", ("main",), 1), + ("prefer", str(requirements_path)), + ("import", "data.plugins.helloworld.main", ("main",), 2), + ] + + +@pytest.mark.asyncio +async def test_import_plugin_does_not_recover_from_plain_import_error( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx\n", encoding="utf-8") + events = [] + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + lambda *, requirements_path: events.append(("prefer", requirements_path)), + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: MissingRequirementsPlan( + missing_names=frozenset(), + install_lines=(), + version_mismatch_names=frozenset(), + ), + ) + + async def unexpected_check_plugin_dept_update(*args, **kwargs): + raise AssertionError("dependency install fallback should not run") + + monkeypatch.setattr( + plugin_manager_pm, + "_check_plugin_dept_update", + unexpected_check_plugin_dept_update, + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + events.append(("import", name, tuple(fromlist))) + raise ImportError("plugin import error") + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + with pytest.raises(ImportError, match="plugin import error"): + await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="data.plugins.helloworld.main", + module_str="main", + root_dir_name=TEST_PLUGIN_DIR, + requirements_path=str(requirements_path), + ) + + assert events == [ + ("prefer", str(requirements_path)), + ("import", "data.plugins.helloworld.main", ("main",)), + ] + + +@pytest.mark.asyncio +async def test_import_plugin_surfaces_unexpected_recovery_errors( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("networkx\n", encoding="utf-8") + events = [] + + def raising_prefer_installed_dependencies(*, requirements_path): + events.append(("prefer", requirements_path)) + raise RuntimeError("unexpected recovery failure") + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.prefer_installed_dependencies", + raising_prefer_installed_dependencies, + ) + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda requirements_path: None, + ) + + async def unexpected_check_plugin_dept_update(*args, **kwargs): + raise AssertionError("dependency install fallback should not run") + + monkeypatch.setattr( + plugin_manager_pm, + "_check_plugin_dept_update", + unexpected_check_plugin_dept_update, + ) + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + del globals, locals, level + events.append(("import", name, tuple(fromlist))) + raise ModuleNotFoundError("networkx") + + monkeypatch.setattr(star_manager_module, "__import__", fake_import, raising=False) + + with pytest.raises(RuntimeError, match="unexpected recovery failure"): + await plugin_manager_pm._import_plugin_with_dependency_recovery( + path="data.plugins.helloworld.main", + module_str="main", + root_dir_name=TEST_PLUGIN_DIR, + requirements_path=str(requirements_path), + ) + + assert events == [ + ("import", "data.plugins.helloworld.main", ("main",)), + ("prefer", str(requirements_path)), + ] + + @pytest.mark.asyncio @pytest.mark.parametrize("dependency_install_fails", [False, True]) async def test_update_plugin_dependency_install_flow( @@ -641,6 +1054,42 @@ async def test_ensure_plugin_requirements_falls_back_when_missing_names_have_no_ assert events == [("deps", str(requirements_path))] +@pytest.mark.asyncio +async def test_ensure_plugin_requirements_fallback_full_install_keeps_upgrade_for_version_mismatch( + plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch +): + requirements_path = local_updator / "requirements.txt" + requirements_path.write_text("boto3>=2\n", encoding="utf-8") + observed_calls = [] + + monkeypatch.setattr( + "astrbot.core.star.star_manager.plan_missing_requirements_install", + lambda path: MissingRequirementsPlan( + missing_names=frozenset({"boto3"}), + install_lines=(), + version_mismatch_names=frozenset({"boto3"}), + fallback_reason="unmapped missing requirement names", + ), + ) + + async def mock_install_requirements(*args, **kwargs): + observed_calls.append(kwargs) + + monkeypatch.setattr( + "astrbot.core.star.star_manager.pip_installer.install", + mock_install_requirements, + ) + + await plugin_manager_pm._ensure_plugin_requirements( + str(local_updator), + TEST_PLUGIN_DIR, + ) + + assert len(observed_calls) == 1 + assert observed_calls[0]["requirements_path"] == str(requirements_path) + assert observed_calls[0]["allow_target_upgrade"] is True + + @pytest.mark.asyncio async def test_ensure_plugin_requirements_does_not_mask_install_error_when_cleanup_fails( plugin_manager_pm: PluginManager, local_updator: Path, monkeypatch, tmp_path