Skip to content

Commit fd176fc

Browse files
authored
feat(iorails): Add tool-calling rails to RailsManager and ModelRegistry (#2030)
1 parent 140aead commit fd176fc

16 files changed

Lines changed: 1839 additions & 3 deletions
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tool-call safety rail for IORails.
17+
18+
Validates the tool calls a model emitted against the request's declared
19+
``Toolset``: every call must name an allowed tool, and its arguments must
20+
satisfy that tool's JSON Schema. The rail is local and model-free -- it runs
21+
through :meth:`ToolRailAction._guarded`, so a malformed call or an unexpected
22+
error fails closed (blocks) rather than propagating.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
from typing import TYPE_CHECKING, List
28+
29+
from nemoguardrails.guardrails.guardrails_types import RailResult
30+
from nemoguardrails.guardrails.tool_rail_action import ToolRailAction
31+
from nemoguardrails.guardrails.tool_schema import validate_arguments
32+
33+
if TYPE_CHECKING:
34+
from nemoguardrails.guardrails.tool_schema import Toolset
35+
from nemoguardrails.types import ToolCall
36+
37+
38+
class ToolCallRailAction(ToolRailAction):
39+
"""Check the model's tool calls against the declared toolset (allowlist + schema)."""
40+
41+
action_name = "tool call validation"
42+
43+
async def run(self, toolset: "Toolset", tool_calls: List["ToolCall"]) -> RailResult:
44+
"""Block unless every tool call names an allowed tool with schema-valid arguments."""
45+
return self._guarded(lambda: self._validate(toolset, tool_calls))
46+
47+
def _validate(self, toolset: "Toolset", tool_calls: List["ToolCall"]) -> RailResult:
48+
"""Allowlist each call by name, then validate its arguments against the tool schema."""
49+
for call in tool_calls:
50+
# Hosted/server tools (e.g. web_search) have no function name; fall back to
51+
# call.type, mirroring Tool.key = name or type used when indexing the toolset.
52+
name = call.function.name or call.type
53+
tool = toolset.get(name)
54+
if tool is None:
55+
return RailResult(is_safe=False, reason=f"tool call '{name}' is not an allowed tool")
56+
block_reason = validate_arguments(tool, call.function.arguments)
57+
if block_reason is not None:
58+
return RailResult(is_safe=False, reason=block_reason)
59+
return RailResult(is_safe=True)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tool-result validation rail for IORails.
17+
18+
Structurally validates the tool results carried on an incoming request against
19+
the tool calls the model previously made: every result must link to a prior
20+
call by ``call_id``, name a tool consistent with that call, and carry
21+
well-formed content. This PR validates structure only -- there are no declared
22+
response schemas yet. The rail is local and model-free; it runs through
23+
:meth:`ToolRailAction._guarded`, so a malformed result or an unexpected error
24+
fails closed (blocks) rather than propagating.
25+
"""
26+
27+
from __future__ import annotations
28+
29+
from typing import TYPE_CHECKING, List
30+
31+
from nemoguardrails.guardrails.guardrails_types import RailResult
32+
from nemoguardrails.guardrails.tool_rail_action import ToolRailAction
33+
34+
if TYPE_CHECKING:
35+
from nemoguardrails.guardrails.tool_schema import ToolResult
36+
from nemoguardrails.types import ToolCall
37+
38+
39+
def _is_well_formed_content(content: object) -> bool:
40+
"""Tool-result content is a string, or a list of content-block dicts.
41+
42+
Matches the declared ``ToolResult.content`` type (``str | list[dict] | None``);
43+
a list of non-dict values (e.g. ``[1, 2, 3]``) is not well-formed.
44+
"""
45+
if isinstance(content, str):
46+
return True
47+
return isinstance(content, list) and all(isinstance(block, dict) for block in content)
48+
49+
50+
class ToolResultRailAction(ToolRailAction):
51+
"""Check incoming tool results link to a prior call and are structurally well-formed."""
52+
53+
action_name = "tool result validation"
54+
55+
async def run(self, tool_results: List["ToolResult"], prior_calls: List["ToolCall"]) -> RailResult:
56+
"""Block unless every tool result links to a prior call with a consistent name and valid content."""
57+
return self._guarded(lambda: self._validate(tool_results, prior_calls))
58+
59+
def _validate(self, tool_results: List["ToolResult"], prior_calls: List["ToolCall"]) -> RailResult:
60+
"""Check call_id linkage, name consistency, and content shape for each result."""
61+
calls_by_id = self._validate_prior_calls(prior_calls)
62+
if isinstance(calls_by_id, RailResult):
63+
return calls_by_id
64+
return self._validate_results(tool_results, calls_by_id)
65+
66+
def _validate_prior_calls(self, prior_calls: List["ToolCall"]) -> "RailResult | dict[str, ToolCall]":
67+
"""Build a call_id index from prior_calls; return a blocking RailResult on duplicate IDs."""
68+
calls_by_id: dict[str, "ToolCall"] = {}
69+
for call in prior_calls:
70+
if not call.id:
71+
continue
72+
if call.id in calls_by_id:
73+
return RailResult(
74+
is_safe=False,
75+
reason=f"duplicate prior tool call id '{call.id}' makes tool-result linkage ambiguous",
76+
)
77+
calls_by_id[call.id] = call
78+
return calls_by_id
79+
80+
def _validate_results(self, tool_results: List["ToolResult"], calls_by_id: "dict[str, ToolCall]") -> RailResult:
81+
"""Check each result links to a prior call with a consistent name and well-formed content."""
82+
rail_result = self._validate_tool_result_ids(tool_results)
83+
if rail_result:
84+
return rail_result
85+
86+
for result in tool_results:
87+
rail_result = self._validate_result_call_id(result, calls_by_id)
88+
if rail_result:
89+
return rail_result
90+
91+
prior = calls_by_id[result.call_id] # type: ignore[index]
92+
rail_result = self._validate_result_name(result, prior)
93+
if rail_result:
94+
return rail_result
95+
96+
rail_result = self._validate_result_content(result)
97+
if rail_result:
98+
return rail_result
99+
100+
return RailResult(is_safe=True)
101+
102+
def _validate_tool_result_ids(self, tool_results: List["ToolResult"]) -> "RailResult | None":
103+
"""Return a blocking RailResult if any call_id appears more than once in the result list."""
104+
seen: set[str] = set()
105+
for result in tool_results:
106+
if not result.call_id:
107+
continue
108+
if result.call_id in seen:
109+
return RailResult(
110+
is_safe=False,
111+
reason=f"duplicate tool result for call_id '{result.call_id}': each tool call must have exactly one result",
112+
)
113+
seen.add(result.call_id)
114+
return None
115+
116+
def _validate_result_call_id(self, result: "ToolResult", calls_by_id: "dict[str, ToolCall]") -> "RailResult | None":
117+
"""Return a blocking RailResult if the result is missing a call_id or it has no prior call."""
118+
call_id = result.call_id
119+
if not call_id:
120+
return RailResult(is_safe=False, reason="tool result is missing a call_id")
121+
if calls_by_id.get(call_id) is None:
122+
return RailResult(
123+
is_safe=False,
124+
reason=f"tool result for call_id '{call_id}' does not correspond to a prior tool call",
125+
)
126+
return None
127+
128+
def _validate_result_name(self, result: "ToolResult", prior: "ToolCall") -> "RailResult | None":
129+
"""Return a blocking RailResult if the result name conflicts with the prior call's function name."""
130+
if result.name and prior.function.name and result.name != prior.function.name:
131+
return RailResult(
132+
is_safe=False,
133+
reason=(
134+
f"tool result name '{result.name}' does not match the called tool "
135+
f"'{prior.function.name}' for call_id '{result.call_id}'"
136+
),
137+
)
138+
return None
139+
140+
def _validate_result_content(self, result: "ToolResult") -> "RailResult | None":
141+
"""Return a blocking RailResult if the result content is not a string or list of dicts."""
142+
if result.content is not None and not _is_well_formed_content(result.content):
143+
return RailResult(
144+
is_safe=False,
145+
reason=f"tool result for call_id '{result.call_id}' has malformed content",
146+
)
147+
return None

nemoguardrails/guardrails/engine_registry.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
set_llm_request_attributes,
3737
set_llm_response_attributes,
3838
)
39+
from nemoguardrails.guardrails.tool_schema import ToolResult, Toolset
3940
from nemoguardrails.rails.llm.config import Model, RailsConfigData
4041
from nemoguardrails.tracing.constants import (
4142
llm_operation_duration,
@@ -370,6 +371,33 @@ async def stream_model_call(
370371
if self._metrics_enabled:
371372
record_token_usage(engine.model_name, provider_name, operation_name, captured_usage)
372373

374+
def parse_tools(self, model_type: str, llm_params: Optional[dict]) -> Toolset:
375+
"""Parse the tool block in ``llm_params`` for the named model engine.
376+
377+
Delegates to the engine's ``parse_tools`` so the provider-specific shape
378+
(keyed on the engine) is normalized into a ``Toolset`` for the tool rails.
379+
380+
Raises:
381+
KeyError: If no engine is registered with the given name.
382+
TypeError: If the named engine is not a ModelEngine.
383+
"""
384+
engine = self._get_engine(model_type, ModelEngine)
385+
return engine.parse_tools({**engine.body_param_defaults, **(llm_params or {})})
386+
387+
def extract_tool_results(self, model_type: str, messages: list[dict]) -> list[ToolResult]:
388+
"""Extract incoming tool results from ``messages`` for the named model engine.
389+
390+
Delegates to the engine's ``extract_tool_results`` so the provider's
391+
tool-result messages are normalized into the ``ToolResult`` list the
392+
ToolResultRail consumes.
393+
394+
Raises:
395+
KeyError: If no engine is registered with the given name.
396+
TypeError: If the named engine is not a ModelEngine.
397+
"""
398+
engine = self._get_engine(model_type, ModelEngine)
399+
return engine.extract_tool_results(messages)
400+
373401
async def api_call(self, api_name: str, message: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
374402
"""Route an API request to the named API engine.
375403

nemoguardrails/guardrails/model_engine.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from nemoguardrails.guardrails.base_engine import BaseEngine
4141
from nemoguardrails.guardrails.guardrails_types import LLMMessages, get_request_id, truncate
42+
from nemoguardrails.guardrails.tool_schema import Tool, ToolResult, Toolset
4243
from nemoguardrails.rails.llm.config import Model
4344
from nemoguardrails.types import ChatMessage, LLMResponse, LLMResponseChunk, ToolCall, ToolCallFunction, UsageInfo
4445

@@ -296,6 +297,79 @@ def _finalize_tool_calls(tool_calls: dict[int, dict]) -> list[ToolCall]:
296297
return result
297298

298299

300+
def _parse_tools_openai(tools: list) -> list[Tool]:
301+
"""Parse OpenAI Chat Completions tool definitions into ``Tool`` objects.
302+
303+
Each entry has the nested shape ``{"type": "function", "function": {"name",
304+
"description", "parameters", "strict"}}``; ``function.parameters`` (the JSON
305+
Schema) maps to ``Tool.arguments_schema``. Entries that are not a dict, lack a
306+
``function`` block, or whose function has no non-empty ``name`` are skipped.
307+
"""
308+
parsed: list[Tool] = []
309+
for entry in tools:
310+
if not isinstance(entry, dict):
311+
continue
312+
function = entry.get("function")
313+
if not isinstance(function, dict):
314+
continue
315+
name = function.get("name")
316+
if not isinstance(name, str) or not name:
317+
continue
318+
parsed.append(
319+
Tool(
320+
name=name,
321+
type=entry.get("type", "function"),
322+
description=function.get("description"),
323+
arguments_schema=function.get("parameters"),
324+
strict=function.get("strict"),
325+
)
326+
)
327+
return parsed
328+
329+
330+
def _parse_tools_nim(tools: list) -> list[Tool]:
331+
"""Parse NIM tool definitions. NIM uses the OpenAI Chat Completions tool shape."""
332+
return _parse_tools_openai(tools)
333+
334+
335+
_TOOL_PARSERS = {
336+
"openai": _parse_tools_openai,
337+
"nim": _parse_tools_nim,
338+
}
339+
340+
341+
def _extract_tool_results_openai(messages: LLMMessages) -> list[ToolResult]:
342+
"""Extract OpenAI Chat Completions tool results into ``ToolResult`` objects.
343+
344+
Chat Completions carries each tool result as a top-level ``{"role": "tool",
345+
"tool_call_id", "content"}`` message (optionally ``name``). This shape has no
346+
error flag, so ``is_error`` is always ``False``.
347+
"""
348+
results: list[ToolResult] = []
349+
for message in messages:
350+
if not isinstance(message, dict) or message.get("role") != "tool":
351+
continue
352+
results.append(
353+
ToolResult(
354+
call_id=message.get("tool_call_id"),
355+
name=message.get("name"),
356+
content=message.get("content"),
357+
)
358+
)
359+
return results
360+
361+
362+
def _extract_tool_results_nim(messages: LLMMessages) -> list[ToolResult]:
363+
"""Extract NIM tool results. NIM uses the OpenAI Chat Completions shape."""
364+
return _extract_tool_results_openai(messages)
365+
366+
367+
_RESULT_EXTRACTORS = {
368+
"openai": _extract_tool_results_openai,
369+
"nim": _extract_tool_results_nim,
370+
}
371+
372+
299373
class ModelEngineError(Exception):
300374
"""Raised when a model engine call fails."""
301375

@@ -658,3 +732,32 @@ async def stream_chat_completion(
658732
"""
659733
async for chunk in self.stream_call(messages, **kwargs):
660734
yield chunk
735+
736+
def parse_tools(self, llm_params: Optional[dict]) -> Toolset:
737+
"""Parse the provider tool block in ``llm_params`` into a ``Toolset``.
738+
739+
Reads the opaque ``tools`` block forwarded via
740+
``GenerationOptions.llm_params`` and normalizes it into the internal
741+
``Toolset`` the tool rails validate against, keyed on the model's engine
742+
(``_TOOL_PARSERS``). OpenAI and NIM share the Chat Completions shape; an
743+
engine with no registered parser falls back to it. Returns an empty
744+
``Toolset`` when no tools are declared.
745+
"""
746+
tools = (llm_params or {}).get("tools")
747+
if not tools:
748+
return Toolset()
749+
parser = _TOOL_PARSERS.get(self.model_config.engine, _parse_tools_openai)
750+
return Toolset(tools=parser(tools))
751+
752+
def extract_tool_results(self, messages: LLMMessages) -> list[ToolResult]:
753+
"""Extract incoming tool results from ``messages`` into ``ToolResult`` objects.
754+
755+
Pulls the provider's tool-result messages out of the conversation and
756+
normalizes them into the internal ``ToolResult`` shape the ToolResultRail
757+
consumes, keyed on the model's engine (``_RESULT_EXTRACTORS``). OpenAI and
758+
NIM share the Chat Completions shape (``role:"tool"`` messages); an engine
759+
with no registered extractor falls back to it. Returns an empty list when
760+
there are no tool results.
761+
"""
762+
extractor = _RESULT_EXTRACTORS.get(self.model_config.engine, _extract_tool_results_openai)
763+
return extractor(messages)

0 commit comments

Comments
 (0)