33import sys
44import warnings
55from concurrent .futures import ThreadPoolExecutor , as_completed
6- from typing import Any , Dict , List , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
77
88import httpx
99from tenacity import retry , retry_if_exception , stop_after_attempt , wait_exponential
@@ -25,6 +25,16 @@ def _build_user_agent() -> str:
2525 return f"prime-evals/{ __version__ } python/{ python_version } "
2626
2727
28+ def _samples_upload_headers (api_key : Optional [str ]) -> Dict [str , str ]:
29+ headers = {
30+ "Content-Type" : "application/json" ,
31+ "User-Agent" : _build_user_agent (),
32+ }
33+ if api_key :
34+ headers ["Authorization" ] = f"Bearer { api_key } "
35+ return headers
36+
37+
2838class EvalsClient :
2939 """
3040 Client for the Prime Evals API
@@ -214,6 +224,7 @@ def push_samples(
214224 samples : List [Dict [str , Any ]],
215225 max_payload_bytes : int = 25 * 1024 * 1024 ,
216226 max_workers : int = 4 ,
227+ progress_callback : Optional [Callable [[int ], None ]] = None ,
217228 ) -> Dict [str , Any ]:
218229 """Push evaluation samples in adaptive batches with concurrent uploads."""
219230 if not samples :
@@ -222,34 +233,41 @@ def push_samples(
222233 raise ValueError ("max_workers must be at least 1" )
223234
224235 batches , skipped_count = self ._build_batches (samples , max_payload_bytes )
236+ if skipped_count and progress_callback is not None :
237+ progress_callback (skipped_count )
238+
225239 total_samples_pushed = 0
226240 errors = []
227-
228- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
229- futures = {
230- executor .submit (self ._upload_batch , evaluation_id , b ): i
231- for i , b in enumerate (batches )
232- }
233- for future in as_completed (futures ):
234- try :
235- total_samples_pushed += future .result ()
236- except Exception as e :
237- errors .append (f"Batch { futures [future ] + 1 } : { e } " )
241+ headers = _samples_upload_headers (self .client .api_key )
242+
243+ with httpx .Client (headers = headers , timeout = 300.0 ) as http_client :
244+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
245+ futures = {
246+ executor .submit (self ._upload_batch , http_client , evaluation_id , b ): i
247+ for i , b in enumerate (batches )
248+ }
249+ for future in as_completed (futures ):
250+ try :
251+ uploaded_count = future .result ()
252+ total_samples_pushed += uploaded_count
253+ if progress_callback is not None :
254+ progress_callback (uploaded_count )
255+ except Exception as e :
256+ errors .append (f"Batch { futures [future ] + 1 } : { e } " )
238257
239258 if errors :
240259 raise EvalsAPIError (f"Failed to push samples: { '; ' .join (errors )} " )
241260
242261 return {"samples_pushed" : total_samples_pushed , "samples_skipped" : skipped_count }
243262
244- def _upload_batch (self , evaluation_id : str , batch : List [Dict [str , Any ]]) -> int :
263+ def _upload_batch (
264+ self ,
265+ http_client : httpx .Client ,
266+ evaluation_id : str ,
267+ batch : List [Dict [str , Any ]],
268+ ) -> int :
245269 """Upload a single batch of samples with retry on rate limit."""
246270 url = f"{ self .client .base_url } /api/v1/evaluations/{ evaluation_id } /samples"
247- headers : Dict [str , str ] = {
248- "Content-Type" : "application/json" ,
249- "User-Agent" : _build_user_agent (),
250- }
251- if self .client .api_key :
252- headers ["Authorization" ] = f"Bearer { self .client .api_key } "
253271
254272 @retry (
255273 retry = retry_if_exception (_is_retryable ),
@@ -258,7 +276,7 @@ def _upload_batch(self, evaluation_id: str, batch: List[Dict[str, Any]]) -> int:
258276 reraise = True ,
259277 )
260278 def do_upload () -> int :
261- response = httpx .post (url , json = {"samples" : batch }, headers = headers , timeout = 300.0 )
279+ response = http_client .post (url , json = {"samples" : batch })
262280 response .raise_for_status ()
263281 return len (batch )
264282
@@ -564,6 +582,7 @@ async def push_samples(
564582 samples : List [Dict [str , Any ]],
565583 max_payload_bytes : int = 25 * 1024 * 1024 ,
566584 max_concurrent : int = 4 ,
585+ progress_callback : Optional [Callable [[int ], None ]] = None ,
567586 ) -> Dict [str , Any ]:
568587 """Push evaluation samples in adaptive batches with concurrent uploads."""
569588 if not samples :
@@ -572,18 +591,18 @@ async def push_samples(
572591 raise ValueError ("max_concurrent must be at least 1" )
573592
574593 batches , skipped_count = self ._build_batches (samples , max_payload_bytes )
594+ if skipped_count and progress_callback is not None :
595+ progress_callback (skipped_count )
596+
575597 semaphore = asyncio .Semaphore (max_concurrent )
576598 errors : List [str ] = []
577599
578600 base_url = self .client .base_url
579- headers : Dict [str , str ] = {
580- "Content-Type" : "application/json" ,
581- "User-Agent" : _build_user_agent (),
582- }
583- if self .client .api_key :
584- headers ["Authorization" ] = f"Bearer { self .client .api_key } "
601+ headers = _samples_upload_headers (self .client .api_key )
585602
586- async def upload_batch (idx : int , batch : List [Dict [str , Any ]]) -> int :
603+ async def upload_batch (
604+ http_client : httpx .AsyncClient , idx : int , batch : List [Dict [str , Any ]]
605+ ) -> int :
587606 url = f"{ base_url } /api/v1/evaluations/{ evaluation_id } /samples"
588607
589608 @retry (
@@ -593,22 +612,27 @@ async def upload_batch(idx: int, batch: List[Dict[str, Any]]) -> int:
593612 reraise = True ,
594613 )
595614 async def do_upload () -> int :
596- async with httpx .AsyncClient (timeout = 300.0 ) as client :
597- response = await client .post (url , json = {"samples" : batch }, headers = headers )
598- response .raise_for_status ()
599- return len (batch )
615+ response = await http_client .post (url , json = {"samples" : batch })
616+ response .raise_for_status ()
617+ return len (batch )
600618
601619 async with semaphore :
602620 try :
603- return await do_upload ()
621+ uploaded_count = await do_upload ()
622+ if progress_callback is not None :
623+ progress_callback (uploaded_count )
624+ return uploaded_count
604625 except httpx .HTTPStatusError as e :
605626 errors .append (f"Batch { idx + 1 } : HTTP { e .response .status_code } " )
606627 return 0
607628 except httpx .RequestError as e :
608629 errors .append (f"Batch { idx + 1 } : { e } " )
609630 return 0
610631
611- results = await asyncio .gather (* [upload_batch (i , b ) for i , b in enumerate (batches )])
632+ async with httpx .AsyncClient (headers = headers , timeout = 300.0 ) as http_client :
633+ results = await asyncio .gather (
634+ * [upload_batch (http_client , i , b ) for i , b in enumerate (batches )]
635+ )
612636
613637 if errors :
614638 raise EvalsAPIError (f"Failed to push samples: { '; ' .join (errors )} " )
0 commit comments