Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions cpex/framework/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ async def execute(
)

# FIRE_AND_FORGET: fire-and-forget background tasks (fires last with final payload snapshot)
self._fire_and_forget_tasks(
bg_tasks = self._fire_and_forget_tasks(
fire_and_forget_refs,
payload,
global_context,
Expand All @@ -373,6 +373,7 @@ async def execute(
modified_extensions=current_extensions,
violation=None,
metadata=combined_metadata,
background_tasks=bg_tasks,
),
res_local_contexts,
)
Expand Down Expand Up @@ -688,7 +689,7 @@ def _build_halt_result(
extensions: Optional[Extensions] = None,
) -> tuple[PluginResult, dict]:
"""Schedule fire-and-forget tasks and build a pipeline-halting result."""
self._fire_and_forget_tasks(
bg_tasks = self._fire_and_forget_tasks(
fire_and_forget_refs,
payload,
global_context,
Expand All @@ -704,6 +705,7 @@ def _build_halt_result(
modified_payload=current_payload,
violation=violation,
metadata=combined_metadata,
background_tasks=bg_tasks,
),
res_local_contexts,
)
Expand All @@ -728,12 +730,14 @@ def _fire_and_forget_tasks(
res_local_contexts: dict,
semaphore: Optional[asyncio.Semaphore],
extensions: Optional[Extensions] = None,
) -> None:
) -> list[asyncio.Task]:
"""Schedule all FIRE_AND_FORGET plugins as fire-and-forget background tasks.

May be called from an early-exit path or from the normal completion path.
Each FIRE_AND_FORGET plugin receives an isolated snapshot of the payload at call time.
Returns the list of asyncio.Task handles for all newly scheduled tasks.
"""
tasks: list[asyncio.Task] = []
for ref in fire_and_forget_refs:
local_context_key = global_context.request_id + ref.plugin_ref.uuid
if local_context_key in res_local_contexts:
Expand All @@ -752,9 +756,11 @@ def _fire_and_forget_tasks(
)
local_context = PluginContext(global_context=tmp_gc)
res_local_contexts[local_context_key] = local_context
asyncio.create_task(
task = asyncio.create_task(
self._run_fire_and_forget_task(ref, task_input, local_context, semaphore, extensions=extensions)
)
tasks.append(task)
return tasks

async def _run_fire_and_forget_task(
self,
Expand All @@ -763,9 +769,10 @@ async def _run_fire_and_forget_task(
local_context: PluginContext,
semaphore: Optional[asyncio.Semaphore],
extensions: Optional[Extensions] = None,
) -> None:
) -> Optional[PluginErrorModel]:
"""Execute a plugin as a fire-and-forget background task.

Returns None on success, or a PluginErrorModel if the plugin raised.
Errors are logged but never propagated — background tasks cannot halt the pipeline.
If on_error=DISABLE, the plugin is added to the runtime-disabled set.
"""
Expand All @@ -775,11 +782,13 @@ async def _run_fire_and_forget_task(
await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions)
else:
await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions)
except Exception:
return None
except Exception as exc:
logger.error("Plugin %s failed in fire-and-forget mode (ignored)", hook_ref.plugin_ref.name)
if hook_ref.plugin_ref.on_error == OnError.DISABLE:
self._runtime_disabled.add(hook_ref.plugin_ref.name)
# FAIL and IGNORE both just log for FIRE_AND_FORGET mode (background can't halt pipeline)
return PluginErrorModel(message=repr(exc), plugin_name=hook_ref.plugin_ref.name)

async def execute_plugin(
self,
Expand Down
21 changes: 21 additions & 0 deletions cpex/framework/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

# Standard
import asyncio
import logging
import os
import re
Expand Down Expand Up @@ -1498,6 +1499,9 @@ class PluginResult(BaseModel, Generic[T]):
(e.g., updated HTTP headers from token delegation, appended security labels).
violation (Optional[PluginViolation]): violation object.
metadata (Optional[dict[str, Any]]): additional metadata.
background_tasks (list[asyncio.Task]): asyncio.Task handles for any FIRE_AND_FORGET
plugins scheduled during this invocation. Use ``wait_for_background_tasks()``
to await them and collect any errors. This field is excluded from model serialization.

Examples:
>>> result = PluginResult()
Expand All @@ -1522,11 +1526,28 @@ class PluginResult(BaseModel, Generic[T]):
False
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

continue_processing: bool = True
modified_payload: Optional[T] = None
modified_extensions: Optional[Extensions] = None
violation: Optional[PluginViolation] = None
metadata: Optional[dict[str, Any]] = Field(default_factory=dict)
background_tasks: list[asyncio.Task] = Field(default_factory=list, exclude=True)

async def wait_for_background_tasks(self) -> "list[PluginErrorModel]":
"""Await all FIRE_AND_FORGET background tasks and return any errors.

Returns an empty list if all tasks completed without error.

Examples:
>>> result = PluginResult()
>>> # errors = await result.wait_for_background_tasks()
"""
if not self.background_tasks:
return []
results = await asyncio.gather(*self.background_tasks, return_exceptions=True)
return [r for r in results if isinstance(r, PluginErrorModel)]


class GlobalContext(BaseModel):
Expand Down
8 changes: 8 additions & 0 deletions docs/specs/plugin-framework-spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ class PluginResult(Generic[T]):
modified_payload: T | None = None
violation: PluginViolation | None = None
metadata: dict[str, Any] = {}
background_tasks: list[asyncio.Task] = [] # excluded from serialization
```

`background_tasks` contains the `asyncio.Task` handles for any `FIRE_AND_FORGET` plugins scheduled during the invocation. Use `wait_for_background_tasks()` to await them and collect any errors:

```python
result, _ = await manager.invoke_hook(...)
errors = await result.wait_for_background_tasks() # list[PluginErrorModel], empty on success
```

The `PluginViolation` type carries structured policy failure information:
Expand Down
45 changes: 36 additions & 9 deletions tests/unit/cpex/framework/test_plugin_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ async def prompt_pre_fetch(self, payload, context):
assert result.continue_processing
assert not finished.is_set()

# Let the background task complete
await asyncio.sleep(0.1)
# Wait deterministically for the background task to complete
await result.wait_for_background_tasks()
assert finished.is_set()

await manager.shutdown()
Expand All @@ -129,8 +129,35 @@ async def prompt_pre_fetch(self, payload, context):

assert result.continue_processing

# Allow background task to run and silently fail
await asyncio.sleep(0.05)
# Wait for the background task; errors are returned, not raised
await result.wait_for_background_tasks()

await manager.shutdown()


@pytest.mark.asyncio
async def test_wait_for_background_tasks_returns_errors():
"""wait_for_background_tasks() returns a PluginErrorModel for each failed task."""

class BrokenPlugin(Plugin):
async def prompt_pre_fetch(self, payload, context):
raise RuntimeError("boom")

manager = await _make_manager()
cfg = make_plugin_config("BrokenFnF", PluginMode.FIRE_AND_FORGET)
plugin = BrokenPlugin(cfg)

with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get:
mock_get.return_value = [HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(plugin))]
payload = PromptPrehookPayload(prompt_id="test", args={})
global_context = GlobalContext(request_id="wait_errors")

result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context)

errors = await result.wait_for_background_tasks()
assert len(errors) == 1
assert errors[0].plugin_name == "BrokenFnF"
assert "RuntimeError" in errors[0].message

await manager.shutdown()

Expand Down Expand Up @@ -356,8 +383,8 @@ async def prompt_pre_fetch(self, payload, context):

assert result.continue_processing

# Allow all background FIRE_AND_FORGET tasks to complete
await asyncio.sleep(0.1)
# Wait deterministically for all background FIRE_AND_FORGET tasks to complete
await result.wait_for_background_tasks()

# With pool=1, max concurrency should be 1
assert concurrency_high_water <= 1
Expand Down Expand Up @@ -619,8 +646,8 @@ async def prompt_pre_fetch(self, payload, context):
result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context)

assert result.continue_processing
# F&F is async — wait for it
await asyncio.sleep(0.1)
# F&F is async — wait for it deterministically
await result.wait_for_background_tasks()

assert phase_log == ["seq", "xform", "audit", "conc", "fnf"]

Expand Down Expand Up @@ -666,7 +693,7 @@ async def prompt_pre_fetch(self, payload, context):
# FIRE_AND_FORGET has not yet completed (fire-and-forget)
assert not fire_and_forget_started.is_set()

await asyncio.sleep(0.1)
await result.wait_for_background_tasks()
assert "fire_and_forget" in phase_log
# FIRE_AND_FORGET always comes after sequential in the log
assert phase_log.index("sequential") < phase_log.index("fire_and_forget")
Expand Down
Loading