From d99a0791843ca1da71075a9969844e78d9020d97 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Tue, 21 Apr 2026 16:43:56 -0500 Subject: [PATCH 1/2] feat: expose background_tasks in PluginResult for fire-and-forget synchronization Adds a background_tasks field to PluginResult containing the asyncio.Task handles created for FIRE_AND_FORGET plugins. Callers can now await background tasks deterministically instead of relying on arbitrary sleep delays. Closes #25 Signed-off-by: Alex Bozarth --- cpex/framework/manager.py | 14 ++++++++++---- cpex/framework/models.py | 9 +++++++++ docs/specs/plugin-framework-spec.md | 8 ++++++++ tests/unit/cpex/framework/test_plugin_modes.py | 18 +++++++++--------- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/cpex/framework/manager.py b/cpex/framework/manager.py index eb4a69c6..2c98d648 100644 --- a/cpex/framework/manager.py +++ b/cpex/framework/manager.py @@ -354,7 +354,7 @@ async def execute( ) # FIRE_AND_FORGET: fire-and-forget background tasks (fires last with final payload snapshot) - self._fire_and_forget_tasks( + bg_tasks = self._fire_and_forget_tasks( fire_and_forget_refs, payload, global_context, @@ -373,6 +373,7 @@ async def execute( modified_extensions=current_extensions, violation=None, metadata=combined_metadata, + background_tasks=bg_tasks, ), res_local_contexts, ) @@ -688,7 +689,7 @@ def _build_halt_result( extensions: Optional[Extensions] = None, ) -> tuple[PluginResult, dict]: """Schedule fire-and-forget tasks and build a pipeline-halting result.""" - self._fire_and_forget_tasks( + bg_tasks = self._fire_and_forget_tasks( fire_and_forget_refs, payload, global_context, @@ -704,6 +705,7 @@ def _build_halt_result( modified_payload=current_payload, violation=violation, metadata=combined_metadata, + background_tasks=bg_tasks, ), res_local_contexts, ) @@ -728,12 +730,14 @@ def _fire_and_forget_tasks( res_local_contexts: dict, semaphore: Optional[asyncio.Semaphore], extensions: Optional[Extensions] = None, - ) -> None: + ) -> list[asyncio.Task]: """Schedule all FIRE_AND_FORGET plugins as fire-and-forget background tasks. May be called from an early-exit path or from the normal completion path. Each FIRE_AND_FORGET plugin receives an isolated snapshot of the payload at call time. + Returns the list of asyncio.Task handles for all newly scheduled tasks. """ + tasks: list[asyncio.Task] = [] for ref in fire_and_forget_refs: local_context_key = global_context.request_id + ref.plugin_ref.uuid if local_context_key in res_local_contexts: @@ -752,9 +756,11 @@ def _fire_and_forget_tasks( ) local_context = PluginContext(global_context=tmp_gc) res_local_contexts[local_context_key] = local_context - asyncio.create_task( + task = asyncio.create_task( self._run_fire_and_forget_task(ref, task_input, local_context, semaphore, extensions=extensions) ) + tasks.append(task) + return tasks async def _run_fire_and_forget_task( self, diff --git a/cpex/framework/models.py b/cpex/framework/models.py index 7402801c..9cbfe88d 100644 --- a/cpex/framework/models.py +++ b/cpex/framework/models.py @@ -10,6 +10,7 @@ """ # Standard +import asyncio import logging import os import re @@ -1498,6 +1499,11 @@ class PluginResult(BaseModel, Generic[T]): (e.g., updated HTTP headers from token delegation, appended security labels). violation (Optional[PluginViolation]): violation object. metadata (Optional[dict[str, Any]]): additional metadata. + background_tasks (list[asyncio.Task]): asyncio.Task handles for any FIRE_AND_FORGET + plugins scheduled during this invocation. Use + ``await asyncio.gather(*result.background_tasks, return_exceptions=True)`` + to deterministically wait for all background tasks to complete (useful in tests). + This field is excluded from model serialization. Examples: >>> result = PluginResult() @@ -1522,11 +1528,14 @@ class PluginResult(BaseModel, Generic[T]): False """ + model_config = ConfigDict(arbitrary_types_allowed=True) + continue_processing: bool = True modified_payload: Optional[T] = None modified_extensions: Optional[Extensions] = None violation: Optional[PluginViolation] = None metadata: Optional[dict[str, Any]] = Field(default_factory=dict) + background_tasks: list[asyncio.Task] = Field(default_factory=list, exclude=True) class GlobalContext(BaseModel): diff --git a/docs/specs/plugin-framework-spec.md b/docs/specs/plugin-framework-spec.md index 1bbf87eb..df8ae27b 100644 --- a/docs/specs/plugin-framework-spec.md +++ b/docs/specs/plugin-framework-spec.md @@ -185,6 +185,14 @@ class PluginResult(Generic[T]): modified_payload: T | None = None violation: PluginViolation | None = None metadata: dict[str, Any] = {} + background_tasks: list[asyncio.Task] = [] # excluded from serialization +``` + +`background_tasks` contains the `asyncio.Task` handles for any `FIRE_AND_FORGET` plugins scheduled during the invocation. Use it to wait for background tasks without sleep delays: + +```python +result, _ = await manager.invoke_hook(...) +await asyncio.gather(*result.background_tasks, return_exceptions=True) ``` The `PluginViolation` type carries structured policy failure information: diff --git a/tests/unit/cpex/framework/test_plugin_modes.py b/tests/unit/cpex/framework/test_plugin_modes.py index 81c3b0c6..021af77f 100644 --- a/tests/unit/cpex/framework/test_plugin_modes.py +++ b/tests/unit/cpex/framework/test_plugin_modes.py @@ -101,8 +101,8 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing assert not finished.is_set() - # Let the background task complete - await asyncio.sleep(0.1) + # Wait deterministically for the background task to complete + await asyncio.gather(*result.background_tasks, return_exceptions=True) assert finished.is_set() await manager.shutdown() @@ -129,8 +129,8 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing - # Allow background task to run and silently fail - await asyncio.sleep(0.05) + # Wait deterministically for the background task to run and silently fail + await asyncio.gather(*result.background_tasks, return_exceptions=True) await manager.shutdown() @@ -356,8 +356,8 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing - # Allow all background FIRE_AND_FORGET tasks to complete - await asyncio.sleep(0.1) + # Wait deterministically for all background FIRE_AND_FORGET tasks to complete + await asyncio.gather(*result.background_tasks, return_exceptions=True) # With pool=1, max concurrency should be 1 assert concurrency_high_water <= 1 @@ -619,8 +619,8 @@ async def prompt_pre_fetch(self, payload, context): result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) assert result.continue_processing - # F&F is async — wait for it - await asyncio.sleep(0.1) + # F&F is async — wait for it deterministically + await asyncio.gather(*result.background_tasks, return_exceptions=True) assert phase_log == ["seq", "xform", "audit", "conc", "fnf"] @@ -666,7 +666,7 @@ async def prompt_pre_fetch(self, payload, context): # FIRE_AND_FORGET has not yet completed (fire-and-forget) assert not fire_and_forget_started.is_set() - await asyncio.sleep(0.1) + await asyncio.gather(*result.background_tasks, return_exceptions=True) assert "fire_and_forget" in phase_log # FIRE_AND_FORGET always comes after sequential in the log assert phase_log.index("sequential") < phase_log.index("fire_and_forget") From bd3b22c79f158df97534d247111b159074137893 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 22 Apr 2026 15:15:40 -0500 Subject: [PATCH 2/2] feat: add PluginResult.wait_for_background_tasks() Signed-off-by: Alex Bozarth --- cpex/framework/manager.py | 7 +++- cpex/framework/models.py | 20 ++++++++-- docs/specs/plugin-framework-spec.md | 4 +- .../unit/cpex/framework/test_plugin_modes.py | 39 ++++++++++++++++--- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/cpex/framework/manager.py b/cpex/framework/manager.py index 2c98d648..b4885112 100644 --- a/cpex/framework/manager.py +++ b/cpex/framework/manager.py @@ -769,9 +769,10 @@ async def _run_fire_and_forget_task( local_context: PluginContext, semaphore: Optional[asyncio.Semaphore], extensions: Optional[Extensions] = None, - ) -> None: + ) -> Optional[PluginErrorModel]: """Execute a plugin as a fire-and-forget background task. + Returns None on success, or a PluginErrorModel if the plugin raised. Errors are logged but never propagated — background tasks cannot halt the pipeline. If on_error=DISABLE, the plugin is added to the runtime-disabled set. """ @@ -781,11 +782,13 @@ async def _run_fire_and_forget_task( await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions) else: await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions) - except Exception: + return None + except Exception as exc: logger.error("Plugin %s failed in fire-and-forget mode (ignored)", hook_ref.plugin_ref.name) if hook_ref.plugin_ref.on_error == OnError.DISABLE: self._runtime_disabled.add(hook_ref.plugin_ref.name) # FAIL and IGNORE both just log for FIRE_AND_FORGET mode (background can't halt pipeline) + return PluginErrorModel(message=repr(exc), plugin_name=hook_ref.plugin_ref.name) async def execute_plugin( self, diff --git a/cpex/framework/models.py b/cpex/framework/models.py index 9cbfe88d..73522598 100644 --- a/cpex/framework/models.py +++ b/cpex/framework/models.py @@ -1500,10 +1500,8 @@ class PluginResult(BaseModel, Generic[T]): violation (Optional[PluginViolation]): violation object. metadata (Optional[dict[str, Any]]): additional metadata. background_tasks (list[asyncio.Task]): asyncio.Task handles for any FIRE_AND_FORGET - plugins scheduled during this invocation. Use - ``await asyncio.gather(*result.background_tasks, return_exceptions=True)`` - to deterministically wait for all background tasks to complete (useful in tests). - This field is excluded from model serialization. + plugins scheduled during this invocation. Use ``wait_for_background_tasks()`` + to await them and collect any errors. This field is excluded from model serialization. Examples: >>> result = PluginResult() @@ -1537,6 +1535,20 @@ class PluginResult(BaseModel, Generic[T]): metadata: Optional[dict[str, Any]] = Field(default_factory=dict) background_tasks: list[asyncio.Task] = Field(default_factory=list, exclude=True) + async def wait_for_background_tasks(self) -> "list[PluginErrorModel]": + """Await all FIRE_AND_FORGET background tasks and return any errors. + + Returns an empty list if all tasks completed without error. + + Examples: + >>> result = PluginResult() + >>> # errors = await result.wait_for_background_tasks() + """ + if not self.background_tasks: + return [] + results = await asyncio.gather(*self.background_tasks, return_exceptions=True) + return [r for r in results if isinstance(r, PluginErrorModel)] + class GlobalContext(BaseModel): """The global context, which shared across all plugins. diff --git a/docs/specs/plugin-framework-spec.md b/docs/specs/plugin-framework-spec.md index df8ae27b..7727c786 100644 --- a/docs/specs/plugin-framework-spec.md +++ b/docs/specs/plugin-framework-spec.md @@ -188,11 +188,11 @@ class PluginResult(Generic[T]): background_tasks: list[asyncio.Task] = [] # excluded from serialization ``` -`background_tasks` contains the `asyncio.Task` handles for any `FIRE_AND_FORGET` plugins scheduled during the invocation. Use it to wait for background tasks without sleep delays: +`background_tasks` contains the `asyncio.Task` handles for any `FIRE_AND_FORGET` plugins scheduled during the invocation. Use `wait_for_background_tasks()` to await them and collect any errors: ```python result, _ = await manager.invoke_hook(...) -await asyncio.gather(*result.background_tasks, return_exceptions=True) +errors = await result.wait_for_background_tasks() # list[PluginErrorModel], empty on success ``` The `PluginViolation` type carries structured policy failure information: diff --git a/tests/unit/cpex/framework/test_plugin_modes.py b/tests/unit/cpex/framework/test_plugin_modes.py index 021af77f..6de75413 100644 --- a/tests/unit/cpex/framework/test_plugin_modes.py +++ b/tests/unit/cpex/framework/test_plugin_modes.py @@ -102,7 +102,7 @@ async def prompt_pre_fetch(self, payload, context): assert not finished.is_set() # Wait deterministically for the background task to complete - await asyncio.gather(*result.background_tasks, return_exceptions=True) + await result.wait_for_background_tasks() assert finished.is_set() await manager.shutdown() @@ -129,8 +129,35 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing - # Wait deterministically for the background task to run and silently fail - await asyncio.gather(*result.background_tasks, return_exceptions=True) + # Wait for the background task; errors are returned, not raised + await result.wait_for_background_tasks() + + await manager.shutdown() + + +@pytest.mark.asyncio +async def test_wait_for_background_tasks_returns_errors(): + """wait_for_background_tasks() returns a PluginErrorModel for each failed task.""" + + class BrokenPlugin(Plugin): + async def prompt_pre_fetch(self, payload, context): + raise RuntimeError("boom") + + manager = await _make_manager() + cfg = make_plugin_config("BrokenFnF", PluginMode.FIRE_AND_FORGET) + plugin = BrokenPlugin(cfg) + + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + mock_get.return_value = [HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin))] + payload = PromptPrehookPayload(prompt_id="test", args={}) + global_context = GlobalContext(request_id="wait_errors") + + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) + + errors = await result.wait_for_background_tasks() + assert len(errors) == 1 + assert errors[0].plugin_name == "BrokenFnF" + assert "RuntimeError" in errors[0].message await manager.shutdown() @@ -357,7 +384,7 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing # Wait deterministically for all background FIRE_AND_FORGET tasks to complete - await asyncio.gather(*result.background_tasks, return_exceptions=True) + await result.wait_for_background_tasks() # With pool=1, max concurrency should be 1 assert concurrency_high_water <= 1 @@ -620,7 +647,7 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing # F&F is async — wait for it deterministically - await asyncio.gather(*result.background_tasks, return_exceptions=True) + await result.wait_for_background_tasks() assert phase_log == ["seq", "xform", "audit", "conc", "fnf"] @@ -666,7 +693,7 @@ async def prompt_pre_fetch(self, payload, context): # FIRE_AND_FORGET has not yet completed (fire-and-forget) assert not fire_and_forget_started.is_set() - await asyncio.gather(*result.background_tasks, return_exceptions=True) + await result.wait_for_background_tasks() assert "fire_and_forget" in phase_log # FIRE_AND_FORGET always comes after sequential in the log assert phase_log.index("sequential") < phase_log.index("fire_and_forget")