11import asyncio
22from typing import List
33
4- from litellm import acompletion
5- from openai . types . chat . chat_completion_message import ChatCompletionMessageToolCall
4+ import logging
5+ import os
66
77from eval_protocol .dataset_logger import default_logger
8- from eval_protocol .models import EvaluationRow , Message
8+ from eval_protocol .models import EvaluationRow , Message , ChatCompletionMessageToolCall
99from eval_protocol .pytest .types import RolloutProcessorConfig
1010
1111
@@ -14,6 +14,20 @@ async def default_single_turn_rollout_processor(
1414) -> List [EvaluationRow ]:
1515 """Generate a single response from any supported model provider using LiteLLM."""
1616
17+ # Quiet LiteLLM logs in test runs unless user overrode
18+ try :
19+ if os .environ .get ("LITELLM_LOG" ) is None :
20+ os .environ ["LITELLM_LOG" ] = "ERROR"
21+ _llog = logging .getLogger ("LiteLLM" )
22+ _llog .setLevel (logging .CRITICAL )
23+ _llog .propagate = False
24+ for _h in list (_llog .handlers ):
25+ _llog .removeHandler (_h )
26+ except Exception :
27+ pass
28+
29+ # Do not modify global LiteLLM cache. Disable caching per-request instead.
30+
1731 async def process_row (row : EvaluationRow ) -> EvaluationRow :
1832 """Process a single row asynchronously."""
1933 if len (row .messages ) == 0 :
@@ -22,10 +36,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
2236 messages_payload = [{"role" : m .role , "content" : m .content } for m in row .messages ]
2337
2438 request_params = {"model" : config .model , "messages" : messages_payload , ** config .input_params }
39+ # Ensure caching is disabled only for this request (review feedback)
40+ request_params ["cache" ] = {"no-cache" : True }
41+ # Allow passing reasoning effort to Fireworks via LiteLLM using extra_body
42+ # Expected: config.input_params may contain {"reasoning": {"effort": "low|medium|high"}}
43+ if "reasoning" in config .input_params :
44+ request_params .setdefault ("extra_body" , {})
45+ request_params ["extra_body" ]["reasoning" ] = config .input_params ["reasoning" ]
2546
2647 if row .tools is not None :
2748 request_params ["tools" ] = row .tools
2849
50+ # Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
51+ import importlib
52+ _litellm = importlib .import_module ("litellm" )
53+ acompletion = getattr (_litellm , "acompletion" )
2954 response = await acompletion (** request_params )
3055
3156 assistant_content = response .choices [0 ].message .content or ""
@@ -57,8 +82,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
5782 default_logger .log (row )
5883 return row
5984
60- # Process all rows concurrently
61- tasks = [process_row (row ) for row in rows ]
85+ # Process rows with bounded concurrency if configured
86+ max_concurrent = getattr (config , "max_concurrent_rollouts" , 8 ) or 8
87+ semaphore = asyncio .Semaphore (max_concurrent )
88+
89+ async def _sem_wrapper (r : EvaluationRow ) -> EvaluationRow :
90+ async with semaphore :
91+ return await process_row (r )
92+
93+ tasks = [_sem_wrapper (row ) for row in rows ]
6294 dataset = list (await asyncio .gather (* tasks ))
6395
6496 return dataset
0 commit comments