Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/browsecomp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ browsecomp_benchmark_agent:
max_context_tokens: 196608
context_reset_pct: 0.3
context_reset_keep_rounds: 3
max_reset_count: null
max_run_retries: 3
snap_dir: null
resources_server:
type: resources_servers
name: browsecomp_benchmark_resources_server
Expand Down
2 changes: 0 additions & 2 deletions resources_servers/browsecomp_advanced_harness/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ class TavilySearchMetrics(BaseModel):


class TavilySearchVerifyResponse(TavilySearchVerifyRequest, JudgeEvaluation):
num_tool_calls: int
metrics: TavilySearchMetrics


Expand Down Expand Up @@ -400,7 +399,6 @@ async def verify(self, request: Request, body: TavilySearchVerifyRequest) -> Tav
return TavilySearchVerifyResponse(
**body.model_dump(),
**judge_evaluation.model_dump(),
num_tool_calls=sum(o.type == "function_call" for o in body.response.output),
metrics=self._session_id_to_metrics[request.session[SESSION_ID_KEY]],
)

Expand Down
66 changes: 64 additions & 2 deletions responses_api_agents/browsecomp_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# limitations under the License.
import json
import re
from typing import List
from pathlib import Path
from typing import List, Optional

from fastapi import Request, Response
from pydantic import ConfigDict, ValidationError
Expand Down Expand Up @@ -52,7 +53,9 @@ class BrowsecompAgentConfig(BaseResponsesAPIAgentConfig):
max_context_tokens: int = 196608
context_reset_pct: float = 0.3
context_reset_keep_rounds: int = 3
max_reset_count: Optional[int] = None
max_run_retries: int = 1
snap_dir: Optional[str] = None


class BrowsecompAgentRunRequest(BaseRunRequest):
Expand Down Expand Up @@ -81,13 +84,22 @@ async def responses(
if isinstance(body.input, str):
body.input = [NeMoGymEasyInputMessage(role="user", content=body.input)]

task_index, attempt = None, None
if self.config.snap_dir:
task_index = body.metadata.pop("task_index")
attempt = body.metadata.pop("attempt")
body.metadata = body.metadata or {}

new_outputs = []
usage = None
step = 0
num_tool_calls = 0
model_server_cookies = None # update the cookies on every model response
resources_server_cookies = request.cookies # update the cookies on every resources server response

reset_threshold = 0
reset_count = 0
max_reset_count = self.config.max_reset_count
if self.config.max_context_tokens and self.config.context_reset_pct:
reset_threshold = int(self.config.max_context_tokens * self.config.context_reset_pct)

Expand All @@ -98,6 +110,8 @@ async def responses(
new_outputs = self._compact_old_tool_messages(new_outputs)

new_body = body.model_copy(update={"input": body.input + new_outputs})
if not body.metadata:
new_body = new_body.model_dump(exclude={"metadata"}, exclude_none=True)

model_response = await self.server_client.post(
server_name=self.config.model_server.name,
Expand All @@ -118,7 +132,22 @@ async def responses(

# --- Check context reset threshold ---
prompt_tokens = model_response.usage.input_tokens if model_response.usage else 0
if reset_threshold and prompt_tokens > reset_threshold:
if (
reset_threshold
and prompt_tokens > reset_threshold
and (max_reset_count is None or reset_count < max_reset_count)
):
reset_count += 1
# record current context
if self.config.snap_dir:
self._save_snapshot(
messages=body.input + new_outputs,
task_index=task_index,
attempt=attempt,
reset_count=reset_count,
is_final=False,
)
# reset context
if self.config.context_reset_keep_rounds > 0:
new_outputs = self._extract_last_rounds(new_outputs)
else:
Expand Down Expand Up @@ -154,6 +183,7 @@ async def responses(

# --- Execute tool calls ---
for output_function_call in all_fn_calls:
num_tool_calls += 1
api_response = await self.server_client.post(
server_name=self.config.resources_server.name,
url_path=f"/{output_function_call.name}",
Expand Down Expand Up @@ -216,12 +246,24 @@ async def responses(
if self.config.max_steps and step >= self.config.max_steps:
break

# record final context
if self.config.snap_dir:
self._save_snapshot(
messages=body.input + new_outputs,
task_index=task_index,
attempt=attempt,
reset_count=None,
is_final=True,
)

# Propogate any extra cookies necessary for downstream verification
for k, v in (*resources_server_cookies.items(), *model_server_cookies.items()):
response.set_cookie(k, v)

model_response.output = new_outputs
model_response.usage = usage
model_response.reset_count = reset_count
model_response.num_tool_calls = num_tool_calls
return model_response

async def run(self, request: Request, body: BrowsecompAgentRunRequest) -> BrowsecompAgentVerifyResponse:
Expand All @@ -238,6 +280,12 @@ async def run(self, request: Request, body: BrowsecompAgentRunRequest) -> Browse

last_verify_response = None
for attempt in range(self.config.max_run_retries):
# prepare for recording
if self.config.snap_dir:
body.responses_create_params.metadata = dict(body.responses_create_params.metadata or {})
body.responses_create_params.metadata["task_index"] = str(body._ng_task_index)
body.responses_create_params.metadata["attempt"] = str(attempt)

response = await self.server_client.post(
server_name=self.config.name,
url_path="/v1/responses",
Expand Down Expand Up @@ -337,6 +385,20 @@ def _extract_last_rounds(self, new_outputs):
result.extend(tool_outputs)
return result

def _save_snapshot(self, messages, task_index, attempt, reset_count, is_final):
sample_dir = Path(f"{self.config.snap_dir}/sample_{task_index}")
if not sample_dir.exists():
sample_dir.mkdir(parents=True)

if is_final:
sample_path = f"{sample_dir}/attempt_{attempt}_final.jsonl"
else:
sample_path = f"{sample_dir}/attempt_{attempt}_reset_{reset_count}.jsonl"

with open(sample_path, "w", encoding="utf-8") as f:
for msg in messages:
f.write(msg.model_dump_json() + "\n")


if __name__ == "__main__":
BrowsecompAgent.run_webserver()
Loading