Skip to content

Commit 5715bca

Browse files
authored
ENG-3697: Improve eval sample upload progress (#670)
1 parent ba1b801 commit 5715bca

3 files changed

Lines changed: 174 additions & 35 deletions

File tree

packages/prime-evals/src/prime_evals/evals.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import warnings
55
from concurrent.futures import ThreadPoolExecutor, as_completed
6-
from typing import Any, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
77

88
import httpx
99
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
@@ -25,6 +25,16 @@ def _build_user_agent() -> str:
2525
return f"prime-evals/{__version__} python/{python_version}"
2626

2727

28+
def _samples_upload_headers(api_key: Optional[str]) -> Dict[str, str]:
29+
headers = {
30+
"Content-Type": "application/json",
31+
"User-Agent": _build_user_agent(),
32+
}
33+
if api_key:
34+
headers["Authorization"] = f"Bearer {api_key}"
35+
return headers
36+
37+
2838
class EvalsClient:
2939
"""
3040
Client for the Prime Evals API
@@ -214,6 +224,7 @@ def push_samples(
214224
samples: List[Dict[str, Any]],
215225
max_payload_bytes: int = 25 * 1024 * 1024,
216226
max_workers: int = 4,
227+
progress_callback: Optional[Callable[[int], None]] = None,
217228
) -> Dict[str, Any]:
218229
"""Push evaluation samples in adaptive batches with concurrent uploads."""
219230
if not samples:
@@ -222,34 +233,41 @@ def push_samples(
222233
raise ValueError("max_workers must be at least 1")
223234

224235
batches, skipped_count = self._build_batches(samples, max_payload_bytes)
236+
if skipped_count and progress_callback is not None:
237+
progress_callback(skipped_count)
238+
225239
total_samples_pushed = 0
226240
errors = []
227-
228-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
229-
futures = {
230-
executor.submit(self._upload_batch, evaluation_id, b): i
231-
for i, b in enumerate(batches)
232-
}
233-
for future in as_completed(futures):
234-
try:
235-
total_samples_pushed += future.result()
236-
except Exception as e:
237-
errors.append(f"Batch {futures[future] + 1}: {e}")
241+
headers = _samples_upload_headers(self.client.api_key)
242+
243+
with httpx.Client(headers=headers, timeout=300.0) as http_client:
244+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
245+
futures = {
246+
executor.submit(self._upload_batch, http_client, evaluation_id, b): i
247+
for i, b in enumerate(batches)
248+
}
249+
for future in as_completed(futures):
250+
try:
251+
uploaded_count = future.result()
252+
total_samples_pushed += uploaded_count
253+
if progress_callback is not None:
254+
progress_callback(uploaded_count)
255+
except Exception as e:
256+
errors.append(f"Batch {futures[future] + 1}: {e}")
238257

239258
if errors:
240259
raise EvalsAPIError(f"Failed to push samples: {'; '.join(errors)}")
241260

242261
return {"samples_pushed": total_samples_pushed, "samples_skipped": skipped_count}
243262

244-
def _upload_batch(self, evaluation_id: str, batch: List[Dict[str, Any]]) -> int:
263+
def _upload_batch(
264+
self,
265+
http_client: httpx.Client,
266+
evaluation_id: str,
267+
batch: List[Dict[str, Any]],
268+
) -> int:
245269
"""Upload a single batch of samples with retry on rate limit."""
246270
url = f"{self.client.base_url}/api/v1/evaluations/{evaluation_id}/samples"
247-
headers: Dict[str, str] = {
248-
"Content-Type": "application/json",
249-
"User-Agent": _build_user_agent(),
250-
}
251-
if self.client.api_key:
252-
headers["Authorization"] = f"Bearer {self.client.api_key}"
253271

254272
@retry(
255273
retry=retry_if_exception(_is_retryable),
@@ -258,7 +276,7 @@ def _upload_batch(self, evaluation_id: str, batch: List[Dict[str, Any]]) -> int:
258276
reraise=True,
259277
)
260278
def do_upload() -> int:
261-
response = httpx.post(url, json={"samples": batch}, headers=headers, timeout=300.0)
279+
response = http_client.post(url, json={"samples": batch})
262280
response.raise_for_status()
263281
return len(batch)
264282

@@ -564,6 +582,7 @@ async def push_samples(
564582
samples: List[Dict[str, Any]],
565583
max_payload_bytes: int = 25 * 1024 * 1024,
566584
max_concurrent: int = 4,
585+
progress_callback: Optional[Callable[[int], None]] = None,
567586
) -> Dict[str, Any]:
568587
"""Push evaluation samples in adaptive batches with concurrent uploads."""
569588
if not samples:
@@ -572,18 +591,18 @@ async def push_samples(
572591
raise ValueError("max_concurrent must be at least 1")
573592

574593
batches, skipped_count = self._build_batches(samples, max_payload_bytes)
594+
if skipped_count and progress_callback is not None:
595+
progress_callback(skipped_count)
596+
575597
semaphore = asyncio.Semaphore(max_concurrent)
576598
errors: List[str] = []
577599

578600
base_url = self.client.base_url
579-
headers: Dict[str, str] = {
580-
"Content-Type": "application/json",
581-
"User-Agent": _build_user_agent(),
582-
}
583-
if self.client.api_key:
584-
headers["Authorization"] = f"Bearer {self.client.api_key}"
601+
headers = _samples_upload_headers(self.client.api_key)
585602

586-
async def upload_batch(idx: int, batch: List[Dict[str, Any]]) -> int:
603+
async def upload_batch(
604+
http_client: httpx.AsyncClient, idx: int, batch: List[Dict[str, Any]]
605+
) -> int:
587606
url = f"{base_url}/api/v1/evaluations/{evaluation_id}/samples"
588607

589608
@retry(
@@ -593,22 +612,27 @@ async def upload_batch(idx: int, batch: List[Dict[str, Any]]) -> int:
593612
reraise=True,
594613
)
595614
async def do_upload() -> int:
596-
async with httpx.AsyncClient(timeout=300.0) as client:
597-
response = await client.post(url, json={"samples": batch}, headers=headers)
598-
response.raise_for_status()
599-
return len(batch)
615+
response = await http_client.post(url, json={"samples": batch})
616+
response.raise_for_status()
617+
return len(batch)
600618

601619
async with semaphore:
602620
try:
603-
return await do_upload()
621+
uploaded_count = await do_upload()
622+
if progress_callback is not None:
623+
progress_callback(uploaded_count)
624+
return uploaded_count
604625
except httpx.HTTPStatusError as e:
605626
errors.append(f"Batch {idx + 1}: HTTP {e.response.status_code}")
606627
return 0
607628
except httpx.RequestError as e:
608629
errors.append(f"Batch {idx + 1}: {e}")
609630
return 0
610631

611-
results = await asyncio.gather(*[upload_batch(i, b) for i, b in enumerate(batches)])
632+
async with httpx.AsyncClient(headers=headers, timeout=300.0) as http_client:
633+
results = await asyncio.gather(
634+
*[upload_batch(http_client, i, b) for i, b in enumerate(batches)]
635+
)
612636

613637
if errors:
614638
raise EvalsAPIError(f"Failed to push samples: {'; '.join(errors)}")

packages/prime-evals/tests/test_evals.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Tests for Prime Evals SDK"""
22

3+
import asyncio
4+
from types import SimpleNamespace
5+
36
import pytest
47

5-
from prime_evals.evals import EvalsClient
8+
from prime_evals.evals import AsyncEvalsClient, EvalsClient
69
from prime_evals.models import (
710
CreateEvaluationRequest,
811
Evaluation,
@@ -122,6 +125,101 @@ def test_sample_model_with_metadata():
122125
assert sample.info == {"batch": 1}
123126

124127

128+
def test_push_samples_reports_progress_and_reuses_http_client(monkeypatch):
129+
posts = []
130+
created_clients = []
131+
132+
class FakeResponse:
133+
def raise_for_status(self):
134+
return None
135+
136+
class FakeHttpClient:
137+
def __init__(self, **kwargs):
138+
self.kwargs = kwargs
139+
created_clients.append(self)
140+
141+
def __enter__(self):
142+
return self
143+
144+
def __exit__(self, *_args):
145+
return None
146+
147+
def post(self, url, json):
148+
posts.append({"url": url, "json": json, "headers": self.kwargs["headers"]})
149+
return FakeResponse()
150+
151+
monkeypatch.setattr("prime_evals.evals.httpx.Client", FakeHttpClient)
152+
api_client = SimpleNamespace(
153+
base_url="https://api.example",
154+
api_key="secret-token",
155+
)
156+
client = EvalsClient(api_client)
157+
progress = []
158+
159+
with pytest.warns(UserWarning, match="exceeds maximum payload size"):
160+
result = client.push_samples(
161+
"eval-1",
162+
[{"x": "a"}, {"x": "b" * 50}, {"x": "c"}],
163+
max_payload_bytes=35,
164+
max_workers=1,
165+
progress_callback=progress.append,
166+
)
167+
168+
assert result == {"samples_pushed": 2, "samples_skipped": 1}
169+
assert progress == [1, 1, 1]
170+
assert len(posts) == 2
171+
assert len(created_clients) == 1
172+
assert posts[0]["headers"]["Authorization"] == "Bearer secret-token"
173+
174+
175+
def test_async_push_samples_reports_progress_and_reuses_http_client(monkeypatch):
176+
posts = []
177+
created_clients = []
178+
179+
class FakeResponse:
180+
def raise_for_status(self):
181+
return None
182+
183+
class FakeAsyncHttpClient:
184+
def __init__(self, **kwargs):
185+
self.kwargs = kwargs
186+
created_clients.append(self)
187+
188+
async def __aenter__(self):
189+
return self
190+
191+
async def __aexit__(self, *_args):
192+
return None
193+
194+
async def post(self, url, json):
195+
posts.append({"url": url, "json": json, "headers": self.kwargs["headers"]})
196+
return FakeResponse()
197+
198+
monkeypatch.setattr("prime_evals.evals.httpx.AsyncClient", FakeAsyncHttpClient)
199+
client = AsyncEvalsClient.__new__(AsyncEvalsClient)
200+
client.client = SimpleNamespace(
201+
base_url="https://api.example",
202+
api_key="secret-token",
203+
)
204+
progress = []
205+
206+
result = asyncio.run(
207+
client.push_samples(
208+
"eval-1",
209+
[{"x": "a"}, {"x": "b"}],
210+
max_payload_bytes=35,
211+
max_concurrent=1,
212+
progress_callback=progress.append,
213+
)
214+
)
215+
216+
assert result == {"samples_pushed": 2, "samples_skipped": 0}
217+
assert progress == [1, 1]
218+
assert len(posts) == 2
219+
assert len(created_clients) == 1
220+
assert posts[0]["headers"]["Authorization"] == "Bearer secret-token"
221+
222+
125223
def test_evals_client_context_manager():
126224
"""Test EvalsClient can be used as context manager"""
127225
try:

packages/prime/src/prime_cli/commands/evals.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import typer
1010
from click.core import ParameterSource
1111
from prime_evals import EvalsAPIError, EvalsClient, InvalidEvaluationError
12+
from rich.progress import Progress
1213
from rich.syntax import Syntax
1314
from rich.table import Table
1415

@@ -1008,6 +1009,22 @@ def _resolve_eval_viewer_url(evaluation_id: str, response: Optional[dict[str, An
10081009
return get_eval_viewer_url(evaluation_id)
10091010

10101011

1012+
def _push_samples_with_progress(
1013+
client: EvalsClient, evaluation_id: str, samples: list[dict[str, Any]]
1014+
) -> None:
1015+
if not console.is_terminal:
1016+
client.push_samples(evaluation_id, samples)
1017+
return
1018+
1019+
with Progress(console=console, transient=True) as progress:
1020+
task_id = progress.add_task("Uploading samples", total=len(samples))
1021+
client.push_samples(
1022+
evaluation_id,
1023+
samples,
1024+
progress_callback=lambda uploaded: progress.update(task_id, advance=uploaded),
1025+
)
1026+
1027+
10111028
def _require_published_environment_for_eval_push(env_name: str, eval_path: Path) -> None:
10121029
console.print("[red]Error:[/red] Evaluation uploads require a pushed environment.")
10131030
console.print(
@@ -1095,7 +1112,7 @@ def _push_single_eval(
10951112
results = eval_data.get("results", [])
10961113
if results:
10971114
console.print(f"[blue]Pushing {len(results)} samples...[/blue]")
1098-
client.push_samples(eval_id, results)
1115+
_push_samples_with_progress(client, eval_id, results)
10991116
console.print("[green]✓ Samples pushed successfully[/green]")
11001117
console.print()
11011118

0 commit comments

Comments
 (0)