feat(service): add tau2 agent+inference service rollout example#1226
feat(service): add tau2 agent+inference service rollout example#1226
Conversation
Combine the agent-service controller refactor with the Tau2 rollout examples so the example stack uses a single session lifecycle and data-collection path. Key changes: - move session lifecycle and collection APIs into the agent controller - add Tau2 agent-service rollout example workflow and docs - align controller tests and gateway/session handling
There was a problem hiding this comment.
Code Review
This pull request adds RL data collection capabilities to the agent service by integrating it with the inference service, specifically for Tau2 benchmarks. Key changes include new session management and reward-setting APIs in the AgentServiceController, refactored inference workflows, and new PydanticAI agent examples. Reviewers identified several issues: the removal of a safety net for finalizing sessions on failure, a memory leak in session tracking, and performance concerns due to synchronous network and tool calls blocking the asynchronous event loop. A potential TypeError in reward processing was also flagged.
| result = await self.agent.run( | ||
| data, | ||
| base_url=self.gateway_addr, | ||
| http_client=http_client, | ||
| api_key=self._admin_api_key, | ||
| task_id=str(task_id), | ||
| ) |
There was a problem hiding this comment.
This refactoring removed the try...except block that previously ensured a 0.0 reward was set on the inference service if the agent failed. While the new design delegates session management to the agent, a crash in agent.run() before it can report a reward will now leave a dangling RL session on the inference gateway. Consider adding a safety net to ensure sessions are always finalized even on failure.
|
|
||
| self._forked_services: list[tuple[str, str, int]] = [] | ||
|
|
||
| self._sessions: dict[str, dict[str, Any]] = {} |
There was a problem hiding this comment.
The _sessions dictionary grows indefinitely as new sessions are created via new_session(), but there is no mechanism to remove them once they are completed (e.g., after set_reward()). This will lead to a memory leak in long-running controller instances. Consider adding a way to prune old sessions or at least clearing them in the destroy() method.
| resp = requests.post( | ||
| f"{inf_addr}/rl/start_session", | ||
| json={"task_id": task_id}, | ||
| headers={"Authorization": f"Bearer {cfg.inference_api_key}"}, | ||
| timeout=cfg.request_timeout, | ||
| ) |
There was a problem hiding this comment.
The use of synchronous requests.post here (and in step and set_reward) will block the event loop when called from asynchronous contexts, such as the Tau2AgentServiceWorkflow introduced in this PR. This can significantly degrade performance and scalability when handling multiple concurrent rollouts. Consider using an asynchronous client like httpx.AsyncClient or wrapping these calls in asyncio.to_thread.
| raise | ||
| session_id = result["session_id"] | ||
| trajectory_id = result["trajectory_id"] | ||
| agent_reward = float(result["reward"]) |
There was a problem hiding this comment.
This line will raise a TypeError if result["reward"] is None. It is safer to handle potential null values explicitly, especially since the validation on line 107 only checks for the presence of the key.
| agent_reward = float(result["reward"]) | |
| agent_reward = float(result["reward"]) if result["reward"] is not None else 0.0 |
| async def _wrapper(**kwargs: Any) -> str: | ||
| try: | ||
| result = fn(**kwargs) | ||
| except Exception as exc: | ||
| result = f"Tool error: {exc}" | ||
| if not isinstance(result, str): | ||
| result = json.dumps(result, default=str) | ||
| return result |
There was a problem hiding this comment.
Calling a synchronous tool function directly inside an async wrapper blocks the event loop. Since this agent is used in an async rollout workflow, this can impact performance. Consider using asyncio.to_thread to run synchronous tool functions.
| async def _wrapper(**kwargs: Any) -> str: | |
| try: | |
| result = fn(**kwargs) | |
| except Exception as exc: | |
| result = f"Tool error: {exc}" | |
| if not isinstance(result, str): | |
| result = json.dumps(result, default=str) | |
| return result | |
| async def _wrapper(**kwargs: Any) -> str: | |
| try: | |
| import asyncio | |
| if inspect.iscoroutinefunction(fn): | |
| result = await fn(**kwargs) | |
| else: | |
| result = await asyncio.to_thread(fn, **kwargs) | |
| except Exception as exc: | |
| result = f"Tool error: {exc}" | |
| if not isinstance(result, str): | |
| result = json.dumps(result, default=str) | |
| return result |
| finished = False | ||
| try: | ||
| for i in range(self.max_turns): | ||
| response = ctrl.step(next_user_message, session["session_id"]) |
| result = inf_ctrl.rollout_batch( | ||
| data=data, | ||
| workflow="examples.experimental.inference_service.tau2_workflow.Tau2AgentServiceWorkflow", | ||
| workflow_kwargs=workflow_kwargs, | ||
| ) |
There was a problem hiding this comment.
rollout_batch should only be used for debugging RL training. For pure rollout, we should orchestrate the agent controller. await agent_ctrl.chat_completion(task) should suffice, given that the agent controller exposing a similar /chat/completion endpoint that runs the agent rather than the model.
In other words, in this example script, inf_ctrl should not be touched except for the initialize and destroy calls.
| ctrl_config = GatewayControllerConfig( | ||
| tokenizer_path=config.tokenizer_path, |
There was a problem hiding this comment.
The new rollout controller should have the same API as the current one, both for backward compatibility and cleaness - we don't need to the the dummy config conversion.
| else: | ||
| raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") | ||
|
|
||
| inf_ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) |
There was a problem hiding this comment.
Name too long. Should be revised in the near future.
| inference_api_key=ctrl_config.admin_api_key or "areal-admin-key", | ||
| ) | ||
|
|
||
| agent_ctrl = AgentServiceController(config=agent_ctrl_config, scheduler=scheduler) |
There was a problem hiding this comment.
Name too long. Should be revised in the near future.
|
|
||
| logger.info("Inference service ready at %s", inf_ctrl.gateway_addr) | ||
|
|
||
| agent_ctrl_config = AgentServiceControllerConfig( |
There was a problem hiding this comment.
(1) Name too long. Should be revised in the near future.
(2) Should expose it in CLI args.
| logger = logging.getLogger("Tau2Workflow") | ||
|
|
||
|
|
||
| class Tau2InferenceWorkflow: |
There was a problem hiding this comment.
This class should be removed IMO
Description
Add data-collection APIs (
new_session,step,set_reward) to the agent-service controller and introduce two new Tau2 rollout workflows for the inference-service examples:Tau2InferenceWorkflow(lightweight, no agent service) andTau2AgentServiceWorkflow(full agent-service integration). Includes a PydanticAI-basedTau2Agentand a newtau2_agent_service_rollout.pyexample script.Related Issue
N/A
Type of Change
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
N/A
Additional Context
Key changes:
areal/experimental/agent_service/controller/controller.py: session lifecycle APIs (new_session,step,set_reward) and data-collection-only mode (num_pairs=0)areal/experimental/inference_service/controller/workflow.py: refactored to delegate session management to agent'srun()return valueexamples/experimental/inference_service/tau2_agent.py: PydanticAI-based Tau2Agentexamples/experimental/inference_service/tau2_workflow.py: two new workflow classesexamples/experimental/inference_service/tau2_agent_service_rollout.py: example scripttests/experimental/agent_service/test_controller.py: tests for new controller APIsRequires:
pydantic-ai,tau2-bench