|
42 | 42 | ClientV2, |
43 | 43 | StreamedChatResponseV2, |
44 | 44 | SystemChatMessageV2, |
45 | | - TextAssistantMessageContentItem, |
| 45 | + TextAssistantMessageV2ContentItem, |
46 | 46 | TextContent, |
47 | | - TextSystemMessageContentItem, |
| 47 | + TextSystemMessageV2ContentItem, |
48 | 48 | ToolCallV2, |
49 | 49 | ToolCallV2Function, |
50 | 50 | ToolChatMessageV2, |
@@ -127,9 +127,9 @@ def _format_message( |
127 | 127 | if message.role.value == "user": |
128 | 128 | return UserChatMessageV2(content=[TextContent(text=message.texts[0])]) |
129 | 129 | elif message.role.value == "assistant": |
130 | | - return AssistantChatMessageV2(content=[TextAssistantMessageContentItem(text=message.texts[0])]) |
| 130 | + return AssistantChatMessageV2(content=[TextAssistantMessageV2ContentItem(text=message.texts[0])]) |
131 | 131 | elif message.role.value == "system": |
132 | | - return SystemChatMessageV2(content=[TextSystemMessageContentItem(text=message.texts[0])]) |
| 132 | + return SystemChatMessageV2(content=[TextSystemMessageV2ContentItem(text=message.texts[0])]) |
133 | 133 | else: |
134 | 134 | msg = f"Unsupported message role: {message.role.value}" |
135 | 135 | raise ValueError(msg) |
@@ -274,21 +274,25 @@ def _process_cohere_chunk(cohere_chunk: StreamedChatResponseV2, state: Dict[str, |
274 | 274 | state["current_tool_call"] = None |
275 | 275 | state["current_tool_arguments"] = "" |
276 | 276 |
|
| 277 | + usage_data = getattr(cohere_chunk.delta, "usage", None) |
| 278 | + finish_reason = getattr(cohere_chunk.delta, "finish_reason", None) |
| 279 | + |
277 | 280 | if ( |
278 | | - cohere_chunk.delta.finish_reason is not None |
279 | | - and cohere_chunk.delta.usage |
280 | | - and cohere_chunk.delta.usage.billed_units |
281 | | - and cohere_chunk.delta.usage.billed_units.input_tokens is not None |
282 | | - and cohere_chunk.delta.usage.billed_units.output_tokens is not None |
| 281 | + finish_reason is not None |
| 282 | + and usage_data is not None |
| 283 | + and isinstance(usage_data, dict) |
| 284 | + and "billed_units" in usage_data |
| 285 | + and "input_tokens" in usage_data["billed_units"] |
| 286 | + and "output_tokens" in usage_data["billed_units"] |
283 | 287 | ): |
284 | 288 | state["captured_meta"].update( |
285 | 289 | { |
286 | 290 | "model": model, |
287 | 291 | "index": 0, |
288 | | - "finish_reason": cohere_chunk.delta.finish_reason, |
| 292 | + "finish_reason": finish_reason, |
289 | 293 | "usage": { |
290 | | - "prompt_tokens": cohere_chunk.delta.usage.billed_units.input_tokens, |
291 | | - "completion_tokens": cohere_chunk.delta.usage.billed_units.output_tokens, |
| 294 | + "prompt_tokens": usage_data["billed_units"]["input_tokens"], |
| 295 | + "completion_tokens": usage_data["billed_units"]["output_tokens"], |
292 | 296 | }, |
293 | 297 | } |
294 | 298 | ) |
|
0 commit comments