diff --git a/cognite/client/data_classes/agents/__init__.py b/cognite/client/data_classes/agents/__init__.py index 9531002d5e..b03eecc9cb 100644 --- a/cognite/client/data_classes/agents/__init__.py +++ b/cognite/client/data_classes/agents/__init__.py @@ -38,12 +38,16 @@ Message, MessageContent, MessageList, + ReasoningDataItem, TextContent, + ToolCallDetail, + ToolCallReasoningDataItem, ToolConfirmationCall, ToolConfirmationResult, UnknownAction, UnknownActionCall, UnknownContent, + UnknownReasoningDataItem, ) __all__ = [ @@ -81,9 +85,12 @@ "QueryKnowledgeGraphAgentToolUpsert", "QueryTimeSeriesDatapointsAgentTool", "QueryTimeSeriesDatapointsAgentToolUpsert", + "ReasoningDataItem", "SummarizeDocumentAgentTool", "SummarizeDocumentAgentToolUpsert", "TextContent", + "ToolCallDetail", + "ToolCallReasoningDataItem", "ToolConfirmationCall", "ToolConfirmationResult", "UnknownAction", @@ -91,4 +98,5 @@ "UnknownAgentTool", "UnknownAgentToolUpsert", "UnknownContent", + "UnknownReasoningDataItem", ] diff --git a/cognite/client/data_classes/agents/chat.py b/cognite/client/data_classes/agents/chat.py index 277877d179..f64d264203 100644 --- a/cognite/client/data_classes/agents/chat.py +++ b/cognite/client/data_classes/agents/chat.py @@ -464,25 +464,132 @@ def _load(cls, data: dict[str, Any]) -> AgentDataItem: return cls(type=item_type, data=item_data) +@dataclass +class ToolCallDetail(CogniteResource): + """Details of a tool call made during agent reasoning. + + Args: + id (str): The id of the tool call. + name (str): The name of the tool that was called. + tool_type (str): The type of the tool that was called. + input (dict[str, Any]): The parameters that were passed to the tool. + result (dict[str, Any]): The results that were returned by the tool. + """ + + id: str + name: str + tool_type: str + input: dict[str, Any] = field(default_factory=dict) + result: dict[str, Any] = field(default_factory=dict) + + @classmethod + def _load(cls, data: dict[str, Any]) -> ToolCallDetail: + return cls( + id=data["id"], + name=data["name"], + tool_type=data["toolType"], + input=data.get("input", {}), + result=data.get("result", {}), + ) + + +@dataclass +class ReasoningDataItem(CogniteResource, ABC): + """Base class for reasoning data item types.""" + + _type: ClassVar[str] + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + output = super().dump(camel_case=camel_case) + output["type"] = self._type + return output + + @classmethod + def _load(cls, data: dict[str, Any]) -> ReasoningDataItem: + item_type = data["type"] + klass = _REASONING_DATA_CLS_BY_TYPE.get(item_type, UnknownReasoningDataItem) + return klass._load_item(data) + + @classmethod + @abstractmethod + def _load_item(cls, data: dict[str, Any]) -> ReasoningDataItem: ... + + +@dataclass +class ToolCallReasoningDataItem(ReasoningDataItem): + """Reasoning data item for a tool call. + + Args: + tool_call (ToolCallDetail | None): Details of the tool call. + """ + + _type: ClassVar[str] = "toolCall" + tool_call: ToolCallDetail | None = None + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + key = "toolCall" if camel_case else "tool_call" + result: dict[str, Any] = {"type": self._type} + if self.tool_call is not None: + result[key] = self.tool_call.dump(camel_case=camel_case) + return result + + @classmethod + def _load_item(cls, data: dict[str, Any]) -> ToolCallReasoningDataItem: + return cls(tool_call=ToolCallDetail._load_if(data.get("toolCall"))) + + +@dataclass +class UnknownReasoningDataItem(ReasoningDataItem): + """Unknown reasoning data item type for forward compatibility. + + Args: + type (str): The item type. + data (dict[str, Any]): The raw item data. + """ + + type: str + data: dict[str, Any] = field(default_factory=dict) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + result = self.data.copy() + result["type"] = self.type + return result + + @classmethod + def _load_item(cls, data: dict[str, Any]) -> UnknownReasoningDataItem: + data = data.copy() + item_type = data.pop("type") + return cls(type=item_type, data=data) + + +_REASONING_DATA_CLS_BY_TYPE: dict[str, type[ReasoningDataItem]] = { + ToolCallReasoningDataItem._type: ToolCallReasoningDataItem, +} + + @dataclass class AgentReasoningItem(CogniteResource): """Reasoning item in agent response. Args: content (list[MessageContent]): The reasoning content. + data (list[ReasoningDataItem] | None): The data of the reasoning. """ content: list[MessageContent] + data: list[ReasoningDataItem] | None = None def dump(self, camel_case: bool = True) -> dict[str, Any]: - return { - "content": [item.dump(camel_case=camel_case) for item in self.content], - } + result: dict[str, Any] = {"content": [item.dump(camel_case=camel_case) for item in self.content]} + if self.data is not None: + result["data"] = [item.dump(camel_case=camel_case) for item in self.data] + return result @classmethod def _load(cls, data: dict[str, Any]) -> AgentReasoningItem: content = [MessageContent._load(item) for item in data.get("content", [])] - return cls(content=content) + data_items = [ReasoningDataItem._load(item) for item in data.get("data", [])] or None + return cls(content=content, data=data_items) @dataclass diff --git a/tests/tests_unit/test_api/test_agents_chat.py b/tests/tests_unit/test_api/test_agents_chat.py index 79deed3d57..445bec3730 100644 --- a/tests/tests_unit/test_api/test_agents_chat.py +++ b/tests/tests_unit/test_api/test_agents_chat.py @@ -14,6 +14,7 @@ AgentMessage, AgentReasoningItem, TextContent, + ToolCallReasoningDataItem, ) from tests.utils import get_url, jsgz_load @@ -51,7 +52,32 @@ def chat_response_body() -> dict: "text": "The user is asking about capabilities", "type": "text", } - ] + ], + "data": [ + { + "type": "toolCall", + "toolCall": { + "id": "tc_1", + "name": "search_instances", + "toolType": "query", + "input": { + "view_space": "cdf_cdm", + "view_external_id": "CogniteAsset", + "view_version": "v1", + "query": "pump", + "operator": "AND", + "return_properties": ["name", "externalId"], + }, + "result": { + "result": { + "items": [{"space": "my_space", "externalId": "pump_1"}], + "count": 1, + }, + "error": None, + }, + }, + } + ], } ], "role": "agent", @@ -113,11 +139,11 @@ def test_chat_simple_message( # Check reasoning assert agent_msg.reasoning is not None assert len(agent_msg.reasoning) == 1 - assert isinstance(agent_msg.reasoning[0], AgentReasoningItem) - assert len(agent_msg.reasoning[0].content) == 1 - content = agent_msg.reasoning[0].content[0] - assert isinstance(content, TextContent) - assert content.text == "The user is asking about capabilities" + reasoning_item = agent_msg.reasoning[0] + assert isinstance(reasoning_item, AgentReasoningItem) + assert isinstance(reasoning_item.content[0], TextContent) + assert reasoning_item.data is not None + assert isinstance(reasoning_item.data[0], ToolCallReasoningDataItem) # Test convenience properties assert response.text == "I can help you with various tasks related to your industrial data."