Skip to content

Commit 6611c05

Browse files
refactor(litellm_tracer): streamline chunk processing and enhance tracing with structured logging
1 parent d6f3859 commit 6611c05

File tree

2 files changed

+108
-122
lines changed

2 files changed

+108
-122
lines changed

src/openlayer/lib/integrations/litellm_tracer.py

Lines changed: 108 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import litellm
1818

1919
from ..tracing import tracer
20+
from ..tracing import enums as tracer_enums
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -154,121 +155,120 @@ def stream_chunks(
154155
latest_usage_data = {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}
155156
provider = "unknown"
156157
latest_chunk_metadata = {}
157-
captured_context = contextvars.copy_context()
158158

159-
try:
160-
i = 0
161-
for i, chunk in enumerate(chunks):
162-
raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk))
163-
164-
if i == 0:
165-
first_token_time = time.time()
166-
# Try to detect provider from the first chunk
167-
provider = detect_provider_from_chunk(chunk, model_name)
168-
169-
# Extract usage data from this chunk if available (usually in final chunks)
170-
chunk_usage = extract_usage_from_chunk(chunk)
171-
if any(v is not None for v in chunk_usage.values()):
172-
latest_usage_data = chunk_usage
159+
# Create step immediately so it's added to parent trace before parent publishes
160+
with tracer.create_step(
161+
name="LiteLLM Chat Completion",
162+
step_type=tracer_enums.StepType.CHAT_COMPLETION,
163+
inputs={"prompt": kwargs.get("messages", [])},
164+
) as step:
165+
try:
166+
i = 0
167+
for i, chunk in enumerate(chunks):
168+
raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk))
173169

174-
# Always update metadata from latest chunk (for cost, headers, etc.)
175-
chunk_metadata = extract_litellm_metadata(chunk, model_name)
176-
if chunk_metadata:
177-
latest_chunk_metadata.update(chunk_metadata)
170+
if i == 0:
171+
first_token_time = time.time()
172+
# Try to detect provider from the first chunk
173+
provider = detect_provider_from_chunk(chunk, model_name)
178174

179-
if i > 0:
180-
num_of_completion_tokens = i + 1
181-
182-
# Handle different chunk formats based on provider
183-
delta = get_delta_from_chunk(chunk)
184-
185-
if delta and hasattr(delta, 'content') and delta.content:
186-
collected_output_data.append(delta.content)
187-
elif delta and hasattr(delta, 'function_call') and delta.function_call:
188-
if delta.function_call.name:
189-
collected_function_call["name"] += delta.function_call.name
190-
if delta.function_call.arguments:
191-
collected_function_call["arguments"] += delta.function_call.arguments
192-
elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls:
193-
if delta.tool_calls[0].function.name:
194-
collected_function_call["name"] += delta.tool_calls[0].function.name
195-
if delta.tool_calls[0].function.arguments:
196-
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments
197-
198-
yield chunk
199-
200-
end_time = time.time()
201-
latency = (end_time - start_time) * 1000
202-
203-
# pylint: disable=broad-except
204-
except Exception as e:
205-
logger.error("Failed to yield chunk. %s", e)
206-
finally:
207-
# #region agent log - Debug: trace finally block execution
208-
_parent = tracer.get_current_step()
209-
_trace = tracer.get_current_trace()
210-
print(f"[OPENLAYER_DEBUG] litellm_tracer.py:finally | has_parent_step={_parent is not None} | parent_step_name={_parent.name if _parent else None} | has_trace={_trace is not None} | trace_steps_count={len(_trace.steps) if _trace else 0}", flush=True)
211-
# #endregion
212-
# Try to add step to the trace
213-
try:
214-
collected_output_data = [message for message in collected_output_data if message is not None]
215-
if collected_output_data:
216-
output_data = "".join(collected_output_data)
217-
else:
218-
if collected_function_call["arguments"]:
219-
try:
220-
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
221-
except json.JSONDecodeError:
222-
pass
223-
output_data = collected_function_call
224-
225-
# Post-streaming calculations (after streaming is finished)
226-
completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost(
227-
chunks=raw_outputs,
228-
messages=kwargs.get("messages", []),
229-
output_content=output_data,
230-
model_name=model_name,
231-
latest_usage_data=latest_usage_data,
232-
latest_chunk_metadata=latest_chunk_metadata
233-
)
234-
235-
# Use calculated values (fall back to extracted data if calculation fails)
236-
usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {}
237-
238-
final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0)
239-
final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens)
240-
final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens)
241-
final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None)
175+
# Extract usage data from this chunk if available (usually in final chunks)
176+
chunk_usage = extract_usage_from_chunk(chunk)
177+
if any(v is not None for v in chunk_usage.values()):
178+
latest_usage_data = chunk_usage
179+
180+
# Always update metadata from latest chunk (for cost, headers, etc.)
181+
chunk_metadata = extract_litellm_metadata(chunk, model_name)
182+
if chunk_metadata:
183+
latest_chunk_metadata.update(chunk_metadata)
184+
185+
if i > 0:
186+
num_of_completion_tokens = i + 1
187+
188+
# Handle different chunk formats based on provider
189+
delta = get_delta_from_chunk(chunk)
190+
191+
if delta and hasattr(delta, 'content') and delta.content:
192+
collected_output_data.append(delta.content)
193+
elif delta and hasattr(delta, 'function_call') and delta.function_call:
194+
if delta.function_call.name:
195+
collected_function_call["name"] += delta.function_call.name
196+
if delta.function_call.arguments:
197+
collected_function_call["arguments"] += delta.function_call.arguments
198+
elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls:
199+
if delta.tool_calls[0].function.name:
200+
collected_function_call["name"] += delta.tool_calls[0].function.name
201+
if delta.tool_calls[0].function.arguments:
202+
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments
203+
204+
yield chunk
205+
206+
end_time = time.time()
207+
latency = (end_time - start_time) * 1000
242208

