Skip to content

Commit 150894a

Browse files
committed
refactor(llm): update integration boundary for canonical tool calls
Add tool_calls_to_langchain_format() to convert canonical nested tool calls back to LangChain flat format at the output boundary. Update runnable_rails.py and guardrails.py to use the new types.
1 parent 428c1eb commit 150894a

6 files changed

Lines changed: 156 additions & 133 deletions

File tree

nemoguardrails/integrations/langchain/message_utils.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
ToolMessage,
2727
)
2828

29+
from nemoguardrails.types import ChatMessage, Role
30+
2931

3032
def get_message_role(msg: BaseMessage) -> str:
3133
"""Get the role string for a BaseMessage."""
@@ -147,36 +149,6 @@ def is_base_message(obj: Any) -> bool:
147149
return isinstance(obj, BaseMessage)
148150

149151

150-
def chatmessage_to_langchain_message(msg: "ChatMessage") -> BaseMessage:
151-
from nemoguardrails.types import Role
152-
153-
content = msg.content or ""
154-
if msg.role == Role.USER:
155-
return HumanMessage(content=content)
156-
elif msg.role == Role.SYSTEM:
157-
return SystemMessage(content=content)
158-
elif msg.role == Role.TOOL:
159-
return ToolMessage(content=content, tool_call_id=msg.tool_call_id or "")
160-
elif msg.role == Role.ASSISTANT:
161-
kwargs: Dict[str, Any] = {}
162-
if msg.tool_calls:
163-
kwargs["tool_calls"] = [
164-
{
165-
"name": tc.function.name,
166-
"args": tc.function.arguments,
167-
"id": tc.id,
168-
"type": "tool_call",
169-
}
170-
for tc in msg.tool_calls
171-
]
172-
return AIMessage(content=content, **kwargs)
173-
return HumanMessage(content=content)
174-
175-
176-
def chatmessages_to_langchain_messages(msgs: List["ChatMessage"]) -> List[BaseMessage]:
177-
return [chatmessage_to_langchain_message(m) for m in msgs]
178-
179-
180152
def is_ai_message(obj: Any) -> bool:
181153
"""Check if an object is an AIMessage."""
182154
return isinstance(obj, AIMessage)
@@ -308,3 +280,54 @@ def create_tool_message(
308280
kwargs["status"] = status
309281

310282
return ToolMessage(content=content, **kwargs)
283+
284+
285+
_ROLE_TO_LANGCHAIN = {
286+
Role.USER: HumanMessage,
287+
Role.ASSISTANT: AIMessage,
288+
Role.SYSTEM: SystemMessage,
289+
Role.TOOL: ToolMessage,
290+
}
291+
292+
293+
def chatmessage_to_langchain_message(msg: ChatMessage) -> BaseMessage:
294+
cls = _ROLE_TO_LANGCHAIN.get(msg.role)
295+
if cls is None:
296+
raise ValueError(f"Unsupported role: {msg.role}")
297+
298+
kwargs: Dict[str, Any] = {}
299+
if msg.name is not None:
300+
kwargs["name"] = msg.name
301+
302+
if cls is AIMessage and msg.tool_calls:
303+
kwargs["tool_calls"] = [
304+
{"name": tc.function.name, "args": tc.function.arguments, "id": tc.id, "type": tc.type}
305+
for tc in msg.tool_calls
306+
]
307+
308+
if cls is ToolMessage:
309+
kwargs["tool_call_id"] = msg.tool_call_id or ""
310+
311+
return cls(content=msg.content, **kwargs)
312+
313+
314+
def chatmessages_to_langchain_messages(msgs: List[ChatMessage]) -> List[BaseMessage]:
315+
return [chatmessage_to_langchain_message(m) for m in msgs]
316+
317+
318+
def tool_calls_to_langchain_format(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
319+
result = []
320+
for tc in tool_calls:
321+
func = tc.get("function")
322+
if func:
323+
result.append(
324+
{
325+
"name": func.get("name", ""),
326+
"args": func.get("arguments", {}),
327+
"id": tc.get("id", ""),
328+
"type": "tool_call",
329+
}
330+
)
331+
else:
332+
result.append(tc)
333+
return result

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
create_ai_message_chunk,
3232
is_base_message,
3333
message_to_dict,
34+
tool_calls_to_langchain_format,
3435
)
3536
from nemoguardrails.integrations.langchain.utils import async_wrap
3637
from nemoguardrails.rails.llm.options import GenerationOptions
@@ -498,10 +499,12 @@ def _format_output(
498499
Raises:
499500
ValueError: If the input type cannot be handled.
500501
"""
501-
# Standardize result format if it's a list
502502
if isinstance(result, list) and len(result) > 0:
503503
result = result[0]
504504

505+
if tool_calls:
506+
tool_calls = tool_calls_to_langchain_format(tool_calls)
507+
505508
if self.passthrough and self.passthrough_runnable:
506509
return self._format_passthrough_output(result, context)
507510

tests/test_tool_calling_passthrough_only.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from nemoguardrails import LLMRails, RailsConfig
2424
from nemoguardrails.actions.llm.generation import LLMGenerationActions
2525
from nemoguardrails.context import tool_calls_var
26+
from nemoguardrails.integrations.langchain.llm_adapter import LangChainLLMAdapter
2627
from tests.utils import get_bound_llm_magic_mock
2728

2829

@@ -106,20 +107,21 @@ def test_config_passthrough_false(self, config_no_passthrough):
106107
@pytest.mark.asyncio
107108
async def test_tool_calls_work_in_passthrough_mode(self, config_passthrough, mock_llm_with_tool_calls):
108109
"""Test that tool calls create BotToolCalls events in passthrough mode."""
109-
# Set up context with tool calls
110110
tool_calls = [
111111
{
112112
"id": "call_123",
113-
"type": "tool_call",
114-
"name": "test_tool",
115-
"args": {"param": "value"},
113+
"type": "function",
114+
"function": {
115+
"name": "test_tool",
116+
"arguments": {"param": "value"},
117+
},
116118
}
117119
]
118120
tool_calls_var.set(tool_calls)
119121

120122
generation_actions = LLMGenerationActions(
121123
config=config_passthrough,
122-
llm=mock_llm_with_tool_calls,
124+
llm=LangChainLLMAdapter(mock_llm_with_tool_calls),
123125
llm_task_manager=MagicMock(),
124126
get_embedding_search_provider_instance=MagicMock(return_value=None),
125127
)
@@ -133,7 +135,11 @@ async def test_tool_calls_work_in_passthrough_mode(self, config_passthrough, moc
133135

134136
assert len(result.events) == 1
135137
assert result.events[0]["type"] == "BotToolCalls"
136-
assert result.events[0]["tool_calls"] == tool_calls
138+
stored = result.events[0]["tool_calls"]
139+
assert len(stored) == 1
140+
assert stored[0]["function"]["name"] == "test_tool"
141+
assert stored[0]["function"]["arguments"] == {"param": "value"}
142+
assert stored[0]["id"] == "call_123"
137143

138144
@pytest.mark.asyncio
139145
async def test_tool_calls_ignored_in_non_passthrough_mode(self, config_no_passthrough, mock_llm_with_tool_calls):
@@ -150,7 +156,7 @@ async def test_tool_calls_ignored_in_non_passthrough_mode(self, config_no_passth
150156

151157
generation_actions = LLMGenerationActions(
152158
config=config_no_passthrough,
153-
llm=mock_llm_with_tool_calls,
159+
llm=LangChainLLMAdapter(mock_llm_with_tool_calls),
154160
llm_task_manager=MagicMock(),
155161
get_embedding_search_provider_instance=MagicMock(return_value=None),
156162
)
@@ -177,7 +183,7 @@ async def test_no_tool_calls_creates_bot_message_in_passthrough(self, config_pas
177183

178184
generation_actions = LLMGenerationActions(
179185
config=config_passthrough,
180-
llm=mock_llm_with_tool_calls,
186+
llm=LangChainLLMAdapter(mock_llm_with_tool_calls),
181187
llm_task_manager=MagicMock(),
182188
get_embedding_search_provider_instance=MagicMock(return_value=None),
183189
)
@@ -194,12 +200,12 @@ async def test_no_tool_calls_creates_bot_message_in_passthrough(self, config_pas
194200

195201
def test_llm_rails_integration_passthrough_mode(self, config_passthrough, mock_llm_with_tool_calls):
196202
"""Test LLMRails with passthrough mode allows tool calls."""
197-
rails = LLMRails(config=config_passthrough, llm=mock_llm_with_tool_calls)
203+
rails = LLMRails(config=config_passthrough, llm=LangChainLLMAdapter(mock_llm_with_tool_calls))
198204

199205
assert rails.config.passthrough is True
200206

201207
def test_llm_rails_integration_non_passthrough_mode(self, config_no_passthrough, mock_llm_with_tool_calls):
202208
"""Test LLMRails without passthrough mode."""
203-
rails = LLMRails(config=config_no_passthrough, llm=mock_llm_with_tool_calls)
209+
rails = LLMRails(config=config_no_passthrough, llm=LangChainLLMAdapter(mock_llm_with_tool_calls))
204210

205211
assert rails.config.passthrough is False

0 commit comments

Comments
 (0)