Skip to content

Commit e0a5728

Browse files
authored
feat: expose background_tasks in PluginResult for fire-and-forget synchronization (#33)
* feat: expose background_tasks in PluginResult for fire-and-forget synchronization Adds a background_tasks field to PluginResult containing the asyncio.Task handles created for FIRE_AND_FORGET plugins. Callers can now await background tasks deterministically instead of relying on arbitrary sleep delays. Signed-off-by: Alex Bozarth <ajbozart@us.ibm.com> * feat: add PluginResult.wait_for_background_tasks() Signed-off-by: Alex Bozarth <ajbozart@us.ibm.com> --------- Signed-off-by: Alex Bozarth <ajbozart@us.ibm.com>
1 parent 3750d22 commit e0a5728

4 files changed

Lines changed: 80 additions & 15 deletions

File tree

cpex/framework/manager.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ async def execute(
354354
)
355355

356356
# FIRE_AND_FORGET: fire-and-forget background tasks (fires last with final payload snapshot)
357-
self._fire_and_forget_tasks(
357+
bg_tasks = self._fire_and_forget_tasks(
358358
fire_and_forget_refs,
359359
payload,
360360
global_context,
@@ -373,6 +373,7 @@ async def execute(
373373
modified_extensions=current_extensions,
374374
violation=None,
375375
metadata=combined_metadata,
376+
background_tasks=bg_tasks,
376377
),
377378
res_local_contexts,
378379
)
@@ -688,7 +689,7 @@ def _build_halt_result(
688689
extensions: Optional[Extensions] = None,
689690
) -> tuple[PluginResult, dict]:
690691
"""Schedule fire-and-forget tasks and build a pipeline-halting result."""
691-
self._fire_and_forget_tasks(
692+
bg_tasks = self._fire_and_forget_tasks(
692693
fire_and_forget_refs,
693694
payload,
694695
global_context,
@@ -704,6 +705,7 @@ def _build_halt_result(
704705
modified_payload=current_payload,
705706
violation=violation,
706707
metadata=combined_metadata,
708+
background_tasks=bg_tasks,
707709
),
708710
res_local_contexts,
709711
)
@@ -728,12 +730,14 @@ def _fire_and_forget_tasks(
728730
res_local_contexts: dict,
729731
semaphore: Optional[asyncio.Semaphore],
730732
extensions: Optional[Extensions] = None,
731-
) -> None:
733+
) -> list[asyncio.Task]:
732734
"""Schedule all FIRE_AND_FORGET plugins as fire-and-forget background tasks.
733735
734736
May be called from an early-exit path or from the normal completion path.
735737
Each FIRE_AND_FORGET plugin receives an isolated snapshot of the payload at call time.
738+
Returns the list of asyncio.Task handles for all newly scheduled tasks.
736739
"""
740+
tasks: list[asyncio.Task] = []
737741
for ref in fire_and_forget_refs:
738742
local_context_key = global_context.request_id + ref.plugin_ref.uuid
739743
if local_context_key in res_local_contexts:
@@ -752,9 +756,11 @@ def _fire_and_forget_tasks(
752756
)
753757
local_context = PluginContext(global_context=tmp_gc)
754758
res_local_contexts[local_context_key] = local_context
755-
asyncio.create_task(
759+
task = asyncio.create_task(
756760
self._run_fire_and_forget_task(ref, task_input, local_context, semaphore, extensions=extensions)
757761
)
762+
tasks.append(task)
763+
return tasks
758764

759765
async def _run_fire_and_forget_task(
760766
self,
@@ -763,9 +769,10 @@ async def _run_fire_and_forget_task(
763769
local_context: PluginContext,
764770
semaphore: Optional[asyncio.Semaphore],
765771
extensions: Optional[Extensions] = None,
766-
) -> None:
772+
) -> Optional[PluginErrorModel]:
767773
"""Execute a plugin as a fire-and-forget background task.
768774
775+
Returns None on success, or a PluginErrorModel if the plugin raised.
769776
Errors are logged but never propagated — background tasks cannot halt the pipeline.
770777
If on_error=DISABLE, the plugin is added to the runtime-disabled set.
771778
"""
@@ -775,11 +782,13 @@ async def _run_fire_and_forget_task(
775782
await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions)
776783
else:
777784
await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions)
778-
except Exception:
785+
return None
786+
except Exception as exc:
779787
logger.error("Plugin %s failed in fire-and-forget mode (ignored)", hook_ref.plugin_ref.name)
780788
if hook_ref.plugin_ref.on_error == OnError.DISABLE:
781789
self._runtime_disabled.add(hook_ref.plugin_ref.name)
782790
# FAIL and IGNORE both just log for FIRE_AND_FORGET mode (background can't halt pipeline)
791+
return PluginErrorModel(message=repr(exc), plugin_name=hook_ref.plugin_ref.name)
783792

784793
async def execute_plugin(
785794
self,

cpex/framework/models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
# Standard
13+
import asyncio
1314
import logging
1415
import os
1516
import re
@@ -1498,6 +1499,9 @@ class PluginResult(BaseModel, Generic[T]):
14981499
(e.g., updated HTTP headers from token delegation, appended security labels).
14991500
violation (Optional[PluginViolation]): violation object.
15001501
metadata (Optional[dict[str, Any]]): additional metadata.
1502+
background_tasks (list[asyncio.Task]): asyncio.Task handles for any FIRE_AND_FORGET
1503+
plugins scheduled during this invocation. Use ``wait_for_background_tasks()``
1504+
to await them and collect any errors. This field is excluded from model serialization.
15011505
15021506
Examples:
15031507
>>> result = PluginResult()
@@ -1522,11 +1526,28 @@ class PluginResult(BaseModel, Generic[T]):
15221526
False
15231527
"""
15241528

1529+
model_config = ConfigDict(arbitrary_types_allowed=True)
1530+
15251531
continue_processing: bool = True
15261532
modified_payload: Optional[T] = None
15271533
modified_extensions: Optional[Extensions] = None
15281534
violation: Optional[PluginViolation] = None
15291535
metadata: Optional[dict[str, Any]] = Field(default_factory=dict)
1536+
background_tasks: list[asyncio.Task] = Field(default_factory=list, exclude=True)
1537+
1538+
async def wait_for_background_tasks(self) -> "list[PluginErrorModel]":
1539+
"""Await all FIRE_AND_FORGET background tasks and return any errors.
1540+
1541+
Returns an empty list if all tasks completed without error.
1542+
1543+
Examples:
1544+
>>> result = PluginResult()
1545+
>>> # errors = await result.wait_for_background_tasks()
1546+
"""
1547+
if not self.background_tasks:
1548+
return []
1549+
results = await asyncio.gather(*self.background_tasks, return_exceptions=True)
1550+
return [r for r in results if isinstance(r, PluginErrorModel)]
15301551

15311552

15321553
class GlobalContext(BaseModel):

docs/specs/plugin-framework-spec.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ class PluginResult(Generic[T]):
185185
modified_payload: T | None = None
186186
violation: PluginViolation | None = None
187187
metadata: dict[str, Any] = {}
188+
background_tasks: list[asyncio.Task] = [] # excluded from serialization
189+
```
190+
191+
`background_tasks` contains the `asyncio.Task` handles for any `FIRE_AND_FORGET` plugins scheduled during the invocation. Use `wait_for_background_tasks()` to await them and collect any errors:
192+
193+
```python
194+
result, _ = await manager.invoke_hook(...)
195+
errors = await result.wait_for_background_tasks() # list[PluginErrorModel], empty on success
188196
```
189197

190198
The `PluginViolation` type carries structured policy failure information:

tests/unit/cpex/framework/test_plugin_modes.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ async def prompt_pre_fetch(self, payload, context):
101101
assert result.continue_processing
102102
assert not finished.is_set()
103103

104-
# Let the background task complete
105-
await asyncio.sleep(0.1)
104+
# Wait deterministically for the background task to complete
105+
await result.wait_for_background_tasks()
106106
assert finished.is_set()
107107

108108
await manager.shutdown()
@@ -129,8 +129,35 @@ async def prompt_pre_fetch(self, payload, context):
129129

130130
assert result.continue_processing
131131

132-
# Allow background task to run and silently fail
133-
await asyncio.sleep(0.05)
132+
# Wait for the background task; errors are returned, not raised
133+
await result.wait_for_background_tasks()
134+
135+
await manager.shutdown()
136+
137+
138+
@pytest.mark.asyncio
139+
async def test_wait_for_background_tasks_returns_errors():
140+
"""wait_for_background_tasks() returns a PluginErrorModel for each failed task."""
141+
142+
class BrokenPlugin(Plugin):
143+
async def prompt_pre_fetch(self, payload, context):
144+
raise RuntimeError("boom")
145+
146+
manager = await _make_manager()
147+
cfg = make_plugin_config("BrokenFnF", PluginMode.FIRE_AND_FORGET)
148+
plugin = BrokenPlugin(cfg)
149+
150+
with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get:
151+
mock_get.return_value = [HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin))]
152+
payload = PromptPrehookPayload(prompt_id="test", args={})
153+
global_context = GlobalContext(request_id="wait_errors")
154+
155+
result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context)
156+
157+
errors = await result.wait_for_background_tasks()
158+
assert len(errors) == 1
159+
assert errors[0].plugin_name == "BrokenFnF"
160+
assert "RuntimeError" in errors[0].message
134161

135162
await manager.shutdown()
136163

@@ -356,8 +383,8 @@ async def prompt_pre_fetch(self, payload, context):
356383

357384
assert result.continue_processing
358385

359-
# Allow all background FIRE_AND_FORGET tasks to complete
360-
await asyncio.sleep(0.1)
386+
# Wait deterministically for all background FIRE_AND_FORGET tasks to complete
387+
await result.wait_for_background_tasks()
361388

362389
# With pool=1, max concurrency should be 1
363390
assert concurrency_high_water <= 1
@@ -619,8 +646,8 @@ async def prompt_pre_fetch(self, payload, context):
619646
result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context)
620647

621648
assert result.continue_processing
622-
# F&F is async — wait for it
623-
await asyncio.sleep(0.1)
649+
# F&F is async — wait for it deterministically
650+
await result.wait_for_background_tasks()
624651

625652
assert phase_log == ["seq", "xform", "audit", "conc", "fnf"]
626653

@@ -666,7 +693,7 @@ async def prompt_pre_fetch(self, payload, context):
666693
# FIRE_AND_FORGET has not yet completed (fire-and-forget)
667694
assert not fire_and_forget_started.is_set()
668695

669-
await asyncio.sleep(0.1)
696+
await result.wait_for_background_tasks()
670697
assert "fire_and_forget" in phase_log
671698
# FIRE_AND_FORGET always comes after sequential in the log
672699
assert phase_log.index("sequential") < phase_log.index("fire_and_forget")

0 commit comments

Comments
 (0)