Skip to content

Commit 884f01b

Browse files
committed
feat: support expanded reasoning response in agent chat
Add the `data` field to `AgentReasoningItem` to surface tool-call details from the agent's reasoning trace. Introduces `ToolCallDetail`, `ReasoningDataItem`, `ToolCallReasoningDataItem`, and `UnknownReasoningDataItem` following the same type-dispatch pattern used by `MessageContent`.
1 parent 3326f55 commit 884f01b

3 files changed

Lines changed: 204 additions & 8 deletions

File tree

cognite/client/data_classes/agents/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@
3838
Message,
3939
MessageContent,
4040
MessageList,
41+
ReasoningDataItem,
4142
TextContent,
43+
ToolCallDetail,
44+
ToolCallReasoningDataItem,
4245
ToolConfirmationCall,
4346
ToolConfirmationResult,
4447
UnknownAction,
4548
UnknownActionCall,
4649
UnknownContent,
50+
UnknownReasoningDataItem,
4751
)
4852

4953
__all__ = [
@@ -81,14 +85,18 @@
8185
"QueryKnowledgeGraphAgentToolUpsert",
8286
"QueryTimeSeriesDatapointsAgentTool",
8387
"QueryTimeSeriesDatapointsAgentToolUpsert",
88+
"ReasoningDataItem",
8489
"SummarizeDocumentAgentTool",
8590
"SummarizeDocumentAgentToolUpsert",
8691
"TextContent",
92+
"ToolCallDetail",
93+
"ToolCallReasoningDataItem",
8794
"ToolConfirmationCall",
8895
"ToolConfirmationResult",
8996
"UnknownAction",
9097
"UnknownActionCall",
9198
"UnknownAgentTool",
9299
"UnknownAgentToolUpsert",
93100
"UnknownContent",
101+
"UnknownReasoningDataItem",
94102
]

cognite/client/data_classes/agents/chat.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,25 +464,130 @@ def _load(cls, data: dict[str, Any]) -> AgentDataItem:
464464
return cls(type=item_type, data=item_data)
465465

466466

