Skip to content

feat(service): add tau2 agent+inference service rollout example#1226

Closed
nuzant wants to merge 1 commit intomainfrom
mzy/inf-agent-service-example
Closed

feat(service): add tau2 agent+inference service rollout example#1226
nuzant wants to merge 1 commit intomainfrom
mzy/inf-agent-service-example

Conversation

@nuzant
Copy link
Copy Markdown
Collaborator

@nuzant nuzant commented Apr 22, 2026

Description

Add data-collection APIs (new_session, step, set_reward) to the agent-service controller and introduce two new Tau2 rollout workflows for the inference-service examples: Tau2InferenceWorkflow (lightweight, no agent service) and Tau2AgentServiceWorkflow (full agent-service integration). Includes a PydanticAI-based Tau2Agent and a new tau2_agent_service_rollout.py example script.

Related Issue

N/A

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

N/A

Additional Context

Key changes:

  • areal/experimental/agent_service/controller/controller.py: session lifecycle APIs (new_session, step, set_reward) and data-collection-only mode (num_pairs=0)
  • areal/experimental/inference_service/controller/workflow.py: refactored to delegate session management to agent's run() return value
  • examples/experimental/inference_service/tau2_agent.py: PydanticAI-based Tau2Agent
  • examples/experimental/inference_service/tau2_workflow.py: two new workflow classes
  • examples/experimental/inference_service/tau2_agent_service_rollout.py: example script
  • tests/experimental/agent_service/test_controller.py: tests for new controller APIs

Requires: pydantic-ai, tau2-bench

Combine the agent-service controller refactor with the Tau2 rollout examples so the example stack uses a single session lifecycle and data-collection path.

Key changes:
- move session lifecycle and collection APIs into the agent controller
- add Tau2 agent-service rollout example workflow and docs
- align controller tests and gateway/session handling
@nuzant nuzant changed the title feat(archon): add tau2 agent-service rollout workflow feat(service): add tau2 agent+inference service rollout example Apr 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds RL data collection capabilities to the agent service by integrating it with the inference service, specifically for Tau2 benchmarks. Key changes include new session management and reward-setting APIs in the AgentServiceController, refactored inference workflows, and new PydanticAI agent examples. Reviewers identified several issues: the removal of a safety net for finalizing sessions on failure, a memory leak in session tracking, and performance concerns due to synchronous network and tool calls blocking the asynchronous event loop. A potential TypeError in reward processing was also flagged.

Comment on lines +94 to 100
result = await self.agent.run(
data,
base_url=self.gateway_addr,
http_client=http_client,
api_key=self._admin_api_key,
task_id=str(task_id),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This refactoring removed the try...except block that previously ensured a 0.0 reward was set on the inference service if the agent failed. While the new design delegates session management to the agent, a crash in agent.run() before it can report a reward will now leave a dangling RL session on the inference gateway. Consider adding a safety net to ensure sessions are always finalized even on failure.


self._forked_services: list[tuple[str, str, int]] = []

self._sessions: dict[str, dict[str, Any]] = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _sessions dictionary grows indefinitely as new sessions are created via new_session(), but there is no mechanism to remove them once they are completed (e.g., after set_reward()). This will lead to a memory leak in long-running controller instances. Consider adding a way to prune old sessions or at least clearing them in the destroy() method.

Comment on lines +431 to +436
resp = requests.post(
f"{inf_addr}/rl/start_session",
json={"task_id": task_id},
headers={"Authorization": f"Bearer {cfg.inference_api_key}"},
timeout=cfg.request_timeout,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of synchronous requests.post here (and in step and set_reward) will block the event loop when called from asynchronous contexts, such as the Tau2AgentServiceWorkflow introduced in this PR. This can significantly degrade performance and scalability when handling multiple concurrent rollouts. Consider using an asynchronous client like httpx.AsyncClient or wrapping these calls in asyncio.to_thread.

raise
session_id = result["session_id"]
trajectory_id = result["trajectory_id"]
agent_reward = float(result["reward"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line will raise a TypeError if result["reward"] is None. It is safer to handle potential null values explicitly, especially since the validation on line 107 only checks for the presence of the key.

Suggested change
agent_reward = float(result["reward"])
agent_reward = float(result["reward"]) if result["reward"] is not None else 0.0

Comment on lines +41 to +48
async def _wrapper(**kwargs: Any) -> str:
try:
result = fn(**kwargs)
except Exception as exc:
result = f"Tool error: {exc}"
if not isinstance(result, str):
result = json.dumps(result, default=str)
return result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling a synchronous tool function directly inside an async wrapper blocks the event loop. Since this agent is used in an async rollout workflow, this can impact performance. Consider using asyncio.to_thread to run synchronous tool functions.

Suggested change
async def _wrapper(**kwargs: Any) -> str:
try:
result = fn(**kwargs)
except Exception as exc:
result = f"Tool error: {exc}"
if not isinstance(result, str):
result = json.dumps(result, default=str)
return result
async def _wrapper(**kwargs: Any) -> str:
try:
import asyncio
if inspect.iscoroutinefunction(fn):
result = await fn(**kwargs)
else:
result = await asyncio.to_thread(fn, **kwargs)
except Exception as exc:
result = f"Tool error: {exc}"
if not isinstance(result, str):
result = json.dumps(result, default=str)
return result

finished = False
try:
for i in range(self.max_turns):
response = ctrl.step(next_user_message, session["session_id"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling the synchronous ctrl.step() method from an async function blocks the event loop. This will limit the ability to run multiple rollouts in parallel efficiently. Consider wrapping this call in asyncio.to_thread().

@nuzant nuzant marked this pull request as draft April 22, 2026 08:00
Comment on lines +256 to +260
result = inf_ctrl.rollout_batch(
data=data,
workflow="examples.experimental.inference_service.tau2_workflow.Tau2AgentServiceWorkflow",
workflow_kwargs=workflow_kwargs,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollout_batch should only be used for debugging RL training. For pure rollout, we should orchestrate the agent controller. await agent_ctrl.chat_completion(task) should suffice, given that the agent controller exposing a similar /chat/completion endpoint that runs the agent rather than the model.

In other words, in this example script, inf_ctrl should not be touched except for the initialize and destroy calls.

Comment on lines +168 to +169
ctrl_config = GatewayControllerConfig(
tokenizer_path=config.tokenizer_path,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new rollout controller should have the same API as the current one, both for backward compatibility and cleaness - we don't need to the the dummy config conversion.

else:
raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}")

inf_ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name too long. Should be revised in the near future.

inference_api_key=ctrl_config.admin_api_key or "areal-admin-key",
)

agent_ctrl = AgentServiceController(config=agent_ctrl_config, scheduler=scheduler)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name too long. Should be revised in the near future.


logger.info("Inference service ready at %s", inf_ctrl.gateway_addr)

agent_ctrl_config = AgentServiceControllerConfig(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) Name too long. Should be revised in the near future.
(2) Should expose it in CLI args.

logger = logging.getLogger("Tau2Workflow")


class Tau2InferenceWorkflow:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should be removed IMO

@nuzant
Copy link
Copy Markdown
Collaborator Author

nuzant commented Apr 26, 2026

Separated into #1265 and #1266 .

@nuzant nuzant closed this Apr 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants