Skip to content

Commit 381cf75

Browse files
jsonbaileyclaude
andcommitted
fix: Replace done_callback with coroutine chain for judge tracking
`_track_judge_results` previously used `add_done_callback` to fire `track_judge_result()` after evaluation completed, but callbacks run outside the asyncio task scheduler and can execute at unpredictable times. Replace with a single `_run_and_track` coroutine wrapped in a new `asyncio.create_task`, so that awaiting `response.evaluations` guarantees both evaluation and tracker calls complete in sequence. Add `test_managed_model.py` covering: invoke() returns before evaluations resolve; awaiting evaluations collects results; tracking fires inside the awaited chain (not before); failed judge results do not trigger tracking; noop evaluator returns an empty list. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3d5a6a9 commit 381cf75

2 files changed

Lines changed: 244 additions & 10 deletions

File tree

packages/sdk/server-ai/src/ldai/managed_model.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,16 @@ def _track_judge_results(
6262
input_text: str,
6363
output_text: str,
6464
) -> asyncio.Task[List[JudgeResult]]:
65-
eval_task = self._ai_config.evaluator.evaluate(input_text, output_text)
66-
67-
def _on_done(task: asyncio.Task) -> None:
68-
if task.cancelled():
69-
return
70-
if task.exception() is not None:
71-
return
72-
for r in task.result():
65+
evaluator_task = self._ai_config.evaluator.evaluate(input_text, output_text)
66+
67+
async def _run_and_track(eval_task: asyncio.Task) -> List[JudgeResult]:
68+
results = await eval_task
69+
for r in results:
7370
if r.success:
7471
tracker.track_judge_result(r)
72+
return results
7573

76-
eval_task.add_done_callback(_on_done)
77-
return eval_task
74+
return asyncio.create_task(_run_and_track(evaluator_task))
7875

