diff --git a/mellea/plugins/manager.py b/mellea/plugins/manager.py index f29b7eb2a..e196a0a25 100644 --- a/mellea/plugins/manager.py +++ b/mellea/plugins/manager.py @@ -26,11 +26,47 @@ _plugin_manager: Any | None = None _plugins_enabled: bool = False _session_tags: dict[str, set[str]] = {} # session_id -> set of plugin names +_pending_background_results: list[Any] = [] +_collect_background_results: bool = False # opt-in; only tests enable this DEFAULT_PLUGIN_TIMEOUT: int = 5 # seconds DEFAULT_HOOK_POLICY: Literal["allow"] | Literal["deny"] = "deny" +def enable_background_collection() -> None: + """Enable fire-and-forget result collection. Call in test fixtures before each test.""" + global _collect_background_results + _collect_background_results = True + + +def disable_background_collection() -> None: + """Disable fire-and-forget result collection and clear any accumulated results.""" + global _collect_background_results, _pending_background_results + _collect_background_results = False + _pending_background_results = [] + + +async def drain_background_tasks() -> None: + """Await all accumulated FIRE_AND_FORGET tasks and clear the pending list. + + Call this in tests after any operation that may have triggered fire-and-forget plugins, + to ensure side effects (metrics recording, etc.) complete before assertions. + """ + global _pending_background_results + pending, _pending_background_results = _pending_background_results, [] + for result in pending: + await result.wait_for_background_tasks() + + +def discard_background_tasks() -> None: + """Discard all accumulated FIRE_AND_FORGET tasks without awaiting them. + + Call this in test fixtures to clear stale results from a previous event + loop before running the next test. + """ + _pending_background_results.clear() + + def has_plugins(hook_type: HookType | None = None) -> bool: """Fast check: are plugins configured and available for the given hook type. @@ -143,6 +179,7 @@ async def shutdown_plugins() -> None: _plugin_manager = None _plugins_enabled = False _session_tags.clear() + _pending_background_results.clear() def track_session_plugin(session_id: str, plugin_name: str) -> None: @@ -229,6 +266,9 @@ async def invoke_hook( violations_as_exceptions=False, ) + if _collect_background_results and result and result.background_tasks: + _pending_background_results.append(result) + if result and not result.continue_processing and result.violation: v = result.violation logger.warning( diff --git a/pyproject.toml b/pyproject.toml index 0b6d3c9ff..5e30a2436 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,7 @@ sandbox = [ backends = ["mellea[watsonx,hf,litellm]"] hooks = [ - "cpex>=0.1.0.dev10; python_version >= '3.11'", + "cpex>=0.1.0.dev12; python_version >= '3.11'", "grpcio>=1.78.0", ] diff --git a/test/plugins/test_execution_modes.py b/test/plugins/test_execution_modes.py index 5167943f4..6083c3b36 100644 --- a/test/plugins/test_execution_modes.py +++ b/test/plugins/test_execution_modes.py @@ -21,8 +21,6 @@ from __future__ import annotations -import asyncio - import pytest pytest.importorskip("cpex.framework") @@ -344,11 +342,10 @@ async def faf_observer(payload, ctx): register(faf_observer) - await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload()) + result, _ = await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload()) - # The hook runs as a background asyncio task; yield to the event loop to - # allow it to complete before asserting. - await asyncio.sleep(0.05) + assert result is not None + await result.wait_for_background_tasks() assert invocations == ["fired"] @pytest.mark.asyncio @@ -413,9 +410,10 @@ async def enforce_second(payload, ctx): register(faf_first) register(enforce_second) - await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload()) + result, _ = await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload()) - await asyncio.sleep(0.05) + assert result is not None + await result.wait_for_background_tasks() assert order == ["enforce", "faf"] @pytest.mark.asyncio diff --git a/test/telemetry/test_metrics_backend.py b/test/telemetry/test_metrics_backend.py index 95a40b5f7..5ce1db66b 100644 --- a/test/telemetry/test_metrics_backend.py +++ b/test/telemetry/test_metrics_backend.py @@ -3,7 +3,6 @@ Tests that backends correctly record token metrics through the telemetry system. """ -import asyncio import os import pytest @@ -12,6 +11,12 @@ IBM_GRANITE_4_HYBRID_MICRO, IBM_GRANITE_4_HYBRID_SMALL, ) +from mellea.plugins.manager import ( + disable_background_collection, + discard_background_tasks, + drain_background_tasks, + enable_background_collection, +) from mellea.stdlib.components import Message from mellea.stdlib.context import SimpleContext from test.predicates import require_api_key, require_gpu @@ -41,6 +46,8 @@ def metric_reader(): @pytest.fixture def enable_metrics(monkeypatch): """Enable metrics for tests.""" + enable_background_collection() + discard_background_tasks() monkeypatch.setenv("MELLEA_METRICS_ENABLED", "true") # Force reload of metrics module to pick up env vars import importlib @@ -52,6 +59,7 @@ def enable_metrics(monkeypatch): # Reset after test monkeypatch.setenv("MELLEA_METRICS_ENABLED", "false") importlib.reload(mellea.telemetry.metrics) + disable_background_collection() @pytest.fixture(scope="module") @@ -169,8 +177,7 @@ async def test_ollama_token_metrics_integration(enable_metrics, metric_reader, s await mot.avalue() # Force metrics export and collection - # Yield to event loop so FIRE_AND_FORGET plugin tasks complete - await asyncio.sleep(0.05) + await drain_background_tasks() provider.force_flush() metrics_data = metric_reader.get_metrics_data() @@ -235,8 +242,7 @@ async def test_openai_token_metrics_integration(enable_metrics, metric_reader, s await mot.astream() await mot.avalue() - # Yield to event loop so FIRE_AND_FORGET plugin tasks complete - await asyncio.sleep(0.05) + await drain_background_tasks() provider.force_flush() metrics_data = metric_reader.get_metrics_data() @@ -290,8 +296,7 @@ async def test_watsonx_token_metrics_integration(enable_metrics, metric_reader): ) await mot.avalue() - # Yield to event loop so FIRE_AND_FORGET plugin tasks complete - await asyncio.sleep(0.05) + await drain_background_tasks() provider.force_flush() metrics_data = metric_reader.get_metrics_data() @@ -354,8 +359,7 @@ async def test_litellm_token_metrics_integration( await mot.astream() await mot.avalue() - # Yield to event loop so FIRE_AND_FORGET plugin tasks complete - await asyncio.sleep(0.05) + await drain_background_tasks() provider.force_flush() metrics_data = metric_reader.get_metrics_data() @@ -413,8 +417,7 @@ async def test_huggingface_token_metrics_integration( await mot.astream() await mot.avalue() - # Yield to event loop so FIRE_AND_FORGET plugin tasks complete - await asyncio.sleep(0.05) + await drain_background_tasks() provider.force_flush() metrics_data = metric_reader.get_metrics_data() @@ -470,8 +473,7 @@ async def test_error_metrics_on_backend_failure(enable_metrics, metric_reader): with pytest.raises(Exception): await mot.avalue() - # Yield to event loop so FIRE_AND_FORGET plugin task completes - await asyncio.sleep(0.05) + await drain_background_tasks() provider.force_flush() metrics_data = metric_reader.get_metrics_data() @@ -508,8 +510,7 @@ async def test_ollama_sampling_metrics_integration(enable_metrics, metric_reader action=Instruction("Say hello"), context=ctx, backend=backend, requirements=None ) - # Yield to event loop so FIRE_AND_FORGET plugin tasks complete - await asyncio.sleep(0.05) + await drain_background_tasks() provider.force_flush() metrics_data = metric_reader.get_metrics_data() diff --git a/uv.lock b/uv.lock index 18e87d7cb..ef9d2c29f 100644 --- a/uv.lock +++ b/uv.lock @@ -749,7 +749,7 @@ toml = [ [[package]] name = "cpex" -version = "0.1.0.dev11" +version = "0.1.0.dev12" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastapi" }, @@ -757,15 +757,16 @@ dependencies = [ { name = "jinja2" }, { name = "mcp" }, { name = "orjson" }, + { name = "packaging" }, { name = "prometheus-client" }, { name = "prometheus-fastapi-instrumentator" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyyaml" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/4c/c462d23f98b7b388dd6642d71702f5f47d120c0df471a22add22ad99fe69/cpex-0.1.0.dev11.tar.gz", hash = "sha256:c7c30650fd49fdae7ec67f46c1d57486db090ae79ec39a6ed2f8ed990b78f6e5", size = 754645, upload-time = "2026-04-13T21:02:26.901Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/f6/d5a194b338b3d55b1b9b8619baafa504ae8146168cf4b91fcefa95811a16/cpex-0.1.0.dev12.tar.gz", hash = "sha256:9fb08e0fa27236747c26c841260951a83252029c0e55a7550c65a060473f200c", size = 3475629, upload-time = "2026-04-23T17:34:14.434Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/5e/d051f25b6bb2241f61f710ef0a127cbdd4b8ca2612d8b777d4164347e1f6/cpex-0.1.0.dev11-py3-none-any.whl", hash = "sha256:c34d23f16191744e98415cdf850d981e377399a2d802abda565c276f217ae380", size = 167516, upload-time = "2026-04-13T21:02:25.687Z" }, + { url = "https://files.pythonhosted.org/packages/39/ed/f70537bd8adbf1f847a703e02c3abd2cdc53dfa87e44977aa25dd163774b/cpex-0.1.0.dev12-py3-none-any.whl", hash = "sha256:5c10688b6f7ca8c3673fce9dfd94d0b3a348e0e63566546ced5068574a38403e", size = 236654, upload-time = "2026-04-23T17:34:12.592Z" }, ] [[package]] @@ -3334,7 +3335,7 @@ typecheck = [ requires-dist = [ { name = "accelerate", marker = "extra == 'hf'", specifier = ">=1.9.0" }, { name = "boto3", marker = "extra == 'litellm'" }, - { name = "cpex", marker = "python_full_version >= '3.11' and extra == 'hooks'", specifier = ">=0.1.0.dev10" }, + { name = "cpex", marker = "python_full_version >= '3.11' and extra == 'hooks'", specifier = ">=0.1.0.dev12" }, { name = "datasets", marker = "extra == 'hf'", specifier = ">=4.0.0" }, { name = "docling", marker = "extra == 'docling'", specifier = ">=2.45.0" }, { name = "elasticsearch", marker = "extra == 'granite-retriever'", specifier = ">=8.0.0,<9.0.0" },