Skip to content

Commit 1e1f36b

Browse files
authored
fix: Replace done_callback with coroutine chain for judge tracking (#147)
1 parent 2c30d75 commit 1e1f36b

2 files changed

Lines changed: 239 additions & 12 deletions

File tree

packages/sdk/server-ai/src/ldai/managed_model.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from typing import List, Optional
33

4+
from ldai import log
45
from ldai.models import AICompletionConfig, LDMessage
56
from ldai.providers.model_runner import ModelRunner
67
from ldai.providers.types import JudgeResult, ModelResponse
@@ -62,19 +63,21 @@ def _track_judge_results(
6263
input_text: str,
6364
output_text: str,
6465
) -> asyncio.Task[List[JudgeResult]]:
65-
eval_task = self._ai_config.evaluator.evaluate(input_text, output_text)
66-
67-
def _on_done(task: asyncio.Task) -> None:
68-
if task.cancelled():
69-
return
70-
if task.exception() is not None:
71-
return
72-
for r in task.result():
73-
if r.success:
74-
tracker.track_judge_result(r)
66+
evaluator_task = self._ai_config.evaluator.evaluate(input_text, output_text)
7567

76-
eval_task.add_done_callback(_on_done)
77-
return eval_task
68+
async def _run_and_track(eval_task: asyncio.Task) -> List[JudgeResult]:
69+
results = await eval_task
70+
for r in results:
71+
if r.success:
72+
try:
73+
tracker.track_judge_result(r)
74+
except Exception as exc:
75+
log.warning("Judge evaluation failed: %s", exc)
76+
else:
77+
log.warning("Judge evaluation failed: %s", r.error_message)
78+
return results
79+
80+
return asyncio.create_task(_run_and_track(evaluator_task))
7881

7982
def get_messages(self, include_config_messages: bool = False) -> List[LDMessage]:
8083
"""
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
"""Tests for ManagedModel — specifically the evaluations tracking chain."""
2+
3+
import asyncio
4+
from typing import List
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
import pytest
8+
9+
from ldai.evaluator import Evaluator
10+
from ldai.managed_model import ManagedModel
11+
from ldai.models import AICompletionConfig, LDMessage, ModelConfig, ProviderConfig
12+
from ldai.providers.types import JudgeResult, LDAIMetrics, ModelResponse
13+
from ldai.tracker import LDAIConfigTracker
14+
15+
16+
17+
def _make_model_response(content: str = 'response text') -> ModelResponse:
18+
return ModelResponse(
19+
message=LDMessage(role='assistant', content=content),
20+
metrics=LDAIMetrics(success=True, usage=None),
21+
)
22+
23+
24+
class TestManagedModelInvokeReturnsImmediately:
25+
"""invoke() must return before the evaluations task resolves."""
26+
27+
@pytest.mark.asyncio
28+
async def test_invoke_returns_before_evaluations_resolve(self):
29+
"""invoke() should return a ModelResponse before evaluations complete."""
30+
# Set up a barrier so the evaluation coroutine doesn't complete until we release it
31+
barrier = asyncio.Event()
32+
33+
async def _slow_evaluate(input_text: str, output_text: str) -> List[JudgeResult]:
34+
await barrier.wait()
35+
return []
36+
37+
evaluator = MagicMock(spec=Evaluator)
38+
evaluator.evaluate = MagicMock(
39+
side_effect=lambda i, o: asyncio.create_task(_slow_evaluate(i, o))
40+
)
41+
42+
mock_runner = MagicMock()
43+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
44+
45+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
46+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
47+
config = AICompletionConfig(
48+
key='test-config',
49+
enabled=True,
50+
create_tracker=MagicMock(return_value=mock_tracker),
51+
model=ModelConfig('gpt-4'),
52+
provider=ProviderConfig('openai'),
53+
messages=[],
54+
evaluator=evaluator,
55+
)
56+
57+
model = ManagedModel(config, mock_runner)
58+
response = await model.invoke('Hello')
59+
60+
# invoke() returned — evaluations task should still be pending
61+
assert response is not None
62+
assert response.evaluations is not None
63+
assert not response.evaluations.done(), "evaluations task should still be pending"
64+
65+
# Release the barrier and let it finish cleanly
66+
barrier.set()
67+
await response.evaluations
68+
69+
@pytest.mark.asyncio
70+
async def test_await_evaluations_collects_results(self):
71+
"""await response.evaluations should return the list of JudgeResult instances."""
72+
judge_result = JudgeResult(
73+
judge_config_key='judge-key',
74+
success=True,
75+
sampled=True,
76+
metric_key='$ld:ai:judge:relevance',
77+
score=0.9,
78+
reasoning='Good response',
79+
)
80+
81+
async def _evaluate_coro(input_text: str, output_text: str) -> List[JudgeResult]:
82+
return [judge_result]
83+
84+
evaluator = MagicMock(spec=Evaluator)
85+
evaluator.evaluate = MagicMock(
86+
side_effect=lambda i, o: asyncio.create_task(_evaluate_coro(i, o))
87+
)
88+
89+
mock_runner = MagicMock()
90+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
91+
92+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
93+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
94+
config = AICompletionConfig(
95+
key='test-config',
96+
enabled=True,
97+
create_tracker=MagicMock(return_value=mock_tracker),
98+
model=ModelConfig('gpt-4'),
99+
provider=ProviderConfig('openai'),
100+
messages=[],
101+
evaluator=evaluator,
102+
)
103+
104+
model = ManagedModel(config, mock_runner)
105+
response = await model.invoke('Hello')
106+
107+
results = await response.evaluations # type: ignore[misc]
108+
assert results == [judge_result]
109+
110+
@pytest.mark.asyncio
111+
async def test_tracking_fires_inside_awaited_chain(self):
112+
"""tracker.track_judge_result() must be called when evaluations are awaited."""
113+
judge_result = JudgeResult(
114+
judge_config_key='judge-key',
115+
success=True,
116+
sampled=True,
117+
metric_key='$ld:ai:judge:relevance',
118+
score=0.85,
119+
reasoning='Relevant answer',
120+
)
121+
122+
async def _evaluate_coro(input_text: str, output_text: str) -> List[JudgeResult]:
123+
return [judge_result]
124+
125+
evaluator = MagicMock(spec=Evaluator)
126+
evaluator.evaluate = MagicMock(
127+
side_effect=lambda i, o: asyncio.create_task(_evaluate_coro(i, o))
128+
)
129+
130+
mock_runner = MagicMock()
131+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
132+
133+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
134+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
135+
mock_tracker.track_judge_result = MagicMock()
136+
137+
config = AICompletionConfig(
138+
key='test-config',
139+
enabled=True,
140+
create_tracker=MagicMock(return_value=mock_tracker),
141+
model=ModelConfig('gpt-4'),
142+
provider=ProviderConfig('openai'),
143+
messages=[],
144+
evaluator=evaluator,
145+
)
146+
147+
model = ManagedModel(config, mock_runner)
148+
response = await model.invoke('Hello')
149+
150+
# Tracking should NOT have fired yet (before we await evaluations)
151+
mock_tracker.track_judge_result.assert_not_called()
152+
153+
# Now await the evaluations task — tracking fires inside the chain
154+
await response.evaluations # type: ignore[misc]
155+
156+
mock_tracker.track_judge_result.assert_called_once_with(judge_result)
157+
158+
@pytest.mark.asyncio
159+
async def test_tracking_not_called_for_failed_judge_result(self):
160+
"""tracker.track_judge_result() should NOT be called for unsuccessful judge results."""
161+
failed_result = JudgeResult(
162+
success=False,
163+
sampled=True,
164+
metric_key='$ld:ai:judge:relevance',
165+
error_message='Judge evaluation failed',
166+
)
167+
168+
async def _evaluate_coro(input_text: str, output_text: str) -> List[JudgeResult]:
169+
return [failed_result]
170+
171+
evaluator = MagicMock(spec=Evaluator)
172+
evaluator.evaluate = MagicMock(
173+
side_effect=lambda i, o: asyncio.create_task(_evaluate_coro(i, o))
174+
)
175+
176+
mock_runner = MagicMock()
177+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
178+
179+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
180+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
181+
mock_tracker.track_judge_result = MagicMock()
182+
183+
config = AICompletionConfig(
184+
key='test-config',
185+
enabled=True,
186+
create_tracker=MagicMock(return_value=mock_tracker),
187+
model=ModelConfig('gpt-4'),
188+
provider=ProviderConfig('openai'),
189+
messages=[],
190+
evaluator=evaluator,
191+
)
192+
193+
model = ManagedModel(config, mock_runner)
194+
response = await model.invoke('Hello')
195+
await response.evaluations # type: ignore[misc]
196+
197+
mock_tracker.track_judge_result.assert_not_called()
198+
199+
@pytest.mark.asyncio
200+
async def test_noop_evaluator_returns_empty_list(self):
201+
"""With a noop evaluator, awaiting evaluations should return an empty list."""
202+
evaluator = Evaluator.noop()
203+
204+
mock_runner = MagicMock()
205+
mock_runner.invoke_model = AsyncMock(return_value=_make_model_response())
206+
207+
mock_tracker = MagicMock(spec=LDAIConfigTracker)
208+
mock_tracker.track_metrics_of_async = AsyncMock(return_value=_make_model_response())
209+
210+
config = AICompletionConfig(
211+
key='test-config',
212+
enabled=True,
213+
create_tracker=MagicMock(return_value=mock_tracker),
214+
model=ModelConfig('gpt-4'),
215+
provider=ProviderConfig('openai'),
216+
messages=[],
217+
evaluator=evaluator,
218+
)
219+
220+
model = ManagedModel(config, mock_runner)
221+
response = await model.invoke('Hello')
222+
results = await response.evaluations # type: ignore[misc]
223+
224+
assert results == []

0 commit comments

Comments
 (0)