Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions mellea/plugins/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,47 @@
_plugin_manager: Any | None = None
_plugins_enabled: bool = False
_session_tags: dict[str, set[str]] = {} # session_id -> set of plugin names
_pending_background_results: list[Any] = []
Comment thread
ajbozarth marked this conversation as resolved.
_collect_background_results: bool = False # opt-in; only tests enable this

DEFAULT_PLUGIN_TIMEOUT: int = 5 # seconds
DEFAULT_HOOK_POLICY: Literal["allow"] | Literal["deny"] = "deny"


def enable_background_collection() -> None:
"""Enable fire-and-forget result collection. Call in test fixtures before each test."""
global _collect_background_results
_collect_background_results = True


def disable_background_collection() -> None:
"""Disable fire-and-forget result collection and clear any accumulated results."""
global _collect_background_results, _pending_background_results
_collect_background_results = False
_pending_background_results = []


async def drain_background_tasks() -> None:
"""Await all accumulated FIRE_AND_FORGET tasks and clear the pending list.

Call this in tests after any operation that may have triggered fire-and-forget plugins,
to ensure side effects (metrics recording, etc.) complete before assertions.
"""
global _pending_background_results
pending, _pending_background_results = _pending_background_results, []
for result in pending:
await result.wait_for_background_tasks()


def discard_background_tasks() -> None:
"""Discard all accumulated FIRE_AND_FORGET tasks without awaiting them.

Call this in test fixtures to clear stale results from a previous event
loop before running the next test.
"""
_pending_background_results.clear()


def has_plugins(hook_type: HookType | None = None) -> bool:
"""Fast check: are plugins configured and available for the given hook type.

Expand Down Expand Up @@ -143,6 +179,7 @@ async def shutdown_plugins() -> None:
_plugin_manager = None
_plugins_enabled = False
_session_tags.clear()
_pending_background_results.clear()


def track_session_plugin(session_id: str, plugin_name: str) -> None:
Expand Down Expand Up @@ -229,6 +266,9 @@ async def invoke_hook(
violations_as_exceptions=False,
)

if _collect_background_results and result and result.background_tasks:
_pending_background_results.append(result)
Comment thread
ajbozarth marked this conversation as resolved.

if result and not result.continue_processing and result.violation:
v = result.violation
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ sandbox = [
backends = ["mellea[watsonx,hf,litellm]"]

hooks = [
"cpex>=0.1.0.dev10; python_version >= '3.11'",
"cpex>=0.1.0.dev12; python_version >= '3.11'",
"grpcio>=1.78.0",
]

Expand Down
14 changes: 6 additions & 8 deletions test/plugins/test_execution_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

from __future__ import annotations

import asyncio

import pytest

pytest.importorskip("cpex.framework")
Expand Down Expand Up @@ -344,11 +342,10 @@ async def faf_observer(payload, ctx):

register(faf_observer)

await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload())
result, _ = await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload())

# The hook runs as a background asyncio task; yield to the event loop to
# allow it to complete before asserting.
await asyncio.sleep(0.05)
assert result is not None
await result.wait_for_background_tasks()
assert invocations == ["fired"]

@pytest.mark.asyncio
Expand Down Expand Up @@ -413,9 +410,10 @@ async def enforce_second(payload, ctx):
register(faf_first)
register(enforce_second)

await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload())
result, _ = await invoke_hook(HookType.SESSION_PRE_INIT, _session_payload())

await asyncio.sleep(0.05)
assert result is not None
await result.wait_for_background_tasks()
assert order == ["enforce", "faf"]

@pytest.mark.asyncio
Expand Down
31 changes: 16 additions & 15 deletions test/telemetry/test_metrics_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Tests that backends correctly record token metrics through the telemetry system.
"""

import asyncio
import os

import pytest
Expand All @@ -12,6 +11,12 @@
IBM_GRANITE_4_HYBRID_MICRO,
IBM_GRANITE_4_HYBRID_SMALL,
)
from mellea.plugins.manager import (
disable_background_collection,
discard_background_tasks,
drain_background_tasks,
enable_background_collection,
)
from mellea.stdlib.components import Message
from mellea.stdlib.context import SimpleContext
from test.predicates import require_api_key, require_gpu
Expand Down Expand Up @@ -41,6 +46,8 @@ def metric_reader():
@pytest.fixture
def enable_metrics(monkeypatch):
"""Enable metrics for tests."""
enable_background_collection()
discard_background_tasks()
monkeypatch.setenv("MELLEA_METRICS_ENABLED", "true")
# Force reload of metrics module to pick up env vars
import importlib
Expand All @@ -52,6 +59,7 @@ def enable_metrics(monkeypatch):
# Reset after test
monkeypatch.setenv("MELLEA_METRICS_ENABLED", "false")
importlib.reload(mellea.telemetry.metrics)
disable_background_collection()


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -169,8 +177,7 @@ async def test_ollama_token_metrics_integration(enable_metrics, metric_reader, s
await mot.avalue()

# Force metrics export and collection
# Yield to event loop so FIRE_AND_FORGET plugin tasks complete
await asyncio.sleep(0.05)
await drain_background_tasks()
provider.force_flush()
metrics_data = metric_reader.get_metrics_data()

Expand Down Expand Up @@ -235,8 +242,7 @@ async def test_openai_token_metrics_integration(enable_metrics, metric_reader, s
await mot.astream()
await mot.avalue()

# Yield to event loop so FIRE_AND_FORGET plugin tasks complete
await asyncio.sleep(0.05)
await drain_background_tasks()
provider.force_flush()
metrics_data = metric_reader.get_metrics_data()

Expand Down Expand Up @@ -290,8 +296,7 @@ async def test_watsonx_token_metrics_integration(enable_metrics, metric_reader):
)
await mot.avalue()

# Yield to event loop so FIRE_AND_FORGET plugin tasks complete
await asyncio.sleep(0.05)
await drain_background_tasks()
provider.force_flush()
metrics_data = metric_reader.get_metrics_data()

Expand Down Expand Up @@ -354,8 +359,7 @@ async def test_litellm_token_metrics_integration(
await mot.astream()
await mot.avalue()

# Yield to event loop so FIRE_AND_FORGET plugin tasks complete
await asyncio.sleep(0.05)
await drain_background_tasks()
provider.force_flush()
metrics_data = metric_reader.get_metrics_data()

Expand Down Expand Up @@ -413,8 +417,7 @@ async def test_huggingface_token_metrics_integration(
await mot.astream()
await mot.avalue()

# Yield to event loop so FIRE_AND_FORGET plugin tasks complete
await asyncio.sleep(0.05)
await drain_background_tasks()
provider.force_flush()
metrics_data = metric_reader.get_metrics_data()

Expand Down Expand Up @@ -470,8 +473,7 @@ async def test_error_metrics_on_backend_failure(enable_metrics, metric_reader):
with pytest.raises(Exception):
await mot.avalue()

# Yield to event loop so FIRE_AND_FORGET plugin task completes
await asyncio.sleep(0.05)
await drain_background_tasks()
provider.force_flush()
metrics_data = metric_reader.get_metrics_data()

Expand Down Expand Up @@ -508,8 +510,7 @@ async def test_ollama_sampling_metrics_integration(enable_metrics, metric_reader
action=Instruction("Say hello"), context=ctx, backend=backend, requirements=None
)

# Yield to event loop so FIRE_AND_FORGET plugin tasks complete
await asyncio.sleep(0.05)
await drain_background_tasks()
provider.force_flush()
metrics_data = metric_reader.get_metrics_data()

Expand Down
9 changes: 5 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading