3333from models .rlsapi .requests import RlsapiV1InferRequest , RlsapiV1SystemInfo
3434from models .rlsapi .responses import RlsapiV1InferData , RlsapiV1InferResponse
3535from observability import InferenceEventData , build_inference_event , send_splunk_event
36- from utils .query import handle_known_apistatus_errors
36+ from utils .query import (
37+ extract_provider_and_model_from_model_id ,
38+ handle_known_apistatus_errors ,
39+ )
3740from utils .responses import (
3841 extract_text_from_response_items ,
42+ extract_token_usage ,
3943 get_mcp_tools ,
4044)
4145from utils .suid import get_suid
@@ -191,6 +195,7 @@ async def retrieve_simple_response(
191195 store = False ,
192196 )
193197 response = cast (OpenAIResponseObject , response )
198+ extract_token_usage (response .usage , model_id )
194199
195200 return extract_text_from_response_items (response .output )
196201
@@ -242,6 +247,8 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
242247 request_id : str ,
243248 error : Exception ,
244249 start_time : float ,
250+ model : str ,
251+ provider : str ,
245252) -> float :
246253 """Record metrics and queue Splunk event for an inference failure.
247254
@@ -257,7 +264,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
257264 The total inference time in seconds.
258265 """
259266 inference_time = time .monotonic () - start_time
260- metrics .llm_calls_failures_total .inc ()
267+ metrics .llm_calls_failures_total .labels ( provider , model ). inc ()
261268 _queue_splunk_event (
262269 background_tasks ,
263270 infer_request ,
@@ -272,7 +279,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
272279
273280@router .post ("/infer" , responses = infer_responses )
274281@authorize (Action .RLSAPI_V1_INFER )
275- async def infer_endpoint (
282+ async def infer_endpoint ( # pylint: disable=R0914
276283 infer_request : RlsapiV1InferRequest ,
277284 request : Request ,
278285 background_tasks : BackgroundTasks ,
@@ -307,6 +314,7 @@ async def infer_endpoint(
307314 input_source = infer_request .get_input_source ()
308315 instructions = _build_instructions (infer_request .context .systeminfo )
309316 model_id = _get_default_model_id ()
317+ provider , model = extract_provider_and_model_from_model_id (model_id )
310318 mcp_tools = await get_mcp_tools (request_headers = request .headers )
311319 logger .debug (
312320 "Request %s: Combined input source length: %d" , request_id , len (input_source )
@@ -321,19 +329,40 @@ async def infer_endpoint(
321329 except RuntimeError as e :
322330 if "context_length" in str (e ).lower ():
323331 _record_inference_failure (
324- background_tasks , infer_request , request , request_id , e , start_time
332+ background_tasks ,
333+ infer_request ,
334+ request ,
335+ request_id ,
336+ e ,
337+ start_time ,
338+ model ,
339+ provider ,
325340 )
326341 logger .error ("Prompt too long for request %s: %s" , request_id , e )
327342 error_response = PromptTooLongResponse (model = model_id )
328343 raise HTTPException (** error_response .model_dump ()) from e
329344 _record_inference_failure (
330- background_tasks , infer_request , request , request_id , e , start_time
345+ background_tasks ,
346+ infer_request ,
347+ request ,
348+ request_id ,
349+ e ,
350+ start_time ,
351+ model ,
352+ provider ,
331353 )
332354 logger .error ("Unexpected RuntimeError for request %s: %s" , request_id , e )
333355 raise
334356 except APIConnectionError as e :
335357 _record_inference_failure (
336- background_tasks , infer_request , request , request_id , e , start_time
358+ background_tasks ,
359+ infer_request ,
360+ request ,
361+ request_id ,
362+ e ,
363+ start_time ,
364+ model ,
365+ provider ,
337366 )
338367 logger .error (
339368 "Unable to connect to Llama Stack for request %s: %s" , request_id , e
@@ -345,7 +374,14 @@ async def infer_endpoint(
345374 raise HTTPException (** error_response .model_dump ()) from e
346375 except RateLimitError as e :
347376 _record_inference_failure (
348- background_tasks , infer_request , request , request_id , e , start_time
377+ background_tasks ,
378+ infer_request ,
379+ request ,
380+ request_id ,
381+ e ,
382+ start_time ,
383+ model ,
384+ provider ,
349385 )
350386 logger .error ("Rate limit exceeded for request %s: %s" , request_id , e )
351387 error_response = QuotaExceededResponse (
@@ -355,7 +391,14 @@ async def infer_endpoint(
355391 raise HTTPException (** error_response .model_dump ()) from e
356392 except (APIStatusError , OpenAIAPIStatusError ) as e :
357393 _record_inference_failure (
358- background_tasks , infer_request , request , request_id , e , start_time
394+ background_tasks ,
395+ infer_request ,
396+ request ,
397+ request_id ,
398+ e ,
399+ start_time ,
400+ model ,
401+ provider ,
359402 )
360403 logger .exception ("API error for request %s: %s" , request_id , e )
361404 error_response = handle_known_apistatus_errors (e , model_id )
0 commit comments