-
Notifications
You must be signed in to change notification settings - Fork 177
Expand file tree
/
Copy pathutils.py
More file actions
130 lines (106 loc) · 4.63 KB
/
utils.py
File metadata and controls
130 lines (106 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
@dataclass
class ChatMessage:
"""A chat message in an LLM conversation.
This dataclass represents messages exchanged in a conversation with an LLM,
supporting various message types including user prompts, assistant responses,
system instructions, and tool interactions.
Attributes:
role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'.
content: The message content. Can be a string or a list of content blocks
for multimodal messages (e.g., text + images).
reasoning_content: Optional reasoning/thinking content from the assistant,
typically from extended thinking or chain-of-thought models.
tool_calls: Optional list of tool calls requested by the assistant.
Each tool call contains 'id', 'type', and 'function' keys.
tool_call_id: Optional ID linking a tool response to its corresponding
tool call. Required for messages with role='tool'.
"""
role: Literal["user", "assistant", "system", "tool"]
content: str | list[dict[str, Any]] = ""
reasoning_content: str | None = None
tool_calls: list[dict[str, Any]] = field(default_factory=list)
tool_call_id: str | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert the message to a dictionary format for API calls.
Content is normalized to a list of ChatML-style blocks to keep a
consistent schema across traces and API payloads.
Returns:
A dictionary containing the message fields. Only includes non-empty
optional fields to keep the output clean.
"""
result: dict[str, Any] = {"role": self.role, "content": _normalize_content_blocks(self.content)}
if self.reasoning_content:
result["reasoning_content"] = self.reasoning_content
if self.tool_calls:
result["tool_calls"] = self.tool_calls
if self.tool_call_id:
result["tool_call_id"] = self.tool_call_id
return result
@classmethod
def as_user(cls, content: str | list[dict[str, Any]]) -> ChatMessage:
"""Create a user message."""
return cls(role="user", content=content)
@classmethod
def as_assistant(
cls,
content: str = "",
reasoning_content: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
) -> ChatMessage:
"""Create an assistant message."""
return cls(
role="assistant",
content=content,
reasoning_content=reasoning_content,
tool_calls=tool_calls or [],
)
@classmethod
def as_system(cls, content: str) -> ChatMessage:
"""Create a system message."""
return cls(role="system", content=content)
@classmethod
def as_tool(cls, content: str | list[dict[str, Any]], tool_call_id: str) -> ChatMessage:
"""Create a tool response message."""
return cls(role="tool", content=content, tool_call_id=tool_call_id)
def prompt_to_messages(
*,
user_prompt: str,
system_prompt: str | None = None,
multi_modal_context: list[dict[str, Any]] | None = None,
) -> list[ChatMessage]:
"""Convert a user and system prompt into ChatMessage list.
Args:
user_prompt (str): A user prompt.
system_prompt (str, optional): An optional system prompt.
"""
user_content: str | list[dict[str, Any]] = user_prompt
if multi_modal_context:
user_content = [*multi_modal_context, {"type": "text", "text": user_prompt}]
if system_prompt:
return [ChatMessage.as_system(system_prompt), ChatMessage.as_user(user_content)]
return [ChatMessage.as_user(user_content)]
def _normalize_content_blocks(content: Any) -> list[dict[str, Any]]:
if isinstance(content, list):
return [_normalize_content_block(block) for block in content]
if content is None:
return []
return [_text_block(content)]
def _normalize_content_block(block: Any) -> dict[str, Any]:
if isinstance(block, dict) and "type" in block:
return block
if isinstance(block, dict) and "text" in block:
return _text_block(block["text"])
return _text_block(block)
def _text_block(value: Any) -> dict[str, Any]:
if value is None:
text_value = ""
elif isinstance(value, str):
text_value = value
else:
text_value = str(value)
return {"type": "text", "text": text_value}