Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/unit/vertexai/genai/replays/test_run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,63 @@ def test_inference_with_eval_cases_multi_turn_agent_data(client):
assert "agent_data" in inference_result.eval_dataset_df.columns


def test_inference_with_eval_cases_agent_engine_agent_data(client):
"""Tests N+1 inference with agent_data via remote Agent Engine."""
agent_engine = client.agent_engines.get(
name="projects/977012026409/locations/us-central1"
"/reasoningEngines/7188347537655332864"
)

eval_case = types.EvalCase(
agent_data=types.evals.AgentData(
turns=[
types.evals.ConversationTurn(
turn_index=0,
events=[
types.evals.AgentEvent(
author="user",
content=genai_types.Content(
role="user",
parts=[genai_types.Part(text="My name is Bob.")],
),
),
types.evals.AgentEvent(
author="model",
content=genai_types.Content(
role="model",
parts=[
genai_types.Part(text="Hi Bob! Nice to meet you.")
],
),
),
],
),
types.evals.ConversationTurn(
turn_index=1,
events=[
types.evals.AgentEvent(
author="user",
content=genai_types.Content(
role="user",
parts=[genai_types.Part(text="What is my name?")],
),
),
],
),
],
),
)
eval_dataset = types.EvaluationDataset(eval_cases=[eval_case])

inference_result = client.evals.run_inference(
agent=agent_engine,
src=eval_dataset,
)
assert isinstance(inference_result, types.EvaluationDataset)
assert inference_result.eval_dataset_df is not None
assert "agent_data" in inference_result.eval_dataset_df.columns


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down
101 changes: 91 additions & 10 deletions vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,19 +2136,100 @@ def _execute_agent_run_with_retry(
max_retries: int = 3,
) -> Union[list[dict[str, Any]], dict[str, Any]]:
"""Executes agent run over agent engine for a single prompt."""
# TODO(b/507976585): Support agent_data history replay for Agent Engine
# sessions. Requires appending history events to the remote session via
# the Sessions API before calling stream_query.
# Agent data with conversation history — pre-populate the remote session
# with prior turns via the Sessions API, then query with the last user
# message only.
if AGENT_DATA in row.index and row.get(AGENT_DATA) is not None:
raise NotImplementedError(
"Conversation history replay from agent_data is not yet supported"
" for remote Agent Engine inference. Use a local ADK agent"
" (LlmAgent) instead, or provide a DataFrame with a 'prompt'"
" column."
)
agent_data_obj = row[AGENT_DATA]
if isinstance(agent_data_obj, dict):
agent_data_obj = types.evals.AgentData.model_validate(agent_data_obj)
if isinstance(agent_data_obj, types.evals.AgentData) and agent_data_obj.turns:
try:
last_user_content, history_events = _extract_prompt_from_agent_data(
agent_data_obj
)
except ValueError as e:
return {"error": f"Invalid agent_data for inference: {e}"}

user_id = str(uuid.uuid4())
session_state = None
if "session_inputs" in row.index and row.get("session_inputs") is not None:
si = _get_session_inputs(row)
user_id = si.user_id or user_id
session_state = si.state

try:
session_id = _create_agent_engine_session(
agent_engine=agent_engine,
user_id=user_id,
session_state=session_state,
)
except Exception as e: # pylint: disable=broad-exception-caught
return {"error": f"Failed to create session: {e}"}

# Pre-populate remote session with history events.
if agent_engine.api_resource is None:
return {"error": "agent_engine.api_resource is None."}
if agent_engine.api_client is None:
return {"error": "agent_engine.api_client is None."}
session_name = f"{agent_engine.api_resource.name}/sessions/{session_id}"
# Use a fixed base timestamp for history events so that
# replay tests produce deterministic request bodies.
base_ts = datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc)
for i, ag_event in enumerate(history_events):
agent_engine.api_client.sessions.events.append(
name=session_name,
author=ag_event.author or "user",
invocation_id="history",
timestamp=base_ts + datetime.timedelta(seconds=i),
config=types.AppendAgentEngineSessionEventConfig(
content=ag_event.content,
),
)

last_user_text = _evals_data_converters._get_content_text(last_user_content)
for attempt in range(max_retries):
try:
responses = []
for event in agent_engine.stream_query( # type: ignore[attr-defined]
user_id=user_id,
session_id=session_id,
message=last_user_text,
):
if event and CONTENT in event and PARTS in event[CONTENT]:
responses.append(event)
return responses
except api_exceptions.ResourceExhausted as e:
logger.warning(
"Resource Exhausted error on attempt %d/%d: %s."
" Retrying in %s seconds...",
attempt + 1,
max_retries,
e,
2**attempt,
)
if attempt == max_retries - 1:
return {"error": (f"Resource exhausted after retries: {e}")}
time.sleep(2**attempt)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Unexpected error during agent engine run on attempt %d/%d: %s",
attempt + 1,
max_retries,
e,
)
if attempt == max_retries - 1:
return {"error": f"Failed after retries: {e}"}
time.sleep(1)
return {
"error": (
f"Failed to get agent run results after {max_retries} retries"
)
}

try:
session_inputs = _get_session_inputs(row)
user_id = session_inputs.user_id
user_id = session_inputs.user_id or str(uuid.uuid4())
session_state = session_inputs.state
session_id = _create_agent_engine_session(
agent_engine=agent_engine,
Expand Down
Loading