|
17 | 17 | import litellm |
18 | 18 |
|
19 | 19 | from ..tracing import tracer |
| 20 | +from ..tracing import enums as tracer_enums |
20 | 21 |
|
21 | 22 | logger = logging.getLogger(__name__) |
22 | 23 |
|
@@ -154,121 +155,120 @@ def stream_chunks( |
154 | 155 | latest_usage_data = {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None} |
155 | 156 | provider = "unknown" |
156 | 157 | latest_chunk_metadata = {} |
157 | | - captured_context = contextvars.copy_context() |
158 | 158 |
|
159 | | - try: |
160 | | - i = 0 |
161 | | - for i, chunk in enumerate(chunks): |
162 | | - raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk)) |
163 | | - |
164 | | - if i == 0: |
165 | | - first_token_time = time.time() |
166 | | - # Try to detect provider from the first chunk |
167 | | - provider = detect_provider_from_chunk(chunk, model_name) |
168 | | - |
169 | | - # Extract usage data from this chunk if available (usually in final chunks) |
170 | | - chunk_usage = extract_usage_from_chunk(chunk) |
171 | | - if any(v is not None for v in chunk_usage.values()): |
172 | | - latest_usage_data = chunk_usage |
| 159 | + # Create step immediately so it's added to parent trace before parent publishes |
| 160 | + with tracer.create_step( |
| 161 | + name="LiteLLM Chat Completion", |
| 162 | + step_type=tracer_enums.StepType.CHAT_COMPLETION, |
| 163 | + inputs={"prompt": kwargs.get("messages", [])}, |
| 164 | + ) as step: |
| 165 | + try: |
| 166 | + i = 0 |
| 167 | + for i, chunk in enumerate(chunks): |
| 168 | + raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk)) |
173 | 169 |
|
174 | | - # Always update metadata from latest chunk (for cost, headers, etc.) |
175 | | - chunk_metadata = extract_litellm_metadata(chunk, model_name) |
176 | | - if chunk_metadata: |
177 | | - latest_chunk_metadata.update(chunk_metadata) |
| 170 | + if i == 0: |
| 171 | + first_token_time = time.time() |
| 172 | + # Try to detect provider from the first chunk |
| 173 | + provider = detect_provider_from_chunk(chunk, model_name) |
178 | 174 |
|
179 | | - if i > 0: |
180 | | - num_of_completion_tokens = i + 1 |
181 | | - |
182 | | - # Handle different chunk formats based on provider |
183 | | - delta = get_delta_from_chunk(chunk) |
184 | | - |
185 | | - if delta and hasattr(delta, 'content') and delta.content: |
186 | | - collected_output_data.append(delta.content) |
187 | | - elif delta and hasattr(delta, 'function_call') and delta.function_call: |
188 | | - if delta.function_call.name: |
189 | | - collected_function_call["name"] += delta.function_call.name |
190 | | - if delta.function_call.arguments: |
191 | | - collected_function_call["arguments"] += delta.function_call.arguments |
192 | | - elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls: |
193 | | - if delta.tool_calls[0].function.name: |
194 | | - collected_function_call["name"] += delta.tool_calls[0].function.name |
195 | | - if delta.tool_calls[0].function.arguments: |
196 | | - collected_function_call["arguments"] += delta.tool_calls[0].function.arguments |
197 | | - |
198 | | - yield chunk |
199 | | - |
200 | | - end_time = time.time() |
201 | | - latency = (end_time - start_time) * 1000 |
202 | | - |
203 | | - # pylint: disable=broad-except |
204 | | - except Exception as e: |
205 | | - logger.error("Failed to yield chunk. %s", e) |
206 | | - finally: |
207 | | - # #region agent log - Debug: trace finally block execution |
208 | | - _parent = tracer.get_current_step() |
209 | | - _trace = tracer.get_current_trace() |
210 | | - print(f"[OPENLAYER_DEBUG] litellm_tracer.py:finally | has_parent_step={_parent is not None} | parent_step_name={_parent.name if _parent else None} | has_trace={_trace is not None} | trace_steps_count={len(_trace.steps) if _trace else 0}", flush=True) |
211 | | - # #endregion |
212 | | - # Try to add step to the trace |
213 | | - try: |
214 | | - collected_output_data = [message for message in collected_output_data if message is not None] |
215 | | - if collected_output_data: |
216 | | - output_data = "".join(collected_output_data) |
217 | | - else: |
218 | | - if collected_function_call["arguments"]: |
219 | | - try: |
220 | | - collected_function_call["arguments"] = json.loads(collected_function_call["arguments"]) |
221 | | - except json.JSONDecodeError: |
222 | | - pass |
223 | | - output_data = collected_function_call |
224 | | - |
225 | | - # Post-streaming calculations (after streaming is finished) |
226 | | - completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost( |
227 | | - chunks=raw_outputs, |
228 | | - messages=kwargs.get("messages", []), |
229 | | - output_content=output_data, |
230 | | - model_name=model_name, |
231 | | - latest_usage_data=latest_usage_data, |
232 | | - latest_chunk_metadata=latest_chunk_metadata |
233 | | - ) |
234 | | - |
235 | | - # Use calculated values (fall back to extracted data if calculation fails) |
236 | | - usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {} |
237 | | - |
238 | | - final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0) |
239 | | - final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens) |
240 | | - final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens) |
241 | | - final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None) |
| 175 | + # Extract usage data from this chunk if available (usually in final chunks) |
| 176 | + chunk_usage = extract_usage_from_chunk(chunk) |
| 177 | + if any(v is not None for v in chunk_usage.values()): |
| 178 | + latest_usage_data = chunk_usage |
| 179 | + |
| 180 | + # Always update metadata from latest chunk (for cost, headers, etc.) |
| 181 | + chunk_metadata = extract_litellm_metadata(chunk, model_name) |
| 182 | + if chunk_metadata: |
| 183 | + latest_chunk_metadata.update(chunk_metadata) |
| 184 | + |
| 185 | + if i > 0: |
| 186 | + num_of_completion_tokens = i + 1 |
| 187 | + |
| 188 | + # Handle different chunk formats based on provider |
| 189 | + delta = get_delta_from_chunk(chunk) |
| 190 | + |
| 191 | + if delta and hasattr(delta, 'content') and delta.content: |
| 192 | + collected_output_data.append(delta.content) |
| 193 | + elif delta and hasattr(delta, 'function_call') and delta.function_call: |
| 194 | + if delta.function_call.name: |
| 195 | + collected_function_call["name"] += delta.function_call.name |
| 196 | + if delta.function_call.arguments: |
| 197 | + collected_function_call["arguments"] += delta.function_call.arguments |
| 198 | + elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls: |
| 199 | + if delta.tool_calls[0].function.name: |
| 200 | + collected_function_call["name"] += delta.tool_calls[0].function.name |
| 201 | + if delta.tool_calls[0].function.arguments: |
| 202 | + collected_function_call["arguments"] += delta.tool_calls[0].function.arguments |
| 203 | + |
| 204 | + yield chunk |
| 205 | + |
| 206 | + end_time = time.time() |
| 207 | + latency = (end_time - start_time) * 1000 |
242 | 208 |
|
243 | | - trace_args = create_trace_args( |
244 | | - end_time=end_time, |
245 | | - inputs={"prompt": kwargs.get("messages", [])}, |
246 | | - output=output_data, |
247 | | - latency=latency, |
248 | | - tokens=final_total_tokens, |
249 | | - prompt_tokens=final_prompt_tokens, |
250 | | - completion_tokens=final_completion_tokens, |
251 | | - model=model_name, |
252 | | - model_parameters=get_model_parameters(kwargs), |
253 | | - raw_output=raw_outputs, |
254 | | - id=inference_id, |
255 | | - cost=final_cost, # Use calculated cost |
256 | | - metadata={ |
257 | | - "timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None), |
258 | | - "provider": provider, |
259 | | - "litellm_model": model_name, |
260 | | - **latest_chunk_metadata, # Add all LiteLLM-specific metadata |
261 | | - }, |
262 | | - ) |
263 | | - captured_context.run(add_to_trace, **trace_args) |
264 | | - |
265 | 209 | # pylint: disable=broad-except |
266 | 210 | except Exception as e: |
267 | | - if logger is not None: |
268 | | - logger.error( |
269 | | - "Failed to trace the LiteLLM completion request with Openlayer. %s", |
270 | | - e, |
| 211 | + logger.error("Failed to yield chunk. %s", e) |
| 212 | + finally: |
| 213 | + # Update step with final data before context manager exits |
| 214 | + try: |
| 215 | + collected_output_data = [message for message in collected_output_data if message is not None] |
| 216 | + if collected_output_data: |
| 217 | + output_data = "".join(collected_output_data) |
| 218 | + else: |
| 219 | + if collected_function_call["arguments"]: |
| 220 | + try: |
| 221 | + collected_function_call["arguments"] = json.loads(collected_function_call["arguments"]) |
| 222 | + except json.JSONDecodeError: |
| 223 | + pass |
| 224 | + output_data = collected_function_call |
| 225 | + |
| 226 | + # Post-streaming calculations (after streaming is finished) |
| 227 | + completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost( |
| 228 | + chunks=raw_outputs, |
| 229 | + messages=kwargs.get("messages", []), |
| 230 | + output_content=output_data, |
| 231 | + model_name=model_name, |
| 232 | + latest_usage_data=latest_usage_data, |
| 233 | + latest_chunk_metadata=latest_chunk_metadata |
271 | 234 | ) |
| 235 | + |
| 236 | + # Use calculated values (fall back to extracted data if calculation fails) |
| 237 | + usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {} |
| 238 | + |
| 239 | + final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0) |
| 240 | + final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens) |
| 241 | + final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens) |
| 242 | + final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None) |
| 243 | + |
| 244 | + # Update the step with final trace data |
| 245 | + step.log( |
| 246 | + output=output_data, |
| 247 | + latency=latency, |
| 248 | + tokens=final_total_tokens, |
| 249 | + prompt_tokens=final_prompt_tokens, |
| 250 | + completion_tokens=final_completion_tokens, |
| 251 | + model=model_name, |
| 252 | + model_parameters=get_model_parameters(kwargs), |
| 253 | + raw_output=raw_outputs, |
| 254 | + id=inference_id, |
| 255 | + cost=final_cost, |
| 256 | + provider=provider, |
| 257 | + metadata={ |
| 258 | + "timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None), |
| 259 | + "provider": provider, |
| 260 | + "litellm_model": model_name, |
| 261 | + **latest_chunk_metadata, |
| 262 | + }, |
| 263 | + ) |
| 264 | + |
| 265 | + # pylint: disable=broad-except |
| 266 | + except Exception as e: |
| 267 | + if logger is not None: |
| 268 | + logger.error( |
| 269 | + "Failed to trace the LiteLLM completion request with Openlayer. %s", |
| 270 | + e, |
| 271 | + ) |
272 | 272 |
|
273 | 273 |
|
274 | 274 | def handle_non_streaming_completion( |
|
0 commit comments