|
15 | 15 | """Tests for the ATIF converter.""" |
16 | 16 |
|
17 | 17 | import pytest |
| 18 | +from langchain_core.messages import ToolMessage |
18 | 19 |
|
19 | 20 | from nat.data_models.atif import ATIFTrajectory |
20 | 21 | from nat.data_models.intermediate_step import IntermediateStep |
@@ -505,3 +506,129 @@ def test_stream_matches_batch( |
505 | 506 | assert s_step.message == b_step.message |
506 | 507 | if b_step.tool_calls: |
507 | 508 | assert len(s_step.tool_calls) == len(b_step.tool_calls) |
| 509 | + |
| 510 | + |
| 511 | +# --------------------------------------------------------------------------- |
| 512 | +# Tool error → ATIF conversion tests |
| 513 | +# --------------------------------------------------------------------------- |
| 514 | + |
| 515 | + |
| 516 | +@pytest.fixture(name="error_trajectory") |
| 517 | +def fixture_error_trajectory() -> list[IntermediateStep]: |
| 518 | + """Trajectory with one successful and one failed tool call.""" |
| 519 | + error_output: ToolMessage = ToolMessage( |
| 520 | + content="ValueError: bad input", |
| 521 | + name="failing_tool", |
| 522 | + tool_call_id="failing_tool", |
| 523 | + status="error", |
| 524 | + ) |
| 525 | + return [ |
| 526 | + _make_step(IntermediateStepType.WORKFLOW_START, input_data="Do something", timestamp_offset=0.0), |
| 527 | + _make_step(IntermediateStepType.LLM_END, |
| 528 | + name="gpt-4", |
| 529 | + output_data="calling tools", |
| 530 | + timestamp_offset=1.0, |
| 531 | + usage=_make_usage(100, 20)), |
| 532 | + _make_step(IntermediateStepType.TOOL_END, |
| 533 | + name="good_tool", |
| 534 | + input_data={"q": "hello"}, |
| 535 | + output_data="success", |
| 536 | + timestamp_offset=2.0, |
| 537 | + step_uuid="tool-good"), |
| 538 | + _make_step(IntermediateStepType.TOOL_END, |
| 539 | + name="failing_tool", |
| 540 | + input_data={"q": "fail"}, |
| 541 | + output_data=error_output, |
| 542 | + timestamp_offset=3.0, |
| 543 | + step_uuid="tool-fail"), |
| 544 | + _make_step(IntermediateStepType.WORKFLOW_END, output_data="partial", timestamp_offset=4.0), |
| 545 | + ] |
| 546 | + |
| 547 | + |
| 548 | +class TestToolErrorATIFConversion: |
| 549 | + """Verify tool errors in IntermediateStepPayload are converted to ATIF step.extra['tool_errors'].""" |
| 550 | + |
| 551 | + def test_error_dict_has_all_required_keys( |
| 552 | + self, |
| 553 | + batch_converter: IntermediateStepToATIFConverter, |
| 554 | + error_trajectory: list[IntermediateStep], |
| 555 | + ): |
| 556 | + """Each tool_errors entry contains exactly the expected keys.""" |
| 557 | + result: ATIFTrajectory = batch_converter.convert(error_trajectory) |
| 558 | + agent_step = result.steps[1] |
| 559 | + errors: list = agent_step.extra["tool_errors"] |
| 560 | + assert len(errors) == 1 |
| 561 | + assert set(errors[0].keys()) == {"tool", "error", "error_type", "error_message", "status"} |
| 562 | + |
| 563 | + def test_error_dict_values_are_parsed_from_content( |
| 564 | + self, |
| 565 | + batch_converter: IntermediateStepToATIFConverter, |
| 566 | + error_trajectory: list[IntermediateStep], |
| 567 | + ): |
| 568 | + """The error dict splits the exception type from the message and preserves the full error string.""" |
| 569 | + result: ATIFTrajectory = batch_converter.convert(error_trajectory) |
| 570 | + entry: dict = result.steps[1].extra["tool_errors"][0] |
| 571 | + assert entry["tool"] == "failing_tool" |
| 572 | + assert entry["status"] == "error" |
| 573 | + assert entry["error"] == "ValueError: bad input" |
| 574 | + assert entry["error_type"] == "ValueError" |
| 575 | + assert entry["error_message"] == "bad input" |
| 576 | + |
| 577 | + def test_error_dict_falls_back_to_unknown_type(self): |
| 578 | + """Error content without a parseable exception type defaults to 'Unknown'.""" |
| 579 | + error_output: ToolMessage = ToolMessage( |
| 580 | + content="something went wrong", |
| 581 | + name="broken_tool", |
| 582 | + tool_call_id="broken_tool", |
| 583 | + status="error", |
| 584 | + ) |
| 585 | + trajectory: list[IntermediateStep] = [ |
| 586 | + _make_step(IntermediateStepType.WORKFLOW_START, input_data="q", timestamp_offset=0.0), |
| 587 | + _make_step(IntermediateStepType.LLM_END, |
| 588 | + name="gpt-4", |
| 589 | + output_data="calling", |
| 590 | + timestamp_offset=1.0, |
| 591 | + usage=_make_usage(10, 5)), |
| 592 | + _make_step(IntermediateStepType.TOOL_END, |
| 593 | + name="broken_tool", |
| 594 | + input_data={}, |
| 595 | + output_data=error_output, |
| 596 | + timestamp_offset=2.0, |
| 597 | + step_uuid="tool-broken"), |
| 598 | + _make_step(IntermediateStepType.WORKFLOW_END, output_data="done", timestamp_offset=3.0), |
| 599 | + ] |
| 600 | + result: ATIFTrajectory = IntermediateStepToATIFConverter().convert(trajectory) |
| 601 | + entry: dict = result.steps[1].extra["tool_errors"][0] |
| 602 | + assert entry["error_type"] == "Unknown" |
| 603 | + assert entry["error_message"] == "something went wrong" |
| 604 | + |
| 605 | + def test_successful_tool_has_no_tool_errors( |
| 606 | + self, |
| 607 | + batch_converter: IntermediateStepToATIFConverter, |
| 608 | + simple_trajectory: list[IntermediateStep], |
| 609 | + ): |
| 610 | + """Successful tool calls do not produce tool_errors entries in the ATIF output.""" |
| 611 | + result: ATIFTrajectory = batch_converter.convert(simple_trajectory) |
| 612 | + for step in result.steps: |
| 613 | + assert not (step.extra or {}).get("tool_errors") |
| 614 | + |
| 615 | + def test_stream_and_batch_produce_same_errors( |
| 616 | + self, |
| 617 | + batch_converter: IntermediateStepToATIFConverter, |
| 618 | + error_trajectory: list[IntermediateStep], |
| 619 | + ): |
| 620 | + """Both converter code paths produce identical tool_errors for the same input trajectory.""" |
| 621 | + batch_result: ATIFTrajectory = batch_converter.convert(error_trajectory) |
| 622 | + stream_conv: ATIFStreamConverter = ATIFStreamConverter() |
| 623 | + for ist in error_trajectory: |
| 624 | + stream_conv.push(ist) |
| 625 | + stream_conv.finalize() |
| 626 | + stream_result: ATIFTrajectory = stream_conv.get_trajectory() |
| 627 | + |
| 628 | + def _collect_errors(trajectory: ATIFTrajectory) -> list[dict]: |
| 629 | + errors: list[dict] = [] |
| 630 | + for step in trajectory.steps: |
| 631 | + errors.extend((step.extra or {}).get("tool_errors", [])) |
| 632 | + return errors |
| 633 | + |
| 634 | + assert _collect_errors(batch_result) == _collect_errors(stream_result) |
0 commit comments