diff --git a/benchmarks/browsecomp/config.yaml b/benchmarks/browsecomp/config.yaml index c494f5c37..635766846 100644 --- a/benchmarks/browsecomp/config.yaml +++ b/benchmarks/browsecomp/config.yaml @@ -32,7 +32,9 @@ browsecomp_benchmark_agent: max_context_tokens: 196608 context_reset_pct: 0.3 context_reset_keep_rounds: 3 + max_reset_count: null max_run_retries: 3 + snap_dir: null resources_server: type: resources_servers name: browsecomp_benchmark_resources_server diff --git a/resources_servers/browsecomp_advanced_harness/app.py b/resources_servers/browsecomp_advanced_harness/app.py index 14b97a725..643b34c33 100644 --- a/resources_servers/browsecomp_advanced_harness/app.py +++ b/resources_servers/browsecomp_advanced_harness/app.py @@ -159,7 +159,6 @@ class TavilySearchMetrics(BaseModel): class TavilySearchVerifyResponse(TavilySearchVerifyRequest, JudgeEvaluation): - num_tool_calls: int metrics: TavilySearchMetrics @@ -400,7 +399,6 @@ async def verify(self, request: Request, body: TavilySearchVerifyRequest) -> Tav return TavilySearchVerifyResponse( **body.model_dump(), **judge_evaluation.model_dump(), - num_tool_calls=sum(o.type == "function_call" for o in body.response.output), metrics=self._session_id_to_metrics[request.session[SESSION_ID_KEY]], ) diff --git a/responses_api_agents/browsecomp_agent/app.py b/responses_api_agents/browsecomp_agent/app.py index 0b1956de0..79685a2c7 100644 --- a/responses_api_agents/browsecomp_agent/app.py +++ b/responses_api_agents/browsecomp_agent/app.py @@ -14,7 +14,8 @@ # limitations under the License. import json import re -from typing import List +from pathlib import Path +from typing import List, Optional from fastapi import Request, Response from pydantic import ConfigDict, ValidationError @@ -52,7 +53,9 @@ class BrowsecompAgentConfig(BaseResponsesAPIAgentConfig): max_context_tokens: int = 196608 context_reset_pct: float = 0.3 context_reset_keep_rounds: int = 3 + max_reset_count: Optional[int] = None max_run_retries: int = 1 + snap_dir: Optional[str] = None class BrowsecompAgentRunRequest(BaseRunRequest): @@ -81,13 +84,22 @@ async def responses( if isinstance(body.input, str): body.input = [NeMoGymEasyInputMessage(role="user", content=body.input)] + task_index, attempt = None, None + if self.config.snap_dir: + task_index = body.metadata.pop("task_index") + attempt = body.metadata.pop("attempt") + body.metadata = body.metadata or {} + new_outputs = [] usage = None step = 0 + num_tool_calls = 0 model_server_cookies = None # update the cookies on every model response resources_server_cookies = request.cookies # update the cookies on every resources server response reset_threshold = 0 + reset_count = 0 + max_reset_count = self.config.max_reset_count if self.config.max_context_tokens and self.config.context_reset_pct: reset_threshold = int(self.config.max_context_tokens * self.config.context_reset_pct) @@ -98,6 +110,8 @@ async def responses( new_outputs = self._compact_old_tool_messages(new_outputs) new_body = body.model_copy(update={"input": body.input + new_outputs}) + if not body.metadata: + new_body = new_body.model_dump(exclude={"metadata"}, exclude_none=True) model_response = await self.server_client.post( server_name=self.config.model_server.name, @@ -118,7 +132,22 @@ async def responses( # --- Check context reset threshold --- prompt_tokens = model_response.usage.input_tokens if model_response.usage else 0 - if reset_threshold and prompt_tokens > reset_threshold: + if ( + reset_threshold + and prompt_tokens > reset_threshold + and (max_reset_count is None or reset_count < max_reset_count) + ): + reset_count += 1 + # record current context + if self.config.snap_dir: + self._save_snapshot( + messages=body.input + new_outputs, + task_index=task_index, + attempt=attempt, + reset_count=reset_count, + is_final=False, + ) + # reset context if self.config.context_reset_keep_rounds > 0: new_outputs = self._extract_last_rounds(new_outputs) else: @@ -154,6 +183,7 @@ async def responses( # --- Execute tool calls --- for output_function_call in all_fn_calls: + num_tool_calls += 1 api_response = await self.server_client.post( server_name=self.config.resources_server.name, url_path=f"/{output_function_call.name}", @@ -216,12 +246,24 @@ async def responses( if self.config.max_steps and step >= self.config.max_steps: break + # record final context + if self.config.snap_dir: + self._save_snapshot( + messages=body.input + new_outputs, + task_index=task_index, + attempt=attempt, + reset_count=None, + is_final=True, + ) + # Propogate any extra cookies necessary for downstream verification for k, v in (*resources_server_cookies.items(), *model_server_cookies.items()): response.set_cookie(k, v) model_response.output = new_outputs model_response.usage = usage + model_response.reset_count = reset_count + model_response.num_tool_calls = num_tool_calls return model_response async def run(self, request: Request, body: BrowsecompAgentRunRequest) -> BrowsecompAgentVerifyResponse: @@ -238,6 +280,12 @@ async def run(self, request: Request, body: BrowsecompAgentRunRequest) -> Browse last_verify_response = None for attempt in range(self.config.max_run_retries): + # prepare for recording + if self.config.snap_dir: + body.responses_create_params.metadata = dict(body.responses_create_params.metadata or {}) + body.responses_create_params.metadata["task_index"] = str(body._ng_task_index) + body.responses_create_params.metadata["attempt"] = str(attempt) + response = await self.server_client.post( server_name=self.config.name, url_path="/v1/responses", @@ -337,6 +385,20 @@ def _extract_last_rounds(self, new_outputs): result.extend(tool_outputs) return result + def _save_snapshot(self, messages, task_index, attempt, reset_count, is_final): + sample_dir = Path(f"{self.config.snap_dir}/sample_{task_index}") + if not sample_dir.exists(): + sample_dir.mkdir(parents=True) + + if is_final: + sample_path = f"{sample_dir}/attempt_{attempt}_final.jsonl" + else: + sample_path = f"{sample_dir}/attempt_{attempt}_reset_{reset_count}.jsonl" + + with open(sample_path, "w", encoding="utf-8") as f: + for msg in messages: + f.write(msg.model_dump_json() + "\n") + if __name__ == "__main__": BrowsecompAgent.run_webserver()