467+
@dataclass
468+
class ToolCallDetail(CogniteResource):
469+
"""Details of a tool call made during agent reasoning.
470+
471+
Args:
472+
id (str): The id of the tool call.
473+
name (str): The name of the tool that was called.
474+
tool_type (str): The type of the tool that was called.
475+
input (dict[str, Any]): The parameters that were passed to the tool.
476+
result (dict[str, Any]): The results that were returned by the tool.
477+
"""
478+
479+
id: str
480+
name: str
481+
tool_type: str
482+
input: dict[str, Any] = field(default_factory=dict)
483+
result: dict[str, Any] = field(default_factory=dict)
484+
485+
def dump(self, camel_case: bool = True) -> dict[str, Any]:
486+
key = "toolType" if camel_case else "tool_type"
487+
return {"id": self.id, "name": self.name, key: self.tool_type, "input": self.input, "result": self.result}
488+
489+
@classmethod
490+
def _load(cls, data: dict[str, Any]) -> ToolCallDetail:
491+
return cls(
492+
id=data["id"],
493+
name=data["name"],
494+
tool_type=data["toolType"],
495+
input=data.get("input", {}),
496+
result=data.get("result", {}),
497+
)
498+
499+
500+
@dataclass
501+
class ReasoningDataItem(CogniteResource, ABC):
502+
"""Base class for reasoning data item types."""
503+
504+
_type: ClassVar[str]
505+
506+
@classmethod
507+
def _load(cls, data: dict[str, Any]) -> ReasoningDataItem:
508+
item_type = data.get("type", "")
509+
klass = _REASONING_DATA_CLS_BY_TYPE.get(item_type, UnknownReasoningDataItem)
510+
return klass._load_item(data)
511+
512+
@classmethod
513+
@abstractmethod
514+
def _load_item(cls, data: dict[str, Any]) -> ReasoningDataItem: ...
515+
516+
517+
@dataclass
518+
class ToolCallReasoningDataItem(ReasoningDataItem):
519+
"""Reasoning data item for a tool call.
520+
521+
Args:
522+
tool_call (ToolCallDetail | None): Details of the tool call.
523+
"""
524+
525+
_type: ClassVar[str] = "toolCall"
526+
tool_call: ToolCallDetail | None = None
527+
528+
def dump(self, camel_case: bool = True) -> dict[str, Any]:
529+
key = "toolCall" if camel_case else "tool_call"
530+
result: dict[str, Any] = {"type": self._type}
531+
if self.tool_call is not None:
532+
result[key] = self.tool_call.dump(camel_case=camel_case)
533+
return result
534+
535+
@classmethod
536+
def _load_item(cls, data: dict[str, Any]) -> ToolCallReasoningDataItem:
537+
tool_call_data = data.get("toolCall")
538+
return cls(tool_call=ToolCallDetail._load(tool_call_data) if tool_call_data else None)
539+
540+
541+
@dataclass
542+
class UnknownReasoningDataItem(ReasoningDataItem):
543+
"""Unknown reasoning data item type for forward compatibility.
544+
545+
Args:
546+
type (str): The item type.
547+
data (dict[str, Any]): The raw item data.
548+
"""
549+
550+
type: str = ""
551+
data: dict[str, Any] = field(default_factory=dict)
552+
553+
def dump(self, camel_case: bool = True) -> dict[str, Any]:
554+
result = self.data.copy()
555+
result["type"] = self.type
556+
return result
557+
558+
@classmethod
559+
def _load_item(cls, data: dict[str, Any]) -> UnknownReasoningDataItem:
560+
return cls(type=data.get("type", ""), data=data)
561+
562+
563+
_REASONING_DATA_CLS_BY_TYPE: dict[str, type[ReasoningDataItem]] = {
564+
ToolCallReasoningDataItem._type: ToolCallReasoningDataItem,
565+
}
566+
567+
467568
@dataclass
468569
class AgentReasoningItem(CogniteResource):
469570
"""Reasoning item in agent response.
470571
471572
Args:
472573
content (list[MessageContent]): The reasoning content.
574+
data (list[ReasoningDataItem] | None): The data of the reasoning.
473575
"""
474576

475577
content: list[MessageContent]
578+
data: list[ReasoningDataItem] | None = None
476579

477580
def dump(self, camel_case: bool = True) -> dict[str, Any]:
478-
return {
479-
"content": [item.dump(camel_case=camel_case) for item in self.content],
480-
}
581+
result: dict[str, Any] = {"content": [item.dump(camel_case=camel_case) for item in self.content]}
582+
if self.data is not None:
583+
result["data"] = [item.dump(camel_case=camel_case) for item in self.data]
584+
return result
481585

482586
@classmethod
483587
def _load(cls, data: dict[str, Any]) -> AgentReasoningItem:
484588
content = [MessageContent._load(item) for item in data.get("content", [])]
485-
return cls(content=content)
589+
data_items = [ReasoningDataItem._load(item) for item in data.get("data", [])] or None
590+
return cls(content=content, data=data_items)
486591

487592

488593
@dataclass

tests/tests_unit/test_api/test_agents_chat.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
AgentMessage,
1515
AgentReasoningItem,
1616
TextContent,
17+
ToolCallDetail,
18+
ToolCallReasoningDataItem,
19+
UnknownReasoningDataItem,
1720
)
1821
from tests.utils import get_url, jsgz_load
1922

