Skip to content

Commit d9a5516

Browse files
authored
Trajectory collection fix (#93)
* fix err msg * fix trajectory collection
1 parent 3144a76 commit d9a5516

1 file changed

Lines changed: 19 additions & 14 deletions

File tree

eval_protocol/mcp/execution/manager.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union
1616

1717
import anyio
18+
import httpx
1819
from openai.types import CompletionUsage
1920

2021
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
@@ -221,7 +222,7 @@ async def _execute_rollout(
221222
current_observation = user_message.content if user_message.content else ""
222223

223224
user_prompt = envs.format_user_prompt(rollout_idx, current_observation)
224-
conversation_history = [
225+
trajectory.conversation_history = [
225226
{"role": "system", "content": system_prompt},
226227
{"role": "user", "content": user_prompt},
227228
]
@@ -241,7 +242,7 @@ async def _execute_rollout(
241242

242243
if user_simulator and user_simulator_state:
243244
# Get user simulator messages and find the last assistant message
244-
user_simulator_messages = self._get_user_simulator_messages(conversation_history)
245+
user_simulator_messages = self._get_user_simulator_messages(trajectory.conversation_history)
245246

246247
# Last message was agent, simulated user response
247248
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
@@ -252,7 +253,7 @@ async def _execute_rollout(
252253
user_content = user_message.content if user_message.content else ""
253254

254255
user_prompt = envs.format_user_prompt(rollout_idx, user_content)
255-
conversation_history.append({"role": "user", "content": user_prompt})
256+
trajectory.conversation_history.append({"role": "user", "content": user_prompt})
256257

257258
# Check if user simulator signaled termination
258259
if UserSimulator.is_stop(user_message):
@@ -262,7 +263,7 @@ async def _execute_rollout(
262263
# In each turn: keep looping until assistant is ready to provide final response
263264
while not turn_completed and not trajectory.terminated:
264265
tool_calls, usage_stats, finish_reason = await policy(
265-
tool_schema, rollout_idx, conversation_history
266+
tool_schema, rollout_idx, trajectory.conversation_history
266267
)
267268

268269
# calc llm usage stats happened in this turn if there is aany
@@ -294,7 +295,7 @@ async def _execute_rollout(
294295
rollout_idx,
295296
tool_call,
296297
tool_response,
297-
conversation_history,
298+
trajectory.conversation_history,
298299
reward,
299300
env_end,
300301
info,
@@ -325,12 +326,14 @@ async def _execute_rollout(
325326
"num_tool_calls": 1,
326327
}
327328
print(f"🔍 control_plane_step: {control_plane_step}")
328-
conversation_history[-1]["control_plane_step"] = control_plane_step
329+
trajectory.conversation_history[-1]["control_plane_step"] = control_plane_step
329330
trajectory.control_plane_steps.append(control_plane_step)
330331

331332
# Log conversation state for playback if in recording mode
332333
if recording_mode:
333-
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
334+
policy.log_conversation_state_for_playback(
335+
rollout_idx, step - 1, trajectory.conversation_history
336+
)
334337

335338
if env_end:
336339
# if the env marks the end of the rollout, break the tool call loop
@@ -364,17 +367,21 @@ async def _execute_rollout(
364367
"tool_calls": tool_calls_summary,
365368
"num_tool_calls": len(tool_calls),
366369
}
367-
conversation_history[-1]["control_plane_step"] = control_plane_step
370+
trajectory.conversation_history[-1]["control_plane_step"] = control_plane_step
368371
trajectory.control_plane_steps.append(control_plane_step)
369372

370373
# Log conversation state for playback if in recording mode
371374
if recording_mode:
372-
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
375+
policy.log_conversation_state_for_playback(
376+
rollout_idx, step - 1, trajectory.conversation_history
377+
)
373378

374379
# if the env marks end, update control plane summary and do one last policy call, then break the agent loop
375380
# this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior
376381
if env_end:
377-
_, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history)
382+
_, usage_stats, finish_reason = await policy(
383+
tool_schema, rollout_idx, trajectory.conversation_history
384+
)
378385
if usage_stats:
379386
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
380387
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
@@ -392,10 +399,10 @@ async def _execute_rollout(
392399

393400
# Log final OpenAI conversation for terminated trajectories only
394401
if openai_logger:
395-
if conversation_history and len(conversation_history) > 0:
402+
if trajectory.conversation_history and len(trajectory.conversation_history) > 0:
396403
openai_logger(
397404
{
398-
"messages": conversation_history,
405+
"messages": trajectory.conversation_history,
399406
"metadata": {
400407
"session_id": session.session_id,
401408
"seed": session.seed,
@@ -421,8 +428,6 @@ async def _execute_rollout(
421428
if not trajectory.termination_reason and step >= steps:
422429
trajectory.termination_reason = TerminationReason.MAX_STEPS
423430

424-
trajectory.conversation_history = conversation_history
425-
426431
# Add termination_reason to the final control_plane_step
427432
for msg in reversed(trajectory.conversation_history):
428433
if msg.get("control_plane_step"):

0 commit comments

Comments
 (0)