243-
trace_args = create_trace_args(
244-
end_time=end_time,
245-
inputs={"prompt": kwargs.get("messages", [])},
246-
output=output_data,
247-
latency=latency,
248-
tokens=final_total_tokens,
249-
prompt_tokens=final_prompt_tokens,
250-
completion_tokens=final_completion_tokens,
251-
model=model_name,
252-
model_parameters=get_model_parameters(kwargs),
253-
raw_output=raw_outputs,
254-
id=inference_id,
255-
cost=final_cost, # Use calculated cost
256-
metadata={
257-
"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None),
258-
"provider": provider,
259-
"litellm_model": model_name,
260-
**latest_chunk_metadata, # Add all LiteLLM-specific metadata
261-
},
262-
)
263-
captured_context.run(add_to_trace, **trace_args)
264-
265209
# pylint: disable=broad-except
266210
except Exception as e:
267-
if logger is not None:
268-
logger.error(
269-
"Failed to trace the LiteLLM completion request with Openlayer. %s",
270-
e,
211+
logger.error("Failed to yield chunk. %s", e)
212+
finally:
213+
# Update step with final data before context manager exits
214+
try:
215+
collected_output_data = [message for message in collected_output_data if message is not None]
216+
if collected_output_data:
217+
output_data = "".join(collected_output_data)
218+
else:
219+
if collected_function_call["arguments"]:
220+
try:
221+
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
222+
except json.JSONDecodeError:
223+
pass
224+
output_data = collected_function_call
225+
226+
# Post-streaming calculations (after streaming is finished)
227+
completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost(
228+
chunks=raw_outputs,
229+
messages=kwargs.get("messages", []),
230+
output_content=output_data,
231+
model_name=model_name,
232+
latest_usage_data=latest_usage_data,
233+
latest_chunk_metadata=latest_chunk_metadata
271234
)
235+
236+
# Use calculated values (fall back to extracted data if calculation fails)
237+
usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {}
238+
239+
final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0)
240+
final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens)
241+
final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens)
242+
final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None)
243+
244+
# Update the step with final trace data
245+
step.log(
246+
output=output_data,
247+
latency=latency,
248+
tokens=final_total_tokens,
249+
prompt_tokens=final_prompt_tokens,
250+
completion_tokens=final_completion_tokens,
251+
model=model_name,
252+
model_parameters=get_model_parameters(kwargs),
253+
raw_output=raw_outputs,
254+
id=inference_id,
255+
cost=final_cost,
256+
provider=provider,
257+
metadata={
258+
"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None),
259+
"provider": provider,
260+
"litellm_model": model_name,
261+
**latest_chunk_metadata,
262+
},
263+
)
264+
265+
# pylint: disable=broad-except
266+
except Exception as e:
267+
if logger is not None:
268+
logger.error(
269+
"Failed to trace the LiteLLM completion request with Openlayer. %s",
270+
e,
271+
)
272272

273273

274274
def handle_non_streaming_completion(

src/openlayer/lib/tracing/tracer.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,10 +1322,6 @@ def _create_and_initialize_step(
13221322
parent_step = get_current_step()
13231323
is_root_step = parent_step is None
13241324

1325-
# #region agent log - Debug: step creation
1326-
print(f"[OPENLAYER_DEBUG] tracer.py:_create_and_initialize_step | step_name={step_name} | step_type={step_type} | is_root_step={is_root_step} | parent_step_name={parent_step.name if parent_step else None}", flush=True)
1327-
# #endregion
1328-
13291325
if parent_step is None:
13301326
logger.debug("Starting a new trace...")
13311327
current_trace = traces.Trace()
@@ -1349,16 +1345,6 @@ def _handle_trace_completion(
13491345
on_flush_failure: Optional[OnFlushFailureCallback] = None,
13501346
) -> None:
13511347
"""Handle trace completion and data streaming."""
1352-
# #region agent log - Debug: trace completion
1353-
_trace = get_current_trace()
1354-
_steps = [s.name for s in _trace.steps] if _trace and _trace.steps else []
1355-
_nested = []
1356-
if _trace and _trace.steps:
1357-
for s in _trace.steps:
1358-
if hasattr(s, 'steps') and s.steps:
1359-
_nested.extend([ns.name for ns in s.steps])
1360-
print(f"[OPENLAYER_DEBUG] tracer.py:_handle_trace_completion | step_name={step_name} | is_root_step={is_root_step} | has_trace={_trace is not None} | root_steps={_steps} | nested_steps={_nested}", flush=True)
1361-
# #endregion
13621348
if is_root_step:
13631349
logger.debug("Ending the trace...")
13641350
current_trace = get_current_trace()

0 commit comments

Comments
 (0)