|
2 | 2 | import asyncio |
3 | 3 | import pytest |
4 | 4 |
|
5 | | -from agentops.sdk.decorators import agent, operation, session, workflow, task |
| 5 | +from agentops.sdk.decorators import agent, operation, session, workflow, task, tool |
6 | 6 | from agentops.semconv import SpanKind |
7 | 7 | from agentops.semconv.span_attributes import SpanAttributes |
8 | 8 | from tests.unit.sdk.instrumentation_tester import InstrumentationTester |
@@ -624,3 +624,140 @@ def __init__(self): |
624 | 624 | with pytest.raises(ValueError): |
625 | 625 | async with TestClass() as instance: |
626 | 626 | raise ValueError("Trigger exception for __aexit__ coverage") |
| 627 | + |
| 628 | + |
| 629 | +class TestToolDecorator: |
| 630 | + """Tests for the tool decorator functionality.""" |
| 631 | + |
| 632 | + @pytest.fixture |
| 633 | + def agent_class(self): |
| 634 | + @agent |
| 635 | + class TestAgent: |
| 636 | + @tool(cost=0.01) |
| 637 | + def process_item(self, item): |
| 638 | + return f"Processed {item}" |
| 639 | + |
| 640 | + @tool(cost=0.02) |
| 641 | + async def async_process_item(self, item): |
| 642 | + await asyncio.sleep(0.1) |
| 643 | + return f"Async processed {item}" |
| 644 | + |
| 645 | + @tool(cost=0.03) |
| 646 | + def generator_process_items(self, items): |
| 647 | + for item in items: |
| 648 | + yield self.process_item(item) |
| 649 | + |
| 650 | + @tool(cost=0.04) |
| 651 | + async def async_generator_process_items(self, items): |
| 652 | + for item in items: |
| 653 | + await asyncio.sleep(0.1) |
| 654 | + yield await self.async_process_item(item) |
| 655 | + |
| 656 | + return TestAgent() |
| 657 | + |
| 658 | + def test_sync_tool_cost(self, agent_class, instrumentation: InstrumentationTester): |
| 659 | + """Test synchronous tool with cost attribute.""" |
| 660 | + result = agent_class.process_item("test") |
| 661 | + |
| 662 | + assert result == "Processed test" |
| 663 | + |
| 664 | + spans = instrumentation.get_finished_spans() |
| 665 | + tool_span = next( |
| 666 | + span for span in spans if span.attributes.get(SpanAttributes.AGENTOPS_SPAN_KIND) == SpanKind.TOOL |
| 667 | + ) |
| 668 | + assert tool_span.attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.01 |
| 669 | + |
| 670 | + @pytest.mark.asyncio |
| 671 | + async def test_async_tool_cost(self, agent_class, instrumentation: InstrumentationTester): |
| 672 | + """Test asynchronous tool with cost attribute.""" |
| 673 | + result = await agent_class.async_process_item("test") |
| 674 | + |
| 675 | + assert result == "Async processed test" |
| 676 | + |
| 677 | + spans = instrumentation.get_finished_spans() |
| 678 | + tool_span = next( |
| 679 | + span for span in spans if span.attributes.get(SpanAttributes.AGENTOPS_SPAN_KIND) == SpanKind.TOOL |
| 680 | + ) |
| 681 | + assert tool_span.attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.02 |
| 682 | + |
| 683 | + def test_generator_tool_cost(self, agent_class, instrumentation: InstrumentationTester): |
| 684 | + """Test generator tool with cost attribute.""" |
| 685 | + items = ["item1", "item2", "item3"] |
| 686 | + results = list(agent_class.generator_process_items(items)) |
| 687 | + |
| 688 | + assert len(results) == 3 |
| 689 | + assert results[0] == "Processed item1" |
| 690 | + assert results[1] == "Processed item2" |
| 691 | + assert results[2] == "Processed item3" |
| 692 | + |
| 693 | + spans = instrumentation.get_finished_spans() |
| 694 | + tool_spans = [span for span in spans if span.attributes.get(SpanAttributes.AGENTOPS_SPAN_KIND) == SpanKind.TOOL] |
| 695 | + assert len(tool_spans) == 4 # Only one span for the generator |
| 696 | + assert tool_spans[0].attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.01 |
| 697 | + assert tool_spans[3].attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.03 |
| 698 | + |
| 699 | + @pytest.mark.asyncio |
| 700 | + async def test_async_generator_tool_cost(self, agent_class, instrumentation: InstrumentationTester): |
| 701 | + """Test async generator tool with cost attribute.""" |
| 702 | + items = ["item1", "item2", "item3"] |
| 703 | + results = [result async for result in agent_class.async_generator_process_items(items)] |
| 704 | + |
| 705 | + assert len(results) == 3 |
| 706 | + assert results[0] == "Async processed item1" |
| 707 | + assert results[1] == "Async processed item2" |
| 708 | + assert results[2] == "Async processed item3" |
| 709 | + |
| 710 | + spans = instrumentation.get_finished_spans() |
| 711 | + tool_span = [span for span in spans if span.attributes.get(SpanAttributes.AGENTOPS_SPAN_KIND) == SpanKind.TOOL] |
| 712 | + assert len(tool_span) == 4 # Only one span for the generator |
| 713 | + assert tool_span[0].attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.02 |
| 714 | + assert tool_span[3].attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.04 |
| 715 | + |
| 716 | + def test_multiple_tool_calls(self, agent_class, instrumentation: InstrumentationTester): |
| 717 | + """Test multiple calls to the same tool.""" |
| 718 | + for i in range(3): |
| 719 | + result = agent_class.process_item(f"item{i}") |
| 720 | + assert result == f"Processed item{i}" |
| 721 | + |
| 722 | + spans = instrumentation.get_finished_spans() |
| 723 | + tool_spans = [span for span in spans if span.attributes.get(SpanAttributes.AGENTOPS_SPAN_KIND) == SpanKind.TOOL] |
| 724 | + assert len(tool_spans) == 3 |
| 725 | + for span in tool_spans: |
| 726 | + assert span.attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.01 |
| 727 | + |
| 728 | + @pytest.mark.asyncio |
| 729 | + async def test_parallel_tool_calls(self, agent_class, instrumentation: InstrumentationTester): |
| 730 | + """Test parallel execution of async tools.""" |
| 731 | + results = await asyncio.gather( |
| 732 | + agent_class.async_process_item("item1"), |
| 733 | + agent_class.async_process_item("item2"), |
| 734 | + agent_class.async_process_item("item3"), |
| 735 | + ) |
| 736 | + |
| 737 | + assert len(results) == 3 |
| 738 | + assert results[0] == "Async processed item1" |
| 739 | + assert results[1] == "Async processed item2" |
| 740 | + assert results[2] == "Async processed item3" |
| 741 | + |
| 742 | + spans = instrumentation.get_finished_spans() |
| 743 | + tool_spans = [span for span in spans if span.attributes.get(SpanAttributes.AGENTOPS_SPAN_KIND) == SpanKind.TOOL] |
| 744 | + assert len(tool_spans) == 3 |
| 745 | + for span in tool_spans: |
| 746 | + assert span.attributes.get(SpanAttributes.LLM_USAGE_TOOL_COST) == 0.02 |
| 747 | + |
| 748 | + def test_tool_without_cost(self, agent_class, instrumentation: InstrumentationTester): |
| 749 | + """Test tool without cost parameter.""" |
| 750 | + |
| 751 | + @tool |
| 752 | + def no_cost_tool(self): |
| 753 | + return "No cost tool result" |
| 754 | + |
| 755 | + result = no_cost_tool(agent_class) |
| 756 | + |
| 757 | + assert result == "No cost tool result" |
| 758 | + |
| 759 | + spans = instrumentation.get_finished_spans() |
| 760 | + tool_span = next( |
| 761 | + span for span in spans if span.attributes.get(SpanAttributes.AGENTOPS_SPAN_KIND) == SpanKind.TOOL |
| 762 | + ) |
| 763 | + assert SpanAttributes.LLM_USAGE_TOOL_COST not in tool_span.attributes |
0 commit comments