7976
def get_messages(self, include_config_messages: bool = False) -> List[LDMessage]:
8077
"""
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""Tests for ManagedModel — specifically the evaluations tracking chain."""
2+
3+
import asyncio
4+
from typing import List
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
import pytest
8+
9+
from ldai.evaluator import Evaluator
10+
from ldai.managed_model import ManagedModel
11+
from ldai.models import AICompletionConfig, LDMessage, ModelConfig, ProviderConfig
12+
from ldai.providers.types import JudgeResult, LDAIMetrics, ModelResponse
13+
from ldai.tracker import LDAIConfigTracker
14+
15+
16+
def _make_ai_completion_config(evaluator: Evaluator) -> AICompletionConfig:
17+
"""Build a minimal AICompletionConfig wired to the given evaluator."""
18+
return AICompletionConfig(
19+
key='test-config',
20+
enabled=True,
21+
create_tracker=MagicMock(return_value=MagicMock(spec=LDAIConfigTracker)),
22+
model=ModelConfig('gpt-4'),
23+
provider=ProviderConfig('openai'),
24+
messages=[LDMessage(role='system', content='You are helpful.')],
25+
evaluator=evaluator,
26+
)
27+
28+
29+
def _make_model_response(content: str = 'response text') -> ModelResponse:
30+
return ModelResponse(
31+
message=LDMessage(role='assistant', content=content),
32+
metrics=LDAIMetrics(success=True, usage=None),
33+
)
34+
35+
36+
class TestManagedModelInvokeReturnsImmediately:
37+
"""invoke() must return before the evaluations task resolves."""
38+
39+
@pytest.mark.asyncio
40+
async def test_invoke_returns_before_evaluations_resolve(self):
41+
"""invoke() should return a ModelResponse before evaluations complete."""
42+
# Set up a barrier so the evaluation coroutine doesn't complete until we release it
43+
barrier = asyncio.Event()
44+
45+
async def _slow_evaluate(input_text: str, output_text: str) -> List[JudgeResult]:
46+
await barrier.wait()
47+
return []
48+
49+
evaluator = MagicMock(spec=Evaluator)
50+
evaluator.evaluate = MagicMock(
51+
side_effect=lambda i, o: asyncio.create_task(_slow_evaluate(i, o))
52+
)
53+
54+
mock_runner = MagicMock()
55+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
56+
57+
config = _make_ai_completion_config(evaluator)
58+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
59+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
60+
config = AICompletionConfig(
61+
key='test-config',
62+
enabled=True,
63+
create_tracker=MagicMock(return_value=mock_tracker),
64+
model=ModelConfig('gpt-4'),
65+
provider=ProviderConfig('openai'),
66+
messages=[],
67+
evaluator=evaluator,
68+
)
69+
70+
model = ManagedModel(config, mock_runner)
71+
response = await model.invoke('Hello')
72+
73+
# invoke() returned — evaluations task should still be pending
74+
assert response is not None
75+
assert response.evaluations is not None
76+
assert not response.evaluations.done(), "evaluations task should still be pending"
77+
78+
# Release the barrier and let it finish cleanly
79+
barrier.set()
80+
await response.evaluations
81+
82+
@pytest.mark.asyncio
83+
async def test_await_evaluations_collects_results(self):
84+
"""await response.evaluations should return the list of JudgeResult instances."""
85+
judge_result = JudgeResult(
86+
judge_config_key='judge-key',
87+
success=True,
88+
sampled=True,
89+
metric_key='$ld:ai:judge:relevance',
90+
score=0.9,
91+
reasoning='Good response',
92+
)
93+
94+
async def _evaluate_coro(input_text: str, output_text: str) -> List[JudgeResult]:
95+
return [judge_result]
96+
97+
evaluator = MagicMock(spec=Evaluator)
98+
evaluator.evaluate = MagicMock(
99+
side_effect=lambda i, o: asyncio.create_task(_evaluate_coro(i, o))
100+
)
101+
102+
mock_runner = MagicMock()
103+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
104+
105+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
106+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
107+
config = AICompletionConfig(
108+
key='test-config',
109+
enabled=True,
110+
create_tracker=MagicMock(return_value=mock_tracker),
111+
model=ModelConfig('gpt-4'),
112+
provider=ProviderConfig('openai'),
113+
messages=[],
114+
evaluator=evaluator,
115+
)
116+
117+
model = ManagedModel(config, mock_runner)
118+
response = await model.invoke('Hello')
119+
120+
results = await response.evaluations # type: ignore[misc]
121+
assert results == [judge_result]
122+
123+
@pytest.mark.asyncio
124+
async def test_tracking_fires_inside_awaited_chain(self):
125+
"""tracker.track_judge_result() must be called when evaluations are awaited."""
126+
judge_result = JudgeResult(
127+
judge_config_key='judge-key',
128+
success=True,
129+
sampled=True,
130+
metric_key='$ld:ai:judge:relevance',
131+
score=0.85,
132+
reasoning='Relevant answer',
133+
)
134+
135+
async def _evaluate_coro(input_text: str, output_text: str) -> List[JudgeResult]:
136+
return [judge_result]
137+
138+
evaluator = MagicMock(spec=Evaluator)
139+
evaluator.evaluate = MagicMock(
140+
side_effect=lambda i, o: asyncio.create_task(_evaluate_coro(i, o))
141+
)
142+
143+
mock_runner = MagicMock()
144+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
145+
146+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
147+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
148+
mock_tracker.track_judge_result = MagicMock()
149+
150+
config = AICompletionConfig(
151+
key='test-config',
152+
enabled=True,
153+
create_tracker=MagicMock(return_value=mock_tracker),
154+
model=ModelConfig('gpt-4'),
155+
provider=ProviderConfig('openai'),
156+
messages=[],
157+
evaluator=evaluator,
158+
)
159+
160+
model = ManagedModel(config, mock_runner)
161+
response = await model.invoke('Hello')
162+
163+
# Tracking should NOT have fired yet (before we await evaluations)
164+
mock_tracker.track_judge_result.assert_not_called()
165+
166+
# Now await the evaluations task — tracking fires inside the chain
167+
await response.evaluations # type: ignore[misc]
168+
169+
mock_tracker.track_judge_result.assert_called_once_with(judge_result)
170+
171+
@pytest.mark.asyncio
172+
async def test_tracking_not_called_for_failed_judge_result(self):
173+
"""tracker.track_judge_result() should NOT be called for unsuccessful judge results."""
174+
failed_result = JudgeResult(
175+
success=False,
176+
sampled=True,
177+
metric_key='$ld:ai:judge:relevance',
178+
error_message='Judge evaluation failed',
179+
)
180+
181+
async def _evaluate_coro(input_text: str, output_text: str) -> List[JudgeResult]:
182+
return [failed_result]
183+
184+
evaluator = MagicMock(spec=Evaluator)
185+
evaluator.evaluate = MagicMock(
186+
side_effect=lambda i, o: asyncio.create_task(_evaluate_coro(i, o))
187+
)
188+
189+
mock_runner = MagicMock()
190+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
191+
192+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
193+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
194+
mock_tracker.track_judge_result = MagicMock()
195+
196+
config = AICompletionConfig(
197+
key='test-config',
198+
enabled=True,
199+
create_tracker=MagicMock(return_value=mock_tracker),
200+
model=ModelConfig('gpt-4'),
201+
provider=ProviderConfig('openai'),
202+
messages=[],
203+
evaluator=evaluator,
204+
)
205+
206+
model = ManagedModel(config, mock_runner)
207+
response = await model.invoke('Hello')
208+
await response.evaluations # type: ignore[misc]
209+
210+
mock_tracker.track_judge_result.assert_not_called()
211+
212+
@pytest.mark.asyncio
213+
async def test_noop_evaluator_returns_empty_list(self):
214+
"""With a noop evaluator, awaiting evaluations should return an empty list."""
215+
evaluator = Evaluator.noop()
216+
217+
mock_runner = MagicMock()
218+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
219+
220+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
221+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
222+
223+
config = AICompletionConfig(
224+
key='test-config',
225+
enabled=True,
226+
create_tracker=MagicMock(return_value=mock_tracker),
227+
model=ModelConfig('gpt-4'),
228+
provider=ProviderConfig('openai'),
229+
messages=[],
230+
evaluator=evaluator,
231+
)
232+
233+
model = ManagedModel(config, mock_runner)
234+
response = await model.invoke('Hello')
235+
results = await response.evaluations # type: ignore[misc]
236+
237+
assert results == []

0 commit comments

Comments
 (0)