Skip to content

Commit 69e78d0

Browse files
committed
Refactor ToolFailureEvaluator: privatize models, rename attempts to failed_attempts
Signed-off-by: Eric Evans <194135482+ericevans-nv@users.noreply.github.com>
1 parent 13dc6bc commit 69e78d0

7 files changed

Lines changed: 62 additions & 66 deletions

File tree

packages/nvidia_nat_core/tests/nat/utils/test_atif_converter.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -708,11 +708,6 @@ def test_stream_matches_batch(
708708
assert len(s_step.tool_calls) == len(b_step.tool_calls)
709709

710710

711-
# ---------------------------------------------------------------------------
712-
# Tool error → ATIF conversion tests
713-
# ---------------------------------------------------------------------------
714-
715-
716711
@pytest.fixture(name="error_trajectory")
717712
def fixture_error_trajectory() -> list[IntermediateStep]:
718713
"""Trajectory with one successful and one failed tool call."""

packages/nvidia_nat_eval/src/nat/plugins/eval/tool_failure_evaluator/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,9 @@
1414
# limitations under the License.
1515

1616
from .evaluator import ToolFailureEvaluator
17-
from .models import ToolFailureReasoning
18-
from .models import ToolSummary
1917
from .register import ToolFailureEvaluatorConfig
2018

2119
__all__ = [
2220
"ToolFailureEvaluator",
2321
"ToolFailureEvaluatorConfig",
24-
"ToolFailureReasoning",
25-
"ToolSummary",
2622
]

packages/nvidia_nat_eval/src/nat/plugins/eval/tool_failure_evaluator/evaluator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
from nat.plugins.eval.evaluator.atif_evaluator import AtifEvalSample
3030
from nat.plugins.eval.evaluator.base_evaluator import BaseEvaluator
3131

32-
from .models import ToolFailureReasoning
33-
from .models import ToolSummary
3432
from .models import _ToolCall
33+
from .models import _ToolFailureReasoning
34+
from .models import _ToolSummary
3535

3636

3737
class ToolFailureEvaluator(BaseEvaluator, AtifBaseEvaluator):
@@ -47,7 +47,7 @@ def __init__(self, max_concurrency: int = 8):
4747
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem:
4848
"""Evaluate a single item's legacy trajectory for tool failures."""
4949
if not item.trajectory:
50-
return EvalOutputItem(id=item.id, score=1.0, reasoning=ToolFailureReasoning())
50+
return EvalOutputItem(id=item.id, score=1.0, reasoning=_ToolFailureReasoning())
5151

5252
total_tool_calls: int = 0
5353
failed_tool_calls: int = 0
@@ -71,17 +71,17 @@ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem:
7171
failed_tool_calls += 1
7272

7373
score: float = self._success_rate(total_tool_calls, failed_tool_calls)
74-
per_tool_summary: list[ToolSummary] = [
75-
ToolSummary(
74+
per_tool_summary: list[_ToolSummary] = [
75+
_ToolSummary(
7676
tool_name=name,
7777
total_calls=len(attempts),
7878
failed_calls=failed_count,
79-
attempts=[a for a in attempts if a.error is not None],
79+
failed_attempts=[a for a in attempts if a.error is not None],
8080
) for name, attempts in calls_by_tool.items()
8181
if (failed_count := sum(1 for a in attempts if a.error is not None)) > 0
8282
]
8383
failed_tools: list[str] = [ts.tool_name for ts in per_tool_summary]
84-
reasoning: ToolFailureReasoning = ToolFailureReasoning(
84+
reasoning: _ToolFailureReasoning = _ToolFailureReasoning(
8585
total_tool_calls=total_tool_calls,
8686
failed_tool_calls=failed_tool_calls,
8787
failed_tools=failed_tools,
@@ -154,17 +154,17 @@ async def evaluate_atif_item(self, sample: AtifEvalSample) -> EvalOutputItem:
154154
failed_tool_calls += 1
155155

156156
score: float = self._success_rate(total_tool_calls, failed_tool_calls)
157-
per_tool_summary: list[ToolSummary] = [
158-
ToolSummary(
157+
per_tool_summary: list[_ToolSummary] = [
158+
_ToolSummary(
159159
tool_name=name,
160160
total_calls=len(attempts),
161161
failed_calls=failed_count,
162-
attempts=[a for a in attempts if a.error is not None],
162+
failed_attempts=[a for a in attempts if a.error is not None],
163163
) for name, attempts in calls_by_tool.items()
164164
if (failed_count := sum(1 for a in attempts if a.error is not None)) > 0
165165
]
166166
failed_tools: list[str] = [ts.tool_name for ts in per_tool_summary]
167-
reasoning: ToolFailureReasoning = ToolFailureReasoning(
167+
reasoning: _ToolFailureReasoning = _ToolFailureReasoning(
168168
total_tool_calls=total_tool_calls,
169169
failed_tool_calls=failed_tool_calls,
170170
failed_tools=failed_tools,

packages/nvidia_nat_eval/src/nat/plugins/eval/tool_failure_evaluator/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,26 @@ class _ToolCall(BaseModel):
3030
error: str | None = Field(default=None, description="Error string if failed, None if succeeded.")
3131

3232

33-
class ToolSummary(BaseModel):
33+
class _ToolSummary(BaseModel):
3434
"""Complete health and attempt data for a single tool."""
3535

3636
tool_name: str = Field(description="Name of the tool.")
3737
total_calls: int = Field(default=0, description="Total number of calls to this tool.")
3838
failed_calls: int = Field(default=0, description="Number of calls that returned an error.")
39-
attempts: list[_ToolCall] = Field(
39+
failed_attempts: list[_ToolCall] = Field(
4040
default_factory=list,
41-
description="Ordered list of every call to this tool.",
41+
description="Ordered list of failed calls to this tool.",
4242
)
4343

4444

45-
class ToolFailureReasoning(BaseModel):
45+
class _ToolFailureReasoning(BaseModel):
4646
"""Complete reasoning payload returned by the tool failure evaluator."""
4747

4848
total_tool_calls: int = Field(default=0, description="Total tool calls in the trajectory.")
4949
failed_tool_calls: int = Field(default=0, description="Total tool calls that errored.")
5050
failed_tools: list[str] = Field(default_factory=list, description="Names of tools that had at least one failure.")
5151
score: float = Field(default=1.0, description="Overall success rate (0.0-1.0).")
52-
per_tool_summary: list[ToolSummary] = Field(
52+
per_tool_summary: list[_ToolSummary] = Field(
5353
default_factory=list,
5454
description="Per-tool health summary with attempt history.",
5555
)

packages/nvidia_nat_eval/tests/eval/evaluator/test_tool_failure_evaluator.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
"""Unit tests for ToolFailureEvaluator model population.
16-
17-
Validates that ToolFailureReasoning, ToolSummary, and _ToolCall are correctly
18-
populated from both the legacy IntermediateStep lane and the ATIF lane, and
19-
that error detection correctly distinguishes failures from successes.
20-
"""
15+
"""Unit tests for ToolFailureEvaluator."""
2116

2217
from __future__ import annotations
2318

@@ -39,7 +34,7 @@
3934
from nat.data_models.invocation_node import InvocationNode
4035
from nat.plugins.eval.evaluator.atif_evaluator import AtifEvalSample
4136
from nat.plugins.eval.tool_failure_evaluator.evaluator import ToolFailureEvaluator
42-
from nat.plugins.eval.tool_failure_evaluator.models import ToolFailureReasoning
37+
from nat.plugins.eval.tool_failure_evaluator.models import _ToolFailureReasoning
4338

4439
_DUMMY_ANCESTRY: InvocationNode = InvocationNode(function_id="f-0", function_name="test_fn")
4540

@@ -107,18 +102,16 @@ def evaluator_fixture() -> ToolFailureEvaluator:
107102
return ToolFailureEvaluator()
108103

109104

110-
class TestLegacyLaneModelPopulation:
111-
"""Verify ToolFailureReasoning, ToolSummary, and _ToolCall are correctly
112-
populated from legacy IntermediateStep trajectories.
113-
"""
105+
class TestEvaluateIntermediateStepTrajectory:
106+
"""Tests for evaluating IntermediateStep trajectories."""
114107

115108
async def test_empty_trajectory_produces_default_reasoning(self, evaluator: ToolFailureEvaluator):
116109
"""An empty trajectory should yield default ToolFailureReasoning with
117110
zero counts, no failed tools, and a perfect score.
118111
"""
119112
result = await evaluator.evaluate_item(_eval_input("empty", []))
120113

121-
reasoning: ToolFailureReasoning = result.reasoning
114+
reasoning: _ToolFailureReasoning = result.reasoning
122115
assert reasoning.total_tool_calls == 0
123116
assert reasoning.failed_tool_calls == 0
124117
assert reasoning.failed_tools == []
@@ -136,7 +129,7 @@ async def test_all_failed_calls_populate_summary_with_error_details(self, evalua
136129
]
137130
result = await evaluator.evaluate_item(_eval_input("fail", trajectory))
138131

139-
reasoning: ToolFailureReasoning = result.reasoning
132+
reasoning: _ToolFailureReasoning = result.reasoning
140133
assert reasoning.total_tool_calls == 2
141134
assert reasoning.failed_tool_calls == 2
142135
assert reasoning.failed_tools == ["lookup"]
@@ -146,8 +139,8 @@ async def test_all_failed_calls_populate_summary_with_error_details(self, evalua
146139
assert summary.tool_name == "lookup"
147140
assert summary.total_calls == 2
148141
assert summary.failed_calls == 2
149-
assert len(summary.attempts) == 2
150-
for attempt in summary.attempts:
142+
assert len(summary.failed_attempts) == 2
143+
for attempt in summary.failed_attempts:
151144
assert attempt.error == "ValueError: bad input"
152145
assert attempt.output is None
153146

@@ -163,7 +156,7 @@ async def test_mixed_results_split_correctly_across_models(self, evaluator: Tool
163156
]
164157
result = await evaluator.evaluate_item(_eval_input("mixed", trajectory))
165158

166-
reasoning: ToolFailureReasoning = result.reasoning
159+
reasoning: _ToolFailureReasoning = result.reasoning
167160
assert reasoning.total_tool_calls == 2
168161
assert reasoning.failed_tool_calls == 1
169162
assert reasoning.failed_tools == ["lookup"]
@@ -173,7 +166,7 @@ async def test_mixed_results_split_correctly_across_models(self, evaluator: Tool
173166
assert reasoning.per_tool_summary[0].tool_name == "lookup"
174167

175168
async def test_same_tool_mixed_results_filters_attempts_to_failures_only(self, evaluator: ToolFailureEvaluator):
176-
"""When a single tool has both successes and failures, ToolSummary.attempts
169+
"""When a single tool has both successes and failures, ToolSummary.failed_attempts
177170
should contain only the failed _ToolCall entries while total_calls reflects all.
178171
"""
179172
trajectory = [
@@ -184,13 +177,13 @@ async def test_same_tool_mixed_results_filters_attempts_to_failures_only(self, e
184177
]
185178
result = await evaluator.evaluate_item(_eval_input("filter", trajectory))
186179

187-
reasoning: ToolFailureReasoning = result.reasoning
180+
reasoning: _ToolFailureReasoning = result.reasoning
188181
summary = reasoning.per_tool_summary[0]
189182
assert summary.total_calls == 2
190183
assert summary.failed_calls == 1
191-
assert len(summary.attempts) == 1
192-
assert summary.attempts[0].error == "boom"
193-
assert summary.attempts[0].input == {"q": "bad"}
184+
assert len(summary.failed_attempts) == 1
185+
assert summary.failed_attempts[0].error == "boom"
186+
assert summary.failed_attempts[0].input == {"q": "bad"}
194187

195188
async def test_none_data_on_step_is_not_treated_as_error(self, evaluator: ToolFailureEvaluator):
196189
"""A TOOL_END step with data=None should count as a call but not a failure."""
@@ -220,10 +213,8 @@ async def test_missing_tool_name_recorded_as_unknown(self, evaluator: ToolFailur
220213
assert result.reasoning.per_tool_summary[0].tool_name == "unknown"
221214

222215

223-
class TestAtifLaneModelPopulation:
224-
"""Verify ToolFailureReasoning, ToolSummary, and _ToolCall are correctly
225-
populated from ATIF trajectories using each error detection path.
226-
"""
216+
class TestEvaluateAtifTrajectory:
217+
"""Tests for evaluating ATIF trajectories."""
227218

228219
async def test_error_detected_via_extra_tool_errors(self, evaluator: ToolFailureEvaluator):
229220
"""Structured error metadata in step.extra['tool_errors'] should populate
@@ -242,11 +233,11 @@ async def test_error_detected_via_extra_tool_errors(self, evaluator: ToolFailure
242233
]
243234
result = await evaluator.evaluate_atif_item(_atif_sample("extra", steps))
244235

245-
reasoning: ToolFailureReasoning = result.reasoning
236+
reasoning: _ToolFailureReasoning = result.reasoning
246237
assert reasoning.failed_tool_calls == 1
247238
assert reasoning.failed_tools == ["lookup"]
248-
assert reasoning.per_tool_summary[0].attempts[0].error == "ValueError: Column not found"
249-
assert reasoning.per_tool_summary[0].attempts[0].input == {"query": "q1"}
239+
assert reasoning.per_tool_summary[0].failed_attempts[0].error == "ValueError: Column not found"
240+
assert reasoning.per_tool_summary[0].failed_attempts[0].input == {"query": "q1"}
250241

251242
async def test_error_detected_via_stringified_tool_message_dict(self, evaluator: ToolFailureEvaluator):
252243
"""A Python dict literal with status='error' in the observation content
@@ -262,7 +253,7 @@ async def test_error_detected_via_stringified_tool_message_dict(self, evaluator:
262253
result = await evaluator.evaluate_atif_item(_atif_sample("parsed", steps))
263254

264255
assert result.reasoning.failed_tool_calls == 1
265-
assert result.reasoning.per_tool_summary[0].attempts[0].error == "TimeoutError: timed out"
256+
assert result.reasoning.per_tool_summary[0].failed_attempts[0].error == "TimeoutError: timed out"
266257

267258
async def test_error_detected_via_raw_error_pattern(self, evaluator: ToolFailureEvaluator):
268259
"""Observation content matching 'XyzError: ...' should be detected as a
@@ -274,7 +265,7 @@ async def test_error_detected_via_raw_error_pattern(self, evaluator: ToolFailure
274265
result = await evaluator.evaluate_atif_item(_atif_sample("pattern", steps))
275266

276267
assert result.reasoning.failed_tool_calls == 1
277-
assert result.reasoning.per_tool_summary[0].attempts[0].error == "RuntimeError: internal failure"
268+
assert result.reasoning.per_tool_summary[0].failed_attempts[0].error == "RuntimeError: internal failure"
278269

279270
async def test_extra_tool_errors_takes_priority_over_observation_pattern(self, evaluator: ToolFailureEvaluator):
280271
"""When both extra['tool_errors'] and a raw error pattern match, the
@@ -292,7 +283,7 @@ async def test_extra_tool_errors_takes_priority_over_observation_pattern(self, e
292283
]
293284
result = await evaluator.evaluate_atif_item(_atif_sample("priority", steps))
294285

295-
assert result.reasoning.per_tool_summary[0].attempts[0].error == "ValueError: from extra"
286+
assert result.reasoning.per_tool_summary[0].failed_attempts[0].error == "ValueError: from extra"
296287

297288
async def test_mixed_success_and_failure_populates_only_failing_tool(self, evaluator: ToolFailureEvaluator):
298289
"""With one successful and one failing tool, only the failing tool

packages/nvidia_nat_langchain/src/nat/plugins/langchain/callback_handler.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from nat.data_models.intermediate_step import IntermediateStepType
3737
from nat.data_models.intermediate_step import ServerToolUseSchema
3838
from nat.data_models.intermediate_step import StreamEventData
39+
from nat.data_models.intermediate_step import ToolDetails
40+
from nat.data_models.intermediate_step import ToolParameters
3941
from nat.data_models.intermediate_step import ToolSchema
4042
from nat.data_models.intermediate_step import TraceMetadata
4143
from nat.data_models.intermediate_step import UsageInfo
@@ -53,11 +55,28 @@ def _extract_tools_schema(invocation_params: dict) -> list:
5355
try:
5456
tools_schema.append(ToolSchema(**tool))
5557
except Exception:
56-
logger.debug(
57-
"Failed to parse tool schema from invocation params: %s. \n This "
58-
"can occur when the LLM server has native tools and can be ignored if "
59-
"using the responses API.",
60-
tool)
58+
# Handle non-OpenAI tool formats (e.g. Anthropic: top-level name/description/input_schema)
59+
try:
60+
input_schema = tool.get("input_schema") or {}
61+
tools_schema.append(
62+
ToolSchema(
63+
type="function",
64+
function=ToolDetails(
65+
name=tool["name"],
66+
description=tool.get("description", ""),
67+
parameters=ToolParameters(
68+
properties=input_schema.get("properties", {}),
69+
required=input_schema.get("required", []),
70+
additionalProperties=input_schema.get("additionalProperties", False),
71+
),
72+
),
73+
))
74+
except (KeyError, TypeError, AttributeError):
75+
logger.debug(
76+
"Failed to parse tool schema from invocation params: %s. \n This "
77+
"can occur when the LLM server has native tools and can be ignored if "
78+
"using the responses API.",
79+
tool)
6180

6281
return tools_schema
6382

packages/nvidia_nat_langchain/tests/test_langchain_callback_handler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,6 @@ def test_extract_tools_schema_empty_and_none():
284284
assert _extract_tools_schema(None) == []
285285

286286

287-
# ---------------------------------------------------------------------------
288-
# on_tool_error tests
289-
# ---------------------------------------------------------------------------
290-
291-
292287
@pytest.fixture(name="handler_and_stats")
293288
def fixture_handler_and_stats(
294289
reactive_stream: Subject, ) -> tuple[LangchainProfilerHandler, list[IntermediateStepPayload]]:

0 commit comments

Comments
 (0)