|
| 1 | +# Copyright (c) Microsoft. All rights reserved. |
| 2 | + |
| 3 | +"""Benchmark CodeAct vs. traditional tool-calling for a multi-tool-call task. |
| 4 | +
|
| 5 | +This sample runs the same prompt against the same FoundryChatClient twice: |
| 6 | +
|
| 7 | +1. **Traditional tool-calling**: the five business tools are passed directly to |
| 8 | + the agent, so the model calls each tool individually via the LLM tool-call |
| 9 | + interface. |
| 10 | +2. **CodeAct**: the same tools are registered on a HyperlightCodeActProvider |
| 11 | + and the model sees a single ``execute_code`` tool that calls them from |
| 12 | + inside the Hyperlight sandbox via ``call_tool(...)``. |
| 13 | +
|
| 14 | +The task (computing grand totals per user) naturally requires many tool calls |
| 15 | +to complete. At the end, the sample prints elapsed time and token usage for |
| 16 | +each run so the two approaches can be compared. |
| 17 | +
|
| 18 | +Run with: |
| 19 | + cd python |
| 20 | + uv run --directory packages/hyperlight python samples/codeact_benchmark.py |
| 21 | +
|
| 22 | +Required environment variables (loaded from ``.env`` if present): |
| 23 | + FOUNDRY_PROJECT_ENDPOINT |
| 24 | + FOUNDRY_MODEL |
| 25 | +""" |
| 26 | + |
| 27 | +from __future__ import annotations |
| 28 | + |
| 29 | +import asyncio |
| 30 | +import os |
| 31 | +import time |
| 32 | +from typing import Annotated, Any, Literal |
| 33 | + |
| 34 | +from agent_framework import Agent, AgentResponse, UsageDetails |
| 35 | +from agent_framework.foundry import FoundryChatClient |
| 36 | +from azure.identity import AzureCliCredential |
| 37 | +from dotenv import load_dotenv |
| 38 | +from pydantic import BaseModel, Field |
| 39 | + |
| 40 | +from agent_framework_hyperlight import HyperlightCodeActProvider |
| 41 | + |
| 42 | +load_dotenv() |
| 43 | + |
| 44 | + |
| 45 | +# 1. Deterministic "business" data and tools. |
| 46 | + |
| 47 | +_USERS: list[dict[str, Any]] = [ |
| 48 | + {"id": 1, "name": "Alice", "region": "EU", "tier": "gold"}, |
| 49 | + {"id": 2, "name": "Bob", "region": "US", "tier": "silver"}, |
| 50 | + {"id": 3, "name": "Charlie", "region": "US", "tier": "gold"}, |
| 51 | + {"id": 4, "name": "Diana", "region": "APAC", "tier": "bronze"}, |
| 52 | + {"id": 5, "name": "Evan", "region": "EU", "tier": "silver"}, |
| 53 | + {"id": 6, "name": "Fiona", "region": "US", "tier": "gold"}, |
| 54 | + {"id": 7, "name": "George", "region": "APAC", "tier": "gold"}, |
| 55 | + {"id": 8, "name": "Hana", "region": "EU", "tier": "bronze"}, |
| 56 | +] |
| 57 | + |
| 58 | +_ORDERS: dict[int, list[dict[str, Any]]] = { |
| 59 | + 1: [{"product": "Widget", "qty": 3, "unit_price": 9.99}, {"product": "Gadget", "qty": 1, "unit_price": 19.99}], |
| 60 | + 2: [{"product": "Widget", "qty": 1, "unit_price": 9.99}], |
| 61 | + 3: [{"product": "Gadget", "qty": 2, "unit_price": 19.99}, {"product": "Thingamajig", "qty": 4, "unit_price": 4.50}], |
| 62 | + 4: [{"product": "Widget", "qty": 10, "unit_price": 9.99}], |
| 63 | + 5: [{"product": "Gadget", "qty": 1, "unit_price": 19.99}], |
| 64 | + 6: [{"product": "Widget", "qty": 2, "unit_price": 9.99}, {"product": "Thingamajig", "qty": 5, "unit_price": 4.50}], |
| 65 | + 7: [{"product": "Gadget", "qty": 3, "unit_price": 19.99}], |
| 66 | + 8: [{"product": "Thingamajig", "qty": 2, "unit_price": 4.50}], |
| 67 | +} |
| 68 | + |
| 69 | +_DISCOUNTS: dict[str, float] = {"gold": 0.20, "silver": 0.10, "bronze": 0.05} |
| 70 | +_TAX_RATES: dict[str, float] = {"EU": 0.21, "US": 0.08, "APAC": 0.10} |
| 71 | + |
| 72 | + |
| 73 | +def list_users() -> list[dict[str, Any]]: |
| 74 | + """Return all users as a list of dictionaries. |
| 75 | +
|
| 76 | + Each entry has keys: id (int), name (str), region (str), tier (str). |
| 77 | + """ |
| 78 | + return _USERS |
| 79 | + |
| 80 | + |
| 81 | +def get_orders_for_user( |
| 82 | + user_id: Annotated[int, "The user id whose orders to retrieve."], |
| 83 | +) -> list[dict[str, Any]]: |
| 84 | + """Return the user's orders as a list of dictionaries. |
| 85 | +
|
| 86 | + Each entry has keys: product (str), qty (int), unit_price (float). |
| 87 | + """ |
| 88 | + return _ORDERS.get(user_id, []) |
| 89 | + |
| 90 | + |
| 91 | +def get_discount_rate( |
| 92 | + tier: Annotated[Literal["gold", "silver", "bronze"], "The customer tier."], |
| 93 | +) -> float: |
| 94 | + """Return the discount rate as a float fraction (e.g. 0.2 for 20%).""" |
| 95 | + return _DISCOUNTS[tier] |
| 96 | + |
| 97 | + |
| 98 | +def get_tax_rate( |
| 99 | + region: Annotated[Literal["EU", "US", "APAC"], "The region code."], |
| 100 | +) -> float: |
| 101 | + """Return the tax rate as a float fraction (e.g. 0.21 for 21%).""" |
| 102 | + return _TAX_RATES[region] |
| 103 | + |
| 104 | + |
| 105 | +def compute_line_total( |
| 106 | + qty: Annotated[int, "Line item quantity."], |
| 107 | + unit_price: Annotated[float, "Line item unit price."], |
| 108 | + discount_rate: Annotated[float, "Discount rate as a fraction (e.g. 0.2 for 20%)."], |
| 109 | + tax_rate: Annotated[float, "Tax rate as a fraction (e.g. 0.21 for 21%)."], |
| 110 | +) -> float: |
| 111 | + """Compute a single order line total. |
| 112 | +
|
| 113 | + Formula: qty * unit_price * (1 - discount_rate) * (1 + tax_rate), rounded to 2 decimals. |
| 114 | + """ |
| 115 | + subtotal = qty * unit_price |
| 116 | + discounted = subtotal * (1.0 - discount_rate) |
| 117 | + return round(discounted * (1.0 + tax_rate), 2) |
| 118 | + |
| 119 | + |
| 120 | +TOOLS = [list_users, get_orders_for_user, get_discount_rate, get_tax_rate, compute_line_total] |
| 121 | + |
| 122 | + |
| 123 | +# 2. Structured output schema shared between both runs. |
| 124 | + |
| 125 | + |
| 126 | +class UserTotal(BaseModel): |
| 127 | + """A user's grand total of all their orders.""" |
| 128 | + |
| 129 | + user_id: int = Field(description="The user's id.") |
| 130 | + name: str = Field(description="The user's display name.") |
| 131 | + grand_total: float = Field(description="Sum of all line totals, rounded to 2 decimals.") |
| 132 | + |
| 133 | + |
| 134 | +class UserGrandTotals(BaseModel): |
| 135 | + """Structured output schema for both runs.""" |
| 136 | + |
| 137 | + results: list[UserTotal] = Field(description="One entry per user, sorted by grand_total descending.") |
| 138 | + |
| 139 | + |
| 140 | +INSTRUCTIONS = "You are a careful assistant. Use the provided tools for every lookup and computation." |
| 141 | + |
| 142 | +BENCHMARK_PROMPT = ( |
| 143 | + "For every user in our system (there are 8 of them), compute the grand total of all their orders. " |
| 144 | + "Use the compute_line_total tool for each user's orders, after looking up the relevant discount and " |
| 145 | + "tax rates for that user. " |
| 146 | + "Use the provided tools for EVERY data lookup (users, orders, discount rates, tax rates) and for EVERY " |
| 147 | + "line-total computation via compute_line_total — do not invent values or hardcode any numbers. " |
| 148 | + "The total per order item should apply the discount first and then the tax " |
| 149 | + "(e.g. total = qty * unit_price * (1-discount) * (1+tax)). " |
| 150 | + "Return one entry per user, sorted by grand_total descending." |
| 151 | +) |
| 152 | + |
| 153 | + |
| 154 | +def get_client() -> FoundryChatClient: |
| 155 | + """Create a FoundryChatClient from environment variables.""" |
| 156 | + return FoundryChatClient( |
| 157 | + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], |
| 158 | + model=os.environ["FOUNDRY_MODEL"], |
| 159 | + credential=AzureCliCredential(), |
| 160 | + ) |
| 161 | + |
| 162 | + |
| 163 | +# 3. Two runners that share the same tools, prompt, and structured output schema. |
| 164 | + |
| 165 | + |
| 166 | +async def _run_traditional() -> tuple[float, AgentResponse]: |
| 167 | + agent = Agent( |
| 168 | + client=get_client(), |
| 169 | + name="TraditionalAgent", |
| 170 | + instructions=INSTRUCTIONS, |
| 171 | + tools=TOOLS, |
| 172 | + default_options={"response_format": UserGrandTotals}, |
| 173 | + ) |
| 174 | + start = time.perf_counter() |
| 175 | + result = await agent.run(BENCHMARK_PROMPT) |
| 176 | + elapsed = time.perf_counter() - start |
| 177 | + return elapsed, result |
| 178 | + |
| 179 | + |
| 180 | +async def _run_codeact() -> tuple[float, AgentResponse]: |
| 181 | + codeact = HyperlightCodeActProvider( |
| 182 | + tools=TOOLS, |
| 183 | + approval_mode="never_require", |
| 184 | + ) |
| 185 | + agent = Agent( |
| 186 | + client=get_client(), |
| 187 | + name="CodeActAgent", |
| 188 | + instructions=INSTRUCTIONS, |
| 189 | + context_providers=[codeact], |
| 190 | + default_options={"response_format": UserGrandTotals}, |
| 191 | + ) |
| 192 | + start = time.perf_counter() |
| 193 | + result = await agent.run(BENCHMARK_PROMPT) |
| 194 | + elapsed = time.perf_counter() - start |
| 195 | + return elapsed, result |
| 196 | + |
| 197 | + |
| 198 | +# 4. Report results side by side. |
| 199 | + |
| 200 | + |
| 201 | +def _print_section(title: str) -> None: |
| 202 | + bar = "=" * 70 |
| 203 | + print(f"\n{bar}\n{title}\n{bar}") |
| 204 | + |
| 205 | + |
| 206 | +def _format_usage(usage: UsageDetails | None) -> str: |
| 207 | + if usage is None: |
| 208 | + return "usage=<none>" |
| 209 | + return ( |
| 210 | + f"input={usage.get('input_token_count') or 0:>6} " |
| 211 | + f"output={usage.get('output_token_count') or 0:>6} " |
| 212 | + f"total={usage.get('total_token_count') or 0:>6}" |
| 213 | + ) |
| 214 | + |
| 215 | + |
| 216 | +def _print_results(result: AgentResponse) -> None: |
| 217 | + if result.value is not None: |
| 218 | + for row in result.value.results: |
| 219 | + print(f" user_id={row.user_id:>2} name={row.name:<8} grand_total={row.grand_total:>8.2f}") |
| 220 | + else: |
| 221 | + print(result.text) |
| 222 | + |
| 223 | + |
| 224 | +async def main() -> None: |
| 225 | + """Run the benchmark and print a comparison.""" |
| 226 | + trad_time, trad_result = await _run_traditional() |
| 227 | + code_time, code_result = await _run_codeact() |
| 228 | + |
| 229 | + _print_section("Traditional tool-calling") |
| 230 | + print(f"time={trad_time:7.2f}s {_format_usage(trad_result.usage_details)}") |
| 231 | + _print_results(trad_result) |
| 232 | + |
| 233 | + _print_section("CodeAct (HyperlightCodeActProvider)") |
| 234 | + print(f"time={code_time:7.2f}s {_format_usage(code_result.usage_details)}") |
| 235 | + _print_results(code_result) |
| 236 | + |
| 237 | + _print_section("Comparison") |
| 238 | + trad_total = (trad_result.usage_details or {}).get("total_token_count") or 0 |
| 239 | + code_total = (code_result.usage_details or {}).get("total_token_count") or 0 |
| 240 | + |
| 241 | + def pct(new: float, old: float) -> str: |
| 242 | + if old == 0: |
| 243 | + return "n/a" |
| 244 | + delta = (new - old) / old * 100 |
| 245 | + sign = "+" if delta >= 0 else "" |
| 246 | + return f"{sign}{delta:.1f}%" |
| 247 | + |
| 248 | + print(f"time : traditional={trad_time:7.2f}s codeact={code_time:7.2f}s delta={pct(code_time, trad_time)}") |
| 249 | + print(f"tokens : traditional={trad_total:7d} codeact={code_total:7d} delta={pct(code_total, trad_total)}") |
| 250 | + |
| 251 | + |
| 252 | +if __name__ == "__main__": |
| 253 | + asyncio.run(main()) |
0 commit comments