|
141 | 141 | " task_indices: list[int] | None = None,\n", |
142 | 142 | ") -> tuple[TaskCollection, list[Dict[str, Any]]]:\n", |
143 | 143 | " \"\"\"Load tasks and agent configurations.\n", |
144 | | - " \n", |
| 144 | + "\n", |
145 | 145 | " Args:\n", |
146 | 146 | " config_type: 'single' or 'multi' agent configuration\n", |
147 | 147 | " framework: Agent framework to use\n", |
|
150 | 150 | " limit: Optional limit on number of tasks (None = all 5)\n", |
151 | 151 | " seed: Random seed for reproducibility\n", |
152 | 152 | " task_indices: Optional list of task indices to load (e.g., [0, 2, 4])\n", |
153 | | - " \n", |
| 153 | + "\n", |
154 | 154 | " Returns:\n", |
155 | 155 | " Tuple of (TaskCollection, list of agent configs)\n", |
156 | 156 | " \"\"\"\n", |
157 | 157 | " data_dir = Path(\"examples/five_a_day_benchmark/data\")\n", |
158 | | - " \n", |
| 158 | + "\n", |
159 | 159 | " with open(data_dir / \"tasks.json\", \"r\") as f:\n", |
160 | 160 | " tasks_raw = json.load(f)\n", |
161 | 161 | " with open(data_dir / f\"{config_type}agent.json\", \"r\") as f:\n", |
162 | 162 | " configs_raw = json.load(f)\n", |
163 | | - " \n", |
| 163 | + "\n", |
164 | 164 | " # Apply limit first\n", |
165 | 165 | " if limit:\n", |
166 | 166 | " tasks_raw = tasks_raw[:limit]\n", |
167 | 167 | " configs_raw = configs_raw[:limit]\n", |
168 | | - " \n", |
| 168 | + "\n", |
169 | 169 | " # Then apply task_indices filter if specified\n", |
170 | 170 | " if task_indices is not None:\n", |
171 | 171 | " tasks_raw = [tasks_raw[i] for i in task_indices if i < len(tasks_raw)]\n", |
172 | 172 | " configs_raw = [configs_raw[i] for i in task_indices if i < len(configs_raw)]\n", |
173 | | - " \n", |
| 173 | + "\n", |
174 | 174 | " tasks_data = []\n", |
175 | 175 | " configs_data = []\n", |
176 | | - " \n", |
| 176 | + "\n", |
177 | 177 | " for task_dict, config in zip(tasks_raw, configs_raw):\n", |
178 | 178 | " task_id = task_dict[\"metadata\"][\"task_id\"]\n", |
179 | 179 | " task_dict[\"environment_data\"][\"agent_framework\"] = framework\n", |
180 | | - " \n", |
| 180 | + "\n", |
181 | 181 | " # Create Task object\n", |
182 | 182 | " tasks_data.append(\n", |
183 | 183 | " Task(\n", |
|
187 | 187 | " metadata=task_dict[\"metadata\"],\n", |
188 | 188 | " )\n", |
189 | 189 | " )\n", |
190 | | - " \n", |
| 190 | + "\n", |
191 | 191 | " # Enrich config with framework and model info\n", |
192 | 192 | " config[\"framework\"] = framework\n", |
193 | 193 | " config[\"model_config\"] = {\"model_id\": model_id, \"temperature\": temperature}\n", |
194 | | - " \n", |
| 194 | + "\n", |
195 | 195 | " # Derive seeds for reproducibility\n", |
196 | 196 | " if seed is not None:\n", |
197 | 197 | " for agent_spec in config[\"agents\"]:\n", |
198 | 198 | " agent_spec[\"seed\"] = derive_seed(seed, task_id, agent_spec[\"agent_id\"])\n", |
199 | | - " \n", |
| 199 | + "\n", |
200 | 200 | " configs_data.append(config)\n", |
201 | | - " \n", |
| 201 | + "\n", |
202 | 202 | " return TaskCollection(tasks_data), configs_data" |
203 | 203 | ] |
204 | 204 | }, |
|
224 | 224 | "# Tell litellm to drop unsupported params (like 'seed' for Gemini)\n", |
225 | 225 | "litellm.drop_params = True\n", |
226 | 226 | "\n", |
| 227 | + "\n", |
227 | 228 | "def get_model(model_id: str, temperature: float = 0.7, seed: int | None = None):\n", |
228 | 229 | " \"\"\"Create a model instance compatible with smolagents.\n", |
229 | | - " \n", |
| 230 | + "\n", |
230 | 231 | " Args:\n", |
231 | 232 | " model_id: Model name (e.g., 'gemini-2.5-flash', 'gpt-4')\n", |
232 | 233 | " temperature: Randomness (0.0 = deterministic, 1.0 = creative)\n", |
233 | 234 | " seed: Random seed for reproducible outputs (ignored for models that don't support it)\n", |
234 | | - " \n", |
| 235 | + "\n", |
235 | 236 | " Returns:\n", |
236 | 237 | " LiteLLMModel configured for smolagents\n", |
237 | 238 | " \"\"\"\n", |
|
242 | 243 | " seed=seed, # Will be dropped by litellm for providers that don't support it\n", |
243 | 244 | " )\n", |
244 | 245 | "\n", |
| 246 | + "\n", |
245 | 247 | "# Test the model factory\n", |
246 | 248 | "model = get_model(\"gemini-2.5-flash\", temperature=0.7, seed=42)\n", |
247 | 249 | "print(f\"Created model: {model.model_id}\")" |
|
275 | 277 | "\n", |
276 | 278 | "# Extract the first (and only) task and config\n", |
277 | 279 | "task_0: Task = task_data[0]\n", |
278 | | - "config_0: Dict[str,Any] = agent_configs[0]\n", |
| 280 | + "config_0: Dict[str, Any] = agent_configs[0]\n", |
279 | 281 | "\n", |
280 | 282 | "print(\"=\" * 60)\n", |
281 | 283 | "print(\"TASK 0: Email & Banking\")\n", |
|
309 | 311 | "print(f\"Agent Type: {config_0['agent_type']}\")\n", |
310 | 312 | "print(f\"Primary Agent: {config_0['primary_agent_id']}\\n\")\n", |
311 | 313 | "\n", |
312 | | - "for i, agent_spec in enumerate(config_0['agents'], 1):\n", |
| 314 | + "for i, agent_spec in enumerate(config_0[\"agents\"], 1):\n", |
313 | 315 | " print(f\"{i}. {agent_spec['agent_name']} (ID: {agent_spec['agent_id']})\")\n", |
314 | 316 | " print(f\" Tools: {agent_spec['tools'] if agent_spec['tools'] else 'None (delegates only)'}\")\n", |
315 | 317 | " print(f\" Role: {agent_spec['agent_instruction'][:80]}...\")\n", |
|
384 | 386 | " specialist_agents = []\n", |
385 | 387 | "\n", |
386 | 388 | " temperature = agent_data[\"model_config\"][\"temperature\"]\n", |
387 | | - " \n", |
| 389 | + "\n", |
388 | 390 | " primary_agent_id = agent_data[\"primary_agent_id\"]\n", |
389 | 391 | " agents_specs = agent_data[\"agents\"]\n", |
390 | 392 | " all_tool_adapters = environment.get_tools()\n", |
391 | | - " \n", |
| 393 | + "\n", |
392 | 394 | " # Build specialists first\n", |
393 | 395 | " specialist_agents = []\n", |
394 | 396 | " for agent_spec in agents_specs:\n", |
395 | 397 | " if agent_spec[\"agent_id\"] == primary_agent_id:\n", |
396 | 398 | " continue\n", |
397 | | - " \n", |
| 399 | + "\n", |
398 | 400 | " seed = agent_spec.get(\"seed\")\n", |
399 | 401 | " model = get_model(model_id, temperature, seed)\n", |
400 | 402 | " spec_tool_adapters = filter_tool_adapters_by_prefix(all_tool_adapters, agent_spec[\"tools\"])\n", |
401 | 403 | " spec_tools = [adapter.tool for adapter in spec_tool_adapters]\n", |
402 | 404 | " spec_tools.append(FinalAnswerTool())\n", |
403 | | - " \n", |
| 405 | + "\n", |
404 | 406 | " specialist = ToolCallingAgent(\n", |
405 | 407 | " model=model,\n", |
406 | 408 | " tools=spec_tools,\n", |
|
410 | 412 | " verbosity_level=0,\n", |
411 | 413 | " )\n", |
412 | 414 | " specialist_agents.append(specialist)\n", |
413 | | - " \n", |
| 415 | + "\n", |
414 | 416 | " # Build orchestrator\n", |
415 | 417 | " primary_spec = next(a for a in agents_specs if a[\"agent_id\"] == primary_agent_id)\n", |
416 | 418 | " primary_seed = primary_spec.get(\"seed\")\n", |
417 | 419 | " primary_model = get_model(model_id, temperature, primary_seed)\n", |
418 | | - " \n", |
| 420 | + "\n", |
419 | 421 | " orchestrator = ToolCallingAgent(\n", |
420 | 422 | " model=primary_model,\n", |
421 | 423 | " tools=[FinalAnswerTool()],\n", |
|
425 | 427 | " verbosity_level=0,\n", |
426 | 428 | " )\n", |
427 | 429 | "\n", |
428 | | - " return [orchestrator], {agent.name: agent for agent in specialist_agents}\n" |
| 430 | + " return [orchestrator], {agent.name: agent for agent in specialist_agents}" |
429 | 431 | ] |
430 | 432 | }, |
431 | 433 | { |
|
475 | 477 | " \"\"\"Initialize environment state from task data.\"\"\"\n", |
476 | 478 | " env_data = task_data[\"environment_data\"].copy()\n", |
477 | 479 | " tool_names = env_data.get(\"tools\", [])\n", |
478 | | - " \n", |
| 480 | + "\n", |
479 | 481 | " # Create state objects (e.g., email inboxes, bank accounts)\n", |
480 | 482 | " states = get_states(tool_names, env_data)\n", |
481 | 483 | " env_data.update(states)\n", |
482 | | - " \n", |
| 484 | + "\n", |
483 | 485 | " return env_data\n", |
484 | 486 | "\n", |
485 | 487 | " def create_tools(self) -> list:\n", |
486 | 488 | " \"\"\"Create and convert tools to framework-specific format.\"\"\"\n", |
487 | 489 | " tools_list = []\n", |
488 | | - " \n", |
| 490 | + "\n", |
489 | 491 | " # Map tool names to their collection classes\n", |
490 | 492 | " tool_mapping = {\n", |
491 | 493 | " \"email\": (EmailToolCollection, lambda: (self.state[\"email_state\"],)),\n", |
|
499 | 501 | " \"my_calendar_mcp\": (MCPCalendarToolCollection, lambda: (self.state[\"my_calendar_mcp_state\"],)),\n", |
500 | 502 | " \"other_calendar_mcp\": (MCPCalendarToolCollection, lambda: (self.state[\"other_calendar_mcp_state\"],)),\n", |
501 | 503 | " }\n", |
502 | | - " \n", |
| 504 | + "\n", |
503 | 505 | " for tool_name in self.state[\"tools\"]:\n", |
504 | 506 | " if tool_name in tool_mapping:\n", |
505 | 507 | " ToolClass, get_init_args = tool_mapping[tool_name]\n", |
506 | 508 | " tool_instance = ToolClass(*get_init_args())\n", |
507 | | - " \n", |
| 509 | + "\n", |
508 | 510 | " # Get base tools and convert to framework format\n", |
509 | 511 | " for base_tool in tool_instance.get_sub_tools():\n", |
510 | 512 | " framework_tool = base_tool.to_smolagents()\n", |
511 | 513 | " tools_list.append(framework_tool)\n", |
512 | | - " \n", |
| 514 | + "\n", |
513 | 515 | " return tools_list" |
514 | 516 | ] |
515 | 517 | }, |
|
534 | 536 | "source": [ |
535 | 537 | "print(f\"{config_0['task_description']}\")\n", |
536 | 538 | "\n", |
537 | | - "for i, agent_spec in enumerate(config_0['agents'], 1):\n", |
| 539 | + "for i, agent_spec in enumerate(config_0[\"agents\"], 1):\n", |
538 | 540 | " print(f\"{i}. {agent_spec['agent_name']} (ID: {agent_spec['agent_id']})\")\n", |
539 | 541 | " print(f\" Tools: {agent_spec['tools'] if agent_spec['tools'] else 'None (delegates only)'}\")\n", |
540 | 542 | " print(f\" Role: {agent_spec['agent_instruction'][:80]}...\")\n", |
|
560 | 562 | "# Note: model_config is already set by load_benchmark_data()\n", |
561 | 563 | "\n", |
562 | 564 | "# Create environment from task data\n", |
563 | | - "environment_0 = FiveADayEnvironment({\n", |
564 | | - " \"environment_data\": task_0.environment_data,\n", |
565 | | - " \"query\": task_0.query,\n", |
566 | | - " \"evaluation_data\": task_0.evaluation_data,\n", |
567 | | - " \"metadata\": task_0.metadata,\n", |
568 | | - "})\n", |
| 565 | + "environment_0 = FiveADayEnvironment(\n", |
| 566 | + " {\n", |
| 567 | + " \"environment_data\": task_0.environment_data,\n", |
| 568 | + " \"query\": task_0.query,\n", |
| 569 | + " \"evaluation_data\": task_0.evaluation_data,\n", |
| 570 | + " \"metadata\": task_0.metadata,\n", |
| 571 | + " }\n", |
| 572 | + ")\n", |
569 | 573 | "\n", |
570 | 574 | "# Build agents using the build_agents function\n", |
571 | 575 | "agents_to_run, agents_to_monitor = build_agents(config_0, environment_0)\n", |
572 | 576 | "\n", |
573 | 577 | "print(f\"\\nBuilt Agents for Task: {task_0.metadata['task_id']}\")\n", |
574 | | - "print(f\"{'='*60}\")\n", |
| 578 | + "print(f\"{'=' * 60}\")\n", |
575 | 579 | "print(f\"\\nAgents to run: {[agent.name for agent in agents_to_run]}\")\n", |
576 | 580 | "print(f\"Agents to monitor: {list(agents_to_monitor.keys())}\")\n", |
577 | 581 | "\n", |
|
580 | 584 | " print(f\"\\n Agent: {agent.name}\")\n", |
581 | 585 | " # smolagents stores tools as a dict with string keys\n", |
582 | 586 | " print(f\" Tools: {list(agent.tools.keys())}\")\n", |
583 | | - " if hasattr(agent, 'managed_agents') and agent.managed_agents:\n", |
| 587 | + " if hasattr(agent, \"managed_agents\") and agent.managed_agents:\n", |
584 | 588 | " # managed_agents is also a dict with string keys\n", |
585 | 589 | " print(f\" Managed agents: {list(agent.managed_agents.keys())}\")\n", |
586 | 590 | " for agent_name, managed in agent.managed_agents.items():\n", |
|
623 | 627 | " \"evaluation_data\": task.evaluation_data,\n", |
624 | 628 | " \"metadata\": task.metadata,\n", |
625 | 629 | " }\n", |
626 | | - " \n", |
| 630 | + "\n", |
627 | 631 | " environment = FiveADayEnvironment(task_data)\n", |
628 | | - " \n", |
| 632 | + "\n", |
629 | 633 | " # Register all tools for tracing\n", |
630 | 634 | " for tool_adapter in environment.get_tools():\n", |
631 | 635 | " tool_name = getattr(tool_adapter, \"name\", str(type(tool_adapter).__name__))\n", |
632 | 636 | " self.register(\"tools\", tool_name, tool_adapter)\n", |
633 | | - " \n", |
| 637 | + "\n", |
634 | 638 | " return environment\n", |
635 | 639 | "\n", |
636 | 640 | " def setup_agents(\n", |
637 | 641 | " self, agent_data: Dict[str, Any], environment: Environment, task: Task, user=None\n", |
638 | 642 | " ) -> tuple[list[SmolAgentAdapter], Dict[str, SmolAgentAdapter]]:\n", |
639 | 643 | " \"\"\"Create multi-agent system with orchestrator and specialists.\"\"\"\n", |
640 | 644 | " agents_to_run, agents_to_monitor = build_agents(agent_data, environment)\n", |
641 | | - " \n", |
| 645 | + "\n", |
642 | 646 | " # Create adapters for the primary agent(s) to run\n", |
643 | 647 | " adapters_to_run = [SmolAgentAdapter(agent, agent.name) for agent in agents_to_run]\n", |
644 | | - " \n", |
| 648 | + "\n", |
645 | 649 | " # This ensures all agent traces are collected by the benchmark\n", |
646 | 650 | " all_agents = {agent.name: agent for agent in agents_to_run} | agents_to_monitor\n", |
647 | 651 | " adapters_to_monitor = {name: SmolAgentAdapter(agent, name) for name, agent in all_agents.items()}\n", |
|
651 | 655 | " \"\"\"Create evaluators based on task's evaluation criteria.\"\"\"\n", |
652 | 656 | " if not task.evaluation_data[\"evaluators\"]:\n", |
653 | 657 | " return []\n", |
654 | | - " \n", |
| 658 | + "\n", |
655 | 659 | " evaluator_instances = []\n", |
656 | 660 | " for name in task.evaluation_data[\"evaluators\"]:\n", |
657 | 661 | " evaluator_class = getattr(evaluators, name)\n", |
658 | 662 | " evaluator_instances.append(evaluator_class(task, environment, user))\n", |
659 | | - " \n", |
| 663 | + "\n", |
660 | 664 | " return evaluator_instances\n", |
661 | 665 | "\n", |
662 | 666 | " def run_agents(self, agents: Sequence[AgentAdapter], task: Task, environment: Environment) -> Sequence[Any]:\n", |
|
741 | 745 | " fail_on_evaluation_error=True,\n", |
742 | 746 | ")\n", |
743 | 747 | "\n", |
744 | | - "results = benchmark.run(tasks=tasks)\n" |
| 748 | + "results = benchmark.run(tasks=tasks)" |
745 | 749 | ] |
746 | 750 | }, |
747 | 751 | { |
|
764 | 768 | "console = Console()\n", |
765 | 769 | "\n", |
766 | 770 | "for task in results[:2]:\n", |
767 | | - " task_id = task['task_id']\n", |
| 771 | + " task_id = task[\"task_id\"]\n", |
768 | 772 | " print(\"=\" * 60)\n", |
769 | 773 | " print(f\"Results for Task ID: {task_id}\")\n", |
770 | 774 | " print(\"=\" * 60)\n", |
771 | | - " traces = task['traces']\n", |
772 | | - " agent_traces = traces['agents']\n", |
| 775 | + " traces = task[\"traces\"]\n", |
| 776 | + " agent_traces = traces[\"agents\"]\n", |
773 | 777 | " print(f\"Traces available for agents: {list(agent_traces.keys())}\")\n", |
774 | 778 | " orchestrator_name = list(traces[\"agents\"].keys())[0]\n", |
775 | 779 | " print(f\"Last 5 messages for '{orchestrator_name}'\")\n", |
776 | 780 | " print(traces[\"agents\"].keys())\n", |
777 | 781 | " messages = traces[\"agents\"][orchestrator_name][\"messages\"]\n", |
778 | 782 | " for msg in messages[-5:]:\n", |
779 | 783 | " role = msg.get(\"role\", \"unknown\")\n", |
780 | | - " content = msg.get(\"content\", [])[0].get(\"text\", '')\n", |
| 784 | + " content = msg.get(\"content\", [])[0].get(\"text\", \"\")\n", |
781 | 785 | " panel = Panel.fit(\n", |
782 | 786 | " content,\n", |
783 | 787 | " title=f\" {role} \",\n", |
784 | 788 | " title_align=\"left\",\n", |
785 | 789 | " )\n", |
786 | | - " console.print(panel)\n" |
| 790 | + " console.print(panel)" |
787 | 791 | ] |
788 | 792 | }, |
789 | 793 | { |
|
795 | 799 | "source": [ |
796 | 800 | "# print results for first two tasks\n", |
797 | 801 | "for task in results[:2]:\n", |
798 | | - " task_id = task['task_id']\n", |
| 802 | + " task_id = task[\"task_id\"]\n", |
799 | 803 | " print(\"=\" * 60)\n", |
800 | 804 | " print(f\"Results for Task ID: {task_id}\")\n", |
801 | 805 | " print(\"=\" * 60)\n", |
802 | | - " eval_results = task['eval']\n", |
| 806 | + " eval_results = task[\"eval\"]\n", |
803 | 807 | " for evals in eval_results:\n", |
804 | | - " for k,v in evals.items():\n", |
| 808 | + " for k, v in evals.items():\n", |
805 | 809 | " print(f\"{k:<35} {v}\")" |
806 | 810 | ] |
807 | 811 | }, |
|
0 commit comments