Skip to content

Commit 2710dc9

Browse files
fix(closes OPEN-8569): handle missing active trace and improve error logging in LiteLLM integration
1 parent 5997612 commit 2710dc9

File tree

2 files changed

+165
-119
lines changed

2 files changed

+165
-119
lines changed

src/openlayer/lib/integrations/litellm_tracer.py

Lines changed: 110 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Module with methods used to trace LiteLLM completions."""
22

3+
import contextvars
34
import json
45
import logging
56
import time
@@ -16,6 +17,7 @@
1617
import litellm
1718

1819
from ..tracing import tracer
20+
from ..tracing import enums as tracer_enums
1921

2022
logger = logging.getLogger(__name__)
2123

@@ -154,113 +156,119 @@ def stream_chunks(
154156
provider = "unknown"
155157
latest_chunk_metadata = {}
156158

157-
try:
158-
i = 0
159-
for i, chunk in enumerate(chunks):
160-
raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk))
161-
162-
if i == 0:
163-
first_token_time = time.time()
164-
# Try to detect provider from the first chunk
165-
provider = detect_provider_from_chunk(chunk, model_name)
166-
167-
# Extract usage data from this chunk if available (usually in final chunks)
168-
chunk_usage = extract_usage_from_chunk(chunk)
169-
if any(v is not None for v in chunk_usage.values()):
170-
latest_usage_data = chunk_usage
159+
# Create step immediately so it's added to parent trace before parent publishes
160+
with tracer.create_step(
161+
name="LiteLLM Chat Completion",
162+
step_type=tracer_enums.StepType.CHAT_COMPLETION,
163+
inputs={"prompt": kwargs.get("messages", [])},
164+
) as step:
165+
try:
166+
i = 0
167+
for i, chunk in enumerate(chunks):
168+
raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk))
171169

172-
# Always update metadata from latest chunk (for cost, headers, etc.)
173-
chunk_metadata = extract_litellm_metadata(chunk, model_name)
174-
if chunk_metadata:
175-
latest_chunk_metadata.update(chunk_metadata)
170+
if i == 0:
171+
first_token_time = time.time()
172+
# Try to detect provider from the first chunk
173+
provider = detect_provider_from_chunk(chunk, model_name)
176174

177-
if i > 0:
178-
num_of_completion_tokens = i + 1
179-
180-
# Handle different chunk formats based on provider
181-
delta = get_delta_from_chunk(chunk)
182-
183-
if delta and hasattr(delta, 'content') and delta.content:
184-
collected_output_data.append(delta.content)
185-
elif delta and hasattr(delta, 'function_call') and delta.function_call:
186-
if delta.function_call.name:
187-
collected_function_call["name"] += delta.function_call.name
188-
if delta.function_call.arguments:
189-
collected_function_call["arguments"] += delta.function_call.arguments
190-
elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls:
191-
if delta.tool_calls[0].function.name:
192-
collected_function_call["name"] += delta.tool_calls[0].function.name
193-
if delta.tool_calls[0].function.arguments:
194-
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments
195-
196-
yield chunk
197-
198-
end_time = time.time()
199-
latency = (end_time - start_time) * 1000
200-
201-
# pylint: disable=broad-except
202-
except Exception as e:
203-
logger.error("Failed to yield chunk. %s", e)
204-
finally:
205-
# Try to add step to the trace
206-
try:
207-
collected_output_data = [message for message in collected_output_data if message is not None]
208-
if collected_output_data:
209-
output_data = "".join(collected_output_data)
210-
else:
211-
if collected_function_call["arguments"]:
212-
try:
213-
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
214-
except json.JSONDecodeError:
215-
pass
216-
output_data = collected_function_call
217-
218-
# Post-streaming calculations (after streaming is finished)
219-
completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost(
220-
chunks=raw_outputs,
221-
messages=kwargs.get("messages", []),
222-
output_content=output_data,
223-
model_name=model_name,
224-
latest_usage_data=latest_usage_data,
225-
latest_chunk_metadata=latest_chunk_metadata
226-
)
227-
228-
# Use calculated values (fall back to extracted data if calculation fails)
229-
usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {}
230-
231-
final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0)
232-
final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens)
233-
final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens)
234-
final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None)
175+
# Extract usage data from this chunk if available (usually in final chunks)
176+
chunk_usage = extract_usage_from_chunk(chunk)
177+
if any(v is not None for v in chunk_usage.values()):
178+
latest_usage_data = chunk_usage
179+
180+
# Always update metadata from latest chunk (for cost, headers, etc.)
181+
chunk_metadata = extract_litellm_metadata(chunk, model_name)
182+
if chunk_metadata:
183+
latest_chunk_metadata.update(chunk_metadata)
184+
185+
if i > 0:
186+
num_of_completion_tokens = i + 1
187+
188+
# Handle different chunk formats based on provider
189+
delta = get_delta_from_chunk(chunk)
190+
191+
if delta and hasattr(delta, 'content') and delta.content:
192+
collected_output_data.append(delta.content)
193+
elif delta and hasattr(delta, 'function_call') and delta.function_call:
194+
if delta.function_call.name:
195+
collected_function_call["name"] += delta.function_call.name
196+
if delta.function_call.arguments:
197+
collected_function_call["arguments"] += delta.function_call.arguments
198+
elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls:
199+
if delta.tool_calls[0].function.name:
200+
collected_function_call["name"] += delta.tool_calls[0].function.name
201+
if delta.tool_calls[0].function.arguments:
202+
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments
203+
204+
yield chunk
205+
206+
end_time = time.time()
207+
latency = (end_time - start_time) * 1000
235208

236-
trace_args = create_trace_args(
237-
end_time=end_time,
238-
inputs={"prompt": kwargs.get("messages", [])},
239-
output=output_data,
240-
latency=latency,
241-
tokens=final_total_tokens,
242-
prompt_tokens=final_prompt_tokens,
243-
completion_tokens=final_completion_tokens,
244-
model=model_name,
245-
model_parameters=get_model_parameters(kwargs),
246-
raw_output=raw_outputs,
247-
id=inference_id,
248-
cost=final_cost, # Use calculated cost
249-
metadata={
250-
"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None),
251-
"provider": provider,
252-
"litellm_model": model_name,
253-
**latest_chunk_metadata, # Add all LiteLLM-specific metadata
254-
},
255-
)
256-
add_to_trace(**trace_args)
257-
258209
# pylint: disable=broad-except
259210
except Exception as e:
260-
logger.error(
261-
"Failed to trace the LiteLLM completion request with Openlayer. %s",
262-
e,
263-
)
211+
logger.error("Failed to yield chunk. %s", e)
212+
finally:
213+
# Update step with final data before context manager exits
214+
try:
215+
collected_output_data = [message for message in collected_output_data if message is not None]
216+
if collected_output_data:
217+
output_data = "".join(collected_output_data)
218+
else:
219+
if collected_function_call["arguments"]:
220+
try:
221+
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
222+
except json.JSONDecodeError:
223+
pass
224+
output_data = collected_function_call
225+
226+
# Post-streaming calculations (after streaming is finished)
227+
completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost(
228+
chunks=raw_outputs,
229+
messages=kwargs.get("messages", []),
230+
output_content=output_data,
231+
model_name=model_name,
232+
latest_usage_data=latest_usage_data,
233+
latest_chunk_metadata=latest_chunk_metadata
234+
)
235+
236+
# Use calculated values (fall back to extracted data if calculation fails)
237+
usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {}
238+
239+
final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0)
240+
final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens)
241+
final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens)
242+
final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None)
243+
244+
# Update the step with final trace data
245+
step.log(
246+
output=output_data,
247+
latency=latency,
248+
tokens=final_total_tokens,
249+
prompt_tokens=final_prompt_tokens,
250+
completion_tokens=final_completion_tokens,
251+
model=model_name,
252+
model_parameters=get_model_parameters(kwargs),
253+
raw_output=raw_outputs,
254+
id=inference_id,
255+
cost=final_cost,
256+
provider=provider,
257+
metadata={
258+
"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None),
259+
"provider": provider,
260+
"litellm_model": model_name,
261+
**latest_chunk_metadata,
262+
},
263+
)
264+
265+
# pylint: disable=broad-except
266+
except Exception as e:
267+
if logger is not None:
268+
logger.error(
269+
"Failed to trace the LiteLLM completion request with Openlayer. %s",
270+
e,
271+
)
264272

265273

266274
def handle_non_streaming_completion(

src/openlayer/lib/tracing/tracer.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def __init__(self):
496496
self._token = None
497497
self._output_chunks = []
498498
self._trace_initialized = False
499+
self._captured_context = None # Capture context for ASGI compatibility
499500

500501
def __iter__(self):
501502
return self
@@ -522,26 +523,26 @@ def __next__(self):
522523
try:
523524
chunk = next(self._original_gen)
524525
self._output_chunks.append(chunk)
526+
if self._captured_context is None:
527+
self._captured_context = contextvars.copy_context()
525528
return chunk
526529
except StopIteration:
527530
# Finalize trace when generator is exhausted
531+
# Use captured context to ensure we have access to the trace
528532
output = _join_output_chunks(self._output_chunks)
529-
_finalize_sync_generator_step(
530-
step=self._step,
531-
token=self._token,
532-
is_root_step=self._is_root_step,
533-
step_name=step_name,
534-
inputs=self._inputs,
535-
output=output,
536-
inference_pipeline_id=inference_pipeline_id,
537-
on_flush_failure=on_flush_failure,
538-
)
539-
raise
540-
except Exception as exc:
541-
# Handle exceptions
542-
if self._step:
543-
_log_step_exception(self._step, exc)
544-
output = _join_output_chunks(self._output_chunks)
533+
if self._captured_context:
534+
self._captured_context.run(
535+
_finalize_sync_generator_step,
536+
step=self._step,
537+
token=self._token,
538+
is_root_step=self._is_root_step,
539+
step_name=step_name,
540+
inputs=self._inputs,
541+
output=output,
542+
inference_pipeline_id=inference_pipeline_id,
543+
on_flush_failure=on_flush_failure,
544+
)
545+
else:
545546
_finalize_sync_generator_step(
546547
step=self._step,
547548
token=self._token,
@@ -553,6 +554,35 @@ def __next__(self):
553554
on_flush_failure=on_flush_failure,
554555
)
555556
raise
557+
except Exception as exc:
558+
# Handle exceptions
559+
if self._step:
560+
_log_step_exception(self._step, exc)
561+
output = _join_output_chunks(self._output_chunks)
562+
if self._captured_context:
563+
self._captured_context.run(
564+
_finalize_sync_generator_step,
565+
step=self._step,
566+
token=self._token,
567+
is_root_step=self._is_root_step,
568+
step_name=step_name,
569+
inputs=self._inputs,
570+
output=output,
571+
inference_pipeline_id=inference_pipeline_id,
572+
on_flush_failure=on_flush_failure,
573+
)
574+
else:
575+
_finalize_sync_generator_step(
576+
step=self._step,
577+
token=self._token,
578+
is_root_step=self._is_root_step,
579+
step_name=step_name,
580+
inputs=self._inputs,
581+
output=output,
582+
inference_pipeline_id=inference_pipeline_id,
583+
on_flush_failure=on_flush_failure,
584+
)
585+
raise
556586

557587
return TracedSyncGenerator()
558588

@@ -1349,6 +1379,14 @@ def _handle_trace_completion(
13491379
logger.debug("Ending the trace...")
13501380
current_trace = get_current_trace()
13511381

1382+
if current_trace is None:
1383+
logger.warning(
1384+
"Cannot complete trace for step '%s': no active trace found. "
1385+
"This can happen when OPENLAYER_DISABLE_PUBLISH=true or trace context was lost.",
1386+
step_name,
1387+
)
1388+
return
1389+
13521390
trace_data, input_variable_names = post_process_trace(current_trace)
13531391

13541392
config = dict(
@@ -1644,7 +1682,7 @@ async def _invoke_with_context(
16441682

16451683

16461684
def post_process_trace(
1647-
trace_obj: traces.Trace,
1685+
trace_obj: Optional[traces.Trace],
16481686
) -> Tuple[Dict[str, Any], List[str]]:
16491687
"""Post processing of the trace data before uploading to Openlayer.
16501688

0 commit comments

Comments
 (0)