@@ -51,7 +54,19 @@ def chat_response_body() -> dict:
5154
"text": "The user is asking about capabilities",
5255
"type": "text",
5356
}
54-
]
57+
],
58+
"data": [
59+
{
60+
"type": "toolCall",
61+
"toolCall": {
62+
"id": "tc_1",
63+
"name": "analyzeData",
64+
"toolType": "analyzeData",
65+
"input": {"query": "SELECT *"},
66+
"result": {"rows": 42},
67+
},
68+
}
69+
],
5570
}
5671
],
5772
"role": "agent",
@@ -113,12 +128,25 @@ def test_chat_simple_message(
113128
# Check reasoning
114129
assert agent_msg.reasoning is not None
115130
assert len(agent_msg.reasoning) == 1
116-
assert isinstance(agent_msg.reasoning[0], AgentReasoningItem)
117-
assert len(agent_msg.reasoning[0].content) == 1
118-
content = agent_msg.reasoning[0].content[0]
131+
reasoning_item = agent_msg.reasoning[0]
132+
assert isinstance(reasoning_item, AgentReasoningItem)
133+
assert len(reasoning_item.content) == 1
134+
content = reasoning_item.content[0]
119135
assert isinstance(content, TextContent)
120136
assert content.text == "The user is asking about capabilities"
121137

138+
# Check reasoning data
139+
assert reasoning_item.data is not None
140+
assert len(reasoning_item.data) == 1
141+
data_item = reasoning_item.data[0]
142+
assert isinstance(data_item, ToolCallReasoningDataItem)
143+
assert data_item.tool_call is not None
144+
assert data_item.tool_call.id == "tc_1"
145+
assert data_item.tool_call.name == "analyzeData"
146+
assert data_item.tool_call.tool_type == "analyzeData"
147+
assert data_item.tool_call.input == {"query": "SELECT *"}
148+
assert data_item.tool_call.result == {"rows": 42}
149+
122150
# Test convenience properties
123151
assert response.text == "I can help you with various tasks related to your industrial data."
124152

@@ -196,3 +224,58 @@ def test_message_with_explicit_content(self) -> None:
196224
assert msg.content is content
197225
assert isinstance(msg.content, TextContent)
198226
assert msg.content.text == "Hello world"
227+
228+
229+
class TestAgentReasoningItem:
230+
def test_load_and_dump_with_data(self) -> None:
231+
raw = {
232+
"content": [{"type": "text", "text": "thinking..."}],
233+
"data": [
234+
{
235+
"type": "toolCall",
236+
"toolCall": {
237+
"id": "tc_1",
238+
"name": "analyzeData",
239+
"toolType": "analyzeData",
240+
"input": {"query": "SELECT *"},
241+
"result": {"rows": 42},
242+
},
243+
}
244+
],
245+
}
246+
item = AgentReasoningItem._load(raw)
247+
assert item.data is not None
248+
assert len(item.data) == 1
249+
data_item = item.data[0]
250+
assert isinstance(data_item, ToolCallReasoningDataItem)
251+
assert data_item.tool_call is not None
252+
assert data_item.tool_call.id == "tc_1"
253+
assert data_item.tool_call.name == "analyzeData"
254+
assert data_item.tool_call.tool_type == "analyzeData"
255+
assert data_item.tool_call.input == {"query": "SELECT *"}
256+
assert data_item.tool_call.result == {"rows": 42}
257+
assert item.dump() == raw
258+
259+
def test_load_without_data_field(self) -> None:
260+
raw = {"content": [{"type": "text", "text": "thinking..."}]}
261+
item = AgentReasoningItem._load(raw)
262+
assert item.data is None
263+
assert item.dump() == raw
264+
265+
def test_tool_call_reasoning_data_item_with_null_tool_call(self) -> None:
266+
raw = {"type": "toolCall"}
267+
data_item = ToolCallReasoningDataItem._load_item(raw)
268+
assert data_item.tool_call is None
269+
assert data_item.dump() == raw
270+
271+
def test_unknown_reasoning_data_item(self) -> None:
272+
raw = {"type": "futureType", "someField": "someValue"}
273+
data_item = UnknownReasoningDataItem._load_item(raw)
274+
assert data_item.type == "futureType"
275+
assert data_item.dump() == raw
276+
277+
def test_tool_call_detail_snake_case_dump(self) -> None:
278+
detail = ToolCallDetail(id="tc_1", name="analyzeData", tool_type="analyzeData", input={"q": 1}, result={"r": 2})
279+
dumped = detail.dump(camel_case=False)
280+
assert "tool_type" in dumped
281+
assert "toolType" not in dumped

0 commit comments

Comments
 (0)