Skip to content

Commit ceebff9

Browse files
Add Utils for Multi-turn Evaluation (#46325)
* Add multi-turn eval utils * run black * update utils to use constants * move import * rename trace to turn * update messages validation and utils * run black
1 parent 89c1029 commit ceebff9

6 files changed

Lines changed: 1222 additions & 1 deletion

File tree

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ class EvaluatorScoringPattern(Enum):
6161
SCALE_1_5 = "scale_1_5" # 1-5 scale (quality evaluators)
6262

6363

64+
class EvaluationLevel(str, Enum):
65+
"""Supported evaluation levels for multi-turn evaluators.
66+
67+
- ``CONVERSATION``: Force conversation-level evaluation using the multi-turn path.
68+
- ``TURN``: Force turn-level evaluation using the single-turn query/response path.
69+
"""
70+
71+
CONVERSATION = "conversation"
72+
TURN = "turn"
73+
74+
6475
class Tasks:
6576
"""Defines types of annotation tasks supported by RAI Service."""
6677

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/utils.py

Lines changed: 303 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import math
88
import threading
9-
from typing import Any, List, Literal, Mapping, Optional, Type, TypeVar, Tuple, Union, cast, get_args, get_origin
9+
from typing import Any, Dict, List, Literal, Mapping, Optional, Type, TypeVar, Tuple, Union, cast, get_args, get_origin
1010

1111
import nltk
1212
from azure.storage.blob import ContainerClient
@@ -962,3 +962,305 @@ def upload(path: str, container_client: ContainerClient, logger=None):
962962
category=ErrorCategory.UPLOAD_ERROR,
963963
blame=ErrorBlame.SYSTEM_ERROR,
964964
) from e
965+
966+
967+
# region Multi-turn utilities
968+
969+
970+
def _merge_query_response_messages(query, response):
971+
"""Merge query and response message lists into a single conversation.
972+
973+
:param query: The query messages.
974+
:type query: List[dict]
975+
:param response: The response messages.
976+
:type response: List[dict]
977+
:return: The merged conversation messages.
978+
:rtype: List[dict]
979+
"""
980+
return [*query, *response]
981+
982+
983+
def _split_messages_at_latest_user(messages):
984+
"""Split messages into query/response slices at the latest user turn.
985+
986+
:param messages: The conversation messages.
987+
:type messages: List[dict]
988+
:return: A tuple of (query_messages, response_messages).
989+
:rtype: Tuple[List[dict], List[dict]]
990+
"""
991+
latest_user_index = max(i for i, message in enumerate(messages) if message["role"] == "user")
992+
return messages[: latest_user_index + 1], messages[latest_user_index + 1 :]
993+
994+
995+
def _wrap_string_messages(query, response):
996+
"""Wrap string query/response into separate message lists.
997+
998+
:param query: The query string.
999+
:type query: str
1000+
:param response: The response string.
1001+
:type response: str
1002+
:return: A tuple of (query_messages, response_messages).
1003+
:rtype: Tuple[List[dict], List[dict]]
1004+
"""
1005+
return (
1006+
[{"role": "user", "content": [{"type": "text", "text": query}]}],
1007+
[{"role": "assistant", "content": [{"type": "text", "text": response}]}],
1008+
)
1009+
1010+
1011+
def serialize_messages(messages):
1012+
"""Serialize a list of chat messages into a labeled text transcript for multi-turn prompts.
1013+
1014+
**Input format:** List of message dicts, each with ``"role"`` (``user``, ``assistant``, ``tool``,
1015+
``system``, ``developer``) and ``"content"`` (string or list of content-block dicts like
1016+
``{"type": "text", "text": "..."}``). Tool messages may include ``tool_call_id`` and content
1017+
blocks of type ``tool_result``/``tool_call``.
1018+
1019+
**Output format:** Plain-text transcript with labeled turns::
1020+
1021+
User turn 1:
1022+
<user text>
1023+
1024+
Agent turn 1:
1025+
<assistant text>
1026+
[TOOL_CALL] func_name({"arg": "val"})
1027+
[TOOL_RESULT] <result>
1028+
1029+
User turn 2:
1030+
<user text>
1031+
...
1032+
1033+
System/developer messages are included as a system preamble. Consecutive messages of the same
1034+
role are grouped into a single turn. Assistant string content is auto-normalized to content-block
1035+
format for consistent formatting.
1036+
1037+
:param messages: Chat messages with role and content.
1038+
:type messages: List[dict]
1039+
:return: Formatted text transcript.
1040+
:rtype: str
1041+
"""
1042+
if not messages:
1043+
return ""
1044+
1045+
from azure.ai.evaluation._evaluators._common._validators._validation_constants import MessageRole
1046+
1047+
all_user_queries = []
1048+
all_agent_responses = []
1049+
cur_user_query = []
1050+
cur_agent_response = []
1051+
system_message = None
1052+
1053+
for msg in messages:
1054+
if not isinstance(msg, dict):
1055+
continue
1056+
role = msg.get("role")
1057+
if not role:
1058+
continue
1059+
1060+
# _get_agent_response expects content as list of dicts, not a plain string
1061+
normalized = msg
1062+
if role == MessageRole.ASSISTANT and isinstance(msg.get("content"), str):
1063+
normalized = {**msg, "content": [{"type": "text", "text": msg["content"]}]}
1064+
1065+
if role in (MessageRole.SYSTEM, MessageRole.DEVELOPER):
1066+
content = msg.get("content", "")
1067+
if isinstance(content, list):
1068+
system_message = "\n".join(_extract_text_from_content(content))
1069+
else:
1070+
system_message = content
1071+
1072+
elif role == MessageRole.USER and "content" in msg:
1073+
if cur_agent_response:
1074+
formatted = _get_agent_response(cur_agent_response, include_tool_messages=True)
1075+
all_agent_responses.append([formatted])
1076+
cur_agent_response = []
1077+
content = msg["content"]
1078+
if isinstance(content, str):
1079+
text_in_msg = [content]
1080+
else:
1081+
text_in_msg = _extract_text_from_content(content)
1082+
if text_in_msg:
1083+
cur_user_query.append(text_in_msg)
1084+
1085+
elif role in (MessageRole.ASSISTANT, MessageRole.TOOL):
1086+
if cur_user_query:
1087+
all_user_queries.append(cur_user_query)
1088+
cur_user_query = []
1089+
cur_agent_response.append(normalized)
1090+
1091+
# Flush any remaining buffered turn
1092+
if cur_user_query:
1093+
all_user_queries.append(cur_user_query)
1094+
if cur_agent_response:
1095+
formatted = _get_agent_response(cur_agent_response, include_tool_messages=True)
1096+
all_agent_responses.append([formatted])
1097+
1098+
conversation_history: Dict = {
1099+
"user_queries": all_user_queries,
1100+
"agent_responses": all_agent_responses[: len(all_user_queries) - 1] if len(all_user_queries) > 0 else [],
1101+
}
1102+
if system_message:
1103+
conversation_history["system_message"] = system_message
1104+
1105+
result = _pretty_format_conversation_history(conversation_history)
1106+
1107+
# Append any trailing agent turn (the final response after the last user query)
1108+
start = max(len(all_user_queries) - 1, 0)
1109+
for i, agent_response in enumerate(all_agent_responses[start:], start=start):
1110+
result += f"Agent turn {i + 1}:\n"
1111+
for msg_text in agent_response:
1112+
if isinstance(msg_text, list):
1113+
for submsg in msg_text:
1114+
result += " " + "\n ".join(submsg.split("\n")) + "\n"
1115+
else:
1116+
result += " " + "\n ".join(msg_text.split("\n")) + "\n"
1117+
result += "\n"
1118+
1119+
return result.rstrip("\n")
1120+
1121+
1122+
def _resolve_evaluation_level(evaluation_level, error_target):
1123+
"""Validate and normalize the evaluation_level parameter.
1124+
1125+
:param evaluation_level: The evaluation level to resolve.
1126+
:type evaluation_level: Optional[Union[EvaluationLevel, str]]
1127+
:param error_target: The error target for exceptions.
1128+
:type error_target: ErrorTarget
1129+
:return: The resolved EvaluationLevel or None for auto-detect.
1130+
:rtype: Optional[EvaluationLevel]
1131+
"""
1132+
from .constants import EvaluationLevel
1133+
1134+
valid = [level.value for level in EvaluationLevel]
1135+
if evaluation_level is None or evaluation_level == "":
1136+
return None
1137+
if isinstance(evaluation_level, EvaluationLevel):
1138+
return evaluation_level
1139+
if isinstance(evaluation_level, str):
1140+
try:
1141+
return EvaluationLevel(evaluation_level)
1142+
except ValueError:
1143+
raise EvaluationException(
1144+
message=(f"Invalid evaluation_level '{evaluation_level}'. " f"Must be one of: {valid}."),
1145+
blame=ErrorBlame.USER_ERROR,
1146+
category=ErrorCategory.INVALID_VALUE,
1147+
target=error_target,
1148+
)
1149+
raise EvaluationException(
1150+
message=(f"Invalid evaluation_level '{evaluation_level}'. " f"Must be one of: {valid}."),
1151+
blame=ErrorBlame.USER_ERROR,
1152+
category=ErrorCategory.INVALID_VALUE,
1153+
target=error_target,
1154+
)
1155+
1156+
1157+
def _is_intermediate_response(response):
1158+
"""Check if response is intermediate (last content item is function_call or mcp_approval_request).
1159+
1160+
An intermediate response is one where the assistant's last message ends with a
1161+
function_call or mcp_approval_request content type, meaning the conversation is
1162+
still in progress and not yet ready for evaluation.
1163+
1164+
:param response: The response messages.
1165+
:type response: List[dict]
1166+
:return: True if the response is intermediate, False otherwise.
1167+
:rtype: bool
1168+
"""
1169+
if isinstance(response, list) and len(response) > 0:
1170+
last_msg = response[-1]
1171+
if isinstance(last_msg, dict) and last_msg.get("role") == "assistant":
1172+
content = last_msg.get("content", [])
1173+
if isinstance(content, list) and len(content) > 0:
1174+
last_content = content[-1]
1175+
if isinstance(last_content, dict) and last_content.get("type") in (
1176+
"function_call",
1177+
"mcp_approval_request",
1178+
):
1179+
return True
1180+
return False
1181+
1182+
1183+
def _drop_mcp_approval_messages(messages):
1184+
"""Remove MCP approval request/response messages from a conversation.
1185+
1186+
MCP approval messages are protocol-level messages that should not be included
1187+
in the evaluation input.
1188+
1189+
:param messages: The conversation messages.
1190+
:type messages: List[dict]
1191+
:return: The filtered messages without MCP approval request/response messages.
1192+
:rtype: List[dict]
1193+
"""
1194+
if not isinstance(messages, list):
1195+
return messages
1196+
return [
1197+
msg
1198+
for msg in messages
1199+
if not (
1200+
isinstance(msg, dict)
1201+
and isinstance(msg.get("content"), list)
1202+
and (
1203+
(
1204+
msg.get("role") == "assistant"
1205+
and any(isinstance(c, dict) and c.get("type") == "mcp_approval_request" for c in msg["content"])
1206+
)
1207+
or (
1208+
msg.get("role") == "tool"
1209+
and any(isinstance(c, dict) and c.get("type") == "mcp_approval_response" for c in msg["content"])
1210+
)
1211+
)
1212+
)
1213+
]
1214+
1215+
1216+
def _normalize_function_call_types(messages):
1217+
"""Normalize function_call/function_call_output/openapi_call/openapi_call_output types to tool_call/tool_result.
1218+
1219+
This ensures a consistent content type vocabulary for downstream evaluators
1220+
regardless of how the original messages were authored.
1221+
1222+
:param messages: The conversation messages.
1223+
:type messages: List[dict]
1224+
:return: The messages with normalized content types.
1225+
:rtype: List[dict]
1226+
"""
1227+
if not isinstance(messages, list):
1228+
return messages
1229+
for msg in messages:
1230+
if not isinstance(msg, dict) or not isinstance(msg.get("content"), list):
1231+
continue
1232+
for item in msg["content"]:
1233+
if not isinstance(item, dict):
1234+
continue
1235+
t = item.get("type")
1236+
if t == "function_call":
1237+
item["type"] = "tool_call"
1238+
elif t == "function_call_output":
1239+
item["type"] = "tool_result"
1240+
if "function_call_output" in item:
1241+
item["tool_result"] = item.pop("function_call_output")
1242+
elif t == "openapi_call":
1243+
item["type"] = "tool_call"
1244+
elif t == "openapi_call_output":
1245+
item["type"] = "tool_result"
1246+
if "openapi_call_output" in item:
1247+
item["tool_result"] = item.pop("openapi_call_output")
1248+
return messages
1249+
1250+
1251+
def _preprocess_messages(messages):
1252+
"""Preprocess conversation messages by dropping MCP approval messages and normalizing function call types.
1253+
1254+
This should be called before passing messages to serialization or evaluation functions.
1255+
1256+
:param messages: The conversation messages.
1257+
:type messages: List[dict]
1258+
:return: The preprocessed messages.
1259+
:rtype: List[dict]
1260+
"""
1261+
messages = _drop_mcp_approval_messages(messages)
1262+
messages = _normalize_function_call_types(messages)
1263+
return messages
1264+
1265+
1266+
# endregion Multi-turn utilities

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_validators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from ._tool_definitions_validator import ToolDefinitionsValidator
99
from ._tool_calls_validator import ToolCallsValidator
1010
from ._task_navigation_efficiency_validator import TaskNavigationEfficiencyValidator
11+
from ._messages_validator import MessagesOrQueryResponseInputValidator
1112

1213
__all__ = [
1314
"ValidatorInterface",
1415
"ConversationValidator",
1516
"ToolDefinitionsValidator",
1617
"ToolCallsValidator",
1718
"TaskNavigationEfficiencyValidator",
19+
"MessagesOrQueryResponseInputValidator",
1820
]

0 commit comments

Comments
 (0)