|
6 | 6 | import re |
7 | 7 | import math |
8 | 8 | import threading |
9 | | -from typing import Any, List, Literal, Mapping, Optional, Type, TypeVar, Tuple, Union, cast, get_args, get_origin |
| 9 | +from typing import Any, Dict, List, Literal, Mapping, Optional, Type, TypeVar, Tuple, Union, cast, get_args, get_origin |
10 | 10 |
|
11 | 11 | import nltk |
12 | 12 | from azure.storage.blob import ContainerClient |
@@ -962,3 +962,305 @@ def upload(path: str, container_client: ContainerClient, logger=None): |
962 | 962 | category=ErrorCategory.UPLOAD_ERROR, |
963 | 963 | blame=ErrorBlame.SYSTEM_ERROR, |
964 | 964 | ) from e |
| 965 | + |
| 966 | + |
| 967 | +# region Multi-turn utilities |
| 968 | + |
| 969 | + |
| 970 | +def _merge_query_response_messages(query, response): |
| 971 | + """Merge query and response message lists into a single conversation. |
| 972 | +
|
| 973 | + :param query: The query messages. |
| 974 | + :type query: List[dict] |
| 975 | + :param response: The response messages. |
| 976 | + :type response: List[dict] |
| 977 | + :return: The merged conversation messages. |
| 978 | + :rtype: List[dict] |
| 979 | + """ |
| 980 | + return [*query, *response] |
| 981 | + |
| 982 | + |
| 983 | +def _split_messages_at_latest_user(messages): |
| 984 | + """Split messages into query/response slices at the latest user turn. |
| 985 | +
|
| 986 | + :param messages: The conversation messages. |
| 987 | + :type messages: List[dict] |
| 988 | + :return: A tuple of (query_messages, response_messages). |
| 989 | + :rtype: Tuple[List[dict], List[dict]] |
| 990 | + """ |
| 991 | + latest_user_index = max(i for i, message in enumerate(messages) if message["role"] == "user") |
| 992 | + return messages[: latest_user_index + 1], messages[latest_user_index + 1 :] |
| 993 | + |
| 994 | + |
| 995 | +def _wrap_string_messages(query, response): |
| 996 | + """Wrap string query/response into separate message lists. |
| 997 | +
|
| 998 | + :param query: The query string. |
| 999 | + :type query: str |
| 1000 | + :param response: The response string. |
| 1001 | + :type response: str |
| 1002 | + :return: A tuple of (query_messages, response_messages). |
| 1003 | + :rtype: Tuple[List[dict], List[dict]] |
| 1004 | + """ |
| 1005 | + return ( |
| 1006 | + [{"role": "user", "content": [{"type": "text", "text": query}]}], |
| 1007 | + [{"role": "assistant", "content": [{"type": "text", "text": response}]}], |
| 1008 | + ) |
| 1009 | + |
| 1010 | + |
| 1011 | +def serialize_messages(messages): |
| 1012 | + """Serialize a list of chat messages into a labeled text transcript for multi-turn prompts. |
| 1013 | +
|
| 1014 | + **Input format:** List of message dicts, each with ``"role"`` (``user``, ``assistant``, ``tool``, |
| 1015 | + ``system``, ``developer``) and ``"content"`` (string or list of content-block dicts like |
| 1016 | + ``{"type": "text", "text": "..."}``). Tool messages may include ``tool_call_id`` and content |
| 1017 | + blocks of type ``tool_result``/``tool_call``. |
| 1018 | +
|
| 1019 | + **Output format:** Plain-text transcript with labeled turns:: |
| 1020 | +
|
| 1021 | + User turn 1: |
| 1022 | + <user text> |
| 1023 | +
|
| 1024 | + Agent turn 1: |
| 1025 | + <assistant text> |
| 1026 | + [TOOL_CALL] func_name({"arg": "val"}) |
| 1027 | + [TOOL_RESULT] <result> |
| 1028 | +
|
| 1029 | + User turn 2: |
| 1030 | + <user text> |
| 1031 | + ... |
| 1032 | +
|
| 1033 | + System/developer messages are included as a system preamble. Consecutive messages of the same |
| 1034 | + role are grouped into a single turn. Assistant string content is auto-normalized to content-block |
| 1035 | + format for consistent formatting. |
| 1036 | +
|
| 1037 | + :param messages: Chat messages with role and content. |
| 1038 | + :type messages: List[dict] |
| 1039 | + :return: Formatted text transcript. |
| 1040 | + :rtype: str |
| 1041 | + """ |
| 1042 | + if not messages: |
| 1043 | + return "" |
| 1044 | + |
| 1045 | + from azure.ai.evaluation._evaluators._common._validators._validation_constants import MessageRole |
| 1046 | + |
| 1047 | + all_user_queries = [] |
| 1048 | + all_agent_responses = [] |
| 1049 | + cur_user_query = [] |
| 1050 | + cur_agent_response = [] |
| 1051 | + system_message = None |
| 1052 | + |
| 1053 | + for msg in messages: |
| 1054 | + if not isinstance(msg, dict): |
| 1055 | + continue |
| 1056 | + role = msg.get("role") |
| 1057 | + if not role: |
| 1058 | + continue |
| 1059 | + |
| 1060 | + # _get_agent_response expects content as list of dicts, not a plain string |
| 1061 | + normalized = msg |
| 1062 | + if role == MessageRole.ASSISTANT and isinstance(msg.get("content"), str): |
| 1063 | + normalized = {**msg, "content": [{"type": "text", "text": msg["content"]}]} |
| 1064 | + |
| 1065 | + if role in (MessageRole.SYSTEM, MessageRole.DEVELOPER): |
| 1066 | + content = msg.get("content", "") |
| 1067 | + if isinstance(content, list): |
| 1068 | + system_message = "\n".join(_extract_text_from_content(content)) |
| 1069 | + else: |
| 1070 | + system_message = content |
| 1071 | + |
| 1072 | + elif role == MessageRole.USER and "content" in msg: |
| 1073 | + if cur_agent_response: |
| 1074 | + formatted = _get_agent_response(cur_agent_response, include_tool_messages=True) |
| 1075 | + all_agent_responses.append([formatted]) |
| 1076 | + cur_agent_response = [] |
| 1077 | + content = msg["content"] |
| 1078 | + if isinstance(content, str): |
| 1079 | + text_in_msg = [content] |
| 1080 | + else: |
| 1081 | + text_in_msg = _extract_text_from_content(content) |
| 1082 | + if text_in_msg: |
| 1083 | + cur_user_query.append(text_in_msg) |
| 1084 | + |
| 1085 | + elif role in (MessageRole.ASSISTANT, MessageRole.TOOL): |
| 1086 | + if cur_user_query: |
| 1087 | + all_user_queries.append(cur_user_query) |
| 1088 | + cur_user_query = [] |
| 1089 | + cur_agent_response.append(normalized) |
| 1090 | + |
| 1091 | + # Flush any remaining buffered turn |
| 1092 | + if cur_user_query: |
| 1093 | + all_user_queries.append(cur_user_query) |
| 1094 | + if cur_agent_response: |
| 1095 | + formatted = _get_agent_response(cur_agent_response, include_tool_messages=True) |
| 1096 | + all_agent_responses.append([formatted]) |
| 1097 | + |
| 1098 | + conversation_history: Dict = { |
| 1099 | + "user_queries": all_user_queries, |
| 1100 | + "agent_responses": all_agent_responses[: len(all_user_queries) - 1] if len(all_user_queries) > 0 else [], |
| 1101 | + } |
| 1102 | + if system_message: |
| 1103 | + conversation_history["system_message"] = system_message |
| 1104 | + |
| 1105 | + result = _pretty_format_conversation_history(conversation_history) |
| 1106 | + |
| 1107 | + # Append any trailing agent turn (the final response after the last user query) |
| 1108 | + start = max(len(all_user_queries) - 1, 0) |
| 1109 | + for i, agent_response in enumerate(all_agent_responses[start:], start=start): |
| 1110 | + result += f"Agent turn {i + 1}:\n" |
| 1111 | + for msg_text in agent_response: |
| 1112 | + if isinstance(msg_text, list): |
| 1113 | + for submsg in msg_text: |
| 1114 | + result += " " + "\n ".join(submsg.split("\n")) + "\n" |
| 1115 | + else: |
| 1116 | + result += " " + "\n ".join(msg_text.split("\n")) + "\n" |
| 1117 | + result += "\n" |
| 1118 | + |
| 1119 | + return result.rstrip("\n") |
| 1120 | + |
| 1121 | + |
| 1122 | +def _resolve_evaluation_level(evaluation_level, error_target): |
| 1123 | + """Validate and normalize the evaluation_level parameter. |
| 1124 | +
|
| 1125 | + :param evaluation_level: The evaluation level to resolve. |
| 1126 | + :type evaluation_level: Optional[Union[EvaluationLevel, str]] |
| 1127 | + :param error_target: The error target for exceptions. |
| 1128 | + :type error_target: ErrorTarget |
| 1129 | + :return: The resolved EvaluationLevel or None for auto-detect. |
| 1130 | + :rtype: Optional[EvaluationLevel] |
| 1131 | + """ |
| 1132 | + from .constants import EvaluationLevel |
| 1133 | + |
| 1134 | + valid = [level.value for level in EvaluationLevel] |
| 1135 | + if evaluation_level is None or evaluation_level == "": |
| 1136 | + return None |
| 1137 | + if isinstance(evaluation_level, EvaluationLevel): |
| 1138 | + return evaluation_level |
| 1139 | + if isinstance(evaluation_level, str): |
| 1140 | + try: |
| 1141 | + return EvaluationLevel(evaluation_level) |
| 1142 | + except ValueError: |
| 1143 | + raise EvaluationException( |
| 1144 | + message=(f"Invalid evaluation_level '{evaluation_level}'. " f"Must be one of: {valid}."), |
| 1145 | + blame=ErrorBlame.USER_ERROR, |
| 1146 | + category=ErrorCategory.INVALID_VALUE, |
| 1147 | + target=error_target, |
| 1148 | + ) |
| 1149 | + raise EvaluationException( |
| 1150 | + message=(f"Invalid evaluation_level '{evaluation_level}'. " f"Must be one of: {valid}."), |
| 1151 | + blame=ErrorBlame.USER_ERROR, |
| 1152 | + category=ErrorCategory.INVALID_VALUE, |
| 1153 | + target=error_target, |
| 1154 | + ) |
| 1155 | + |
| 1156 | + |
| 1157 | +def _is_intermediate_response(response): |
| 1158 | + """Check if response is intermediate (last content item is function_call or mcp_approval_request). |
| 1159 | +
|
| 1160 | + An intermediate response is one where the assistant's last message ends with a |
| 1161 | + function_call or mcp_approval_request content type, meaning the conversation is |
| 1162 | + still in progress and not yet ready for evaluation. |
| 1163 | +
|
| 1164 | + :param response: The response messages. |
| 1165 | + :type response: List[dict] |
| 1166 | + :return: True if the response is intermediate, False otherwise. |
| 1167 | + :rtype: bool |
| 1168 | + """ |
| 1169 | + if isinstance(response, list) and len(response) > 0: |
| 1170 | + last_msg = response[-1] |
| 1171 | + if isinstance(last_msg, dict) and last_msg.get("role") == "assistant": |
| 1172 | + content = last_msg.get("content", []) |
| 1173 | + if isinstance(content, list) and len(content) > 0: |
| 1174 | + last_content = content[-1] |
| 1175 | + if isinstance(last_content, dict) and last_content.get("type") in ( |
| 1176 | + "function_call", |
| 1177 | + "mcp_approval_request", |
| 1178 | + ): |
| 1179 | + return True |
| 1180 | + return False |
| 1181 | + |
| 1182 | + |
| 1183 | +def _drop_mcp_approval_messages(messages): |
| 1184 | + """Remove MCP approval request/response messages from a conversation. |
| 1185 | +
|
| 1186 | + MCP approval messages are protocol-level messages that should not be included |
| 1187 | + in the evaluation input. |
| 1188 | +
|
| 1189 | + :param messages: The conversation messages. |
| 1190 | + :type messages: List[dict] |
| 1191 | + :return: The filtered messages without MCP approval request/response messages. |
| 1192 | + :rtype: List[dict] |
| 1193 | + """ |
| 1194 | + if not isinstance(messages, list): |
| 1195 | + return messages |
| 1196 | + return [ |
| 1197 | + msg |
| 1198 | + for msg in messages |
| 1199 | + if not ( |
| 1200 | + isinstance(msg, dict) |
| 1201 | + and isinstance(msg.get("content"), list) |
| 1202 | + and ( |
| 1203 | + ( |
| 1204 | + msg.get("role") == "assistant" |
| 1205 | + and any(isinstance(c, dict) and c.get("type") == "mcp_approval_request" for c in msg["content"]) |
| 1206 | + ) |
| 1207 | + or ( |
| 1208 | + msg.get("role") == "tool" |
| 1209 | + and any(isinstance(c, dict) and c.get("type") == "mcp_approval_response" for c in msg["content"]) |
| 1210 | + ) |
| 1211 | + ) |
| 1212 | + ) |
| 1213 | + ] |
| 1214 | + |
| 1215 | + |
| 1216 | +def _normalize_function_call_types(messages): |
| 1217 | + """Normalize function_call/function_call_output/openapi_call/openapi_call_output types to tool_call/tool_result. |
| 1218 | +
|
| 1219 | + This ensures a consistent content type vocabulary for downstream evaluators |
| 1220 | + regardless of how the original messages were authored. |
| 1221 | +
|
| 1222 | + :param messages: The conversation messages. |
| 1223 | + :type messages: List[dict] |
| 1224 | + :return: The messages with normalized content types. |
| 1225 | + :rtype: List[dict] |
| 1226 | + """ |
| 1227 | + if not isinstance(messages, list): |
| 1228 | + return messages |
| 1229 | + for msg in messages: |
| 1230 | + if not isinstance(msg, dict) or not isinstance(msg.get("content"), list): |
| 1231 | + continue |
| 1232 | + for item in msg["content"]: |
| 1233 | + if not isinstance(item, dict): |
| 1234 | + continue |
| 1235 | + t = item.get("type") |
| 1236 | + if t == "function_call": |
| 1237 | + item["type"] = "tool_call" |
| 1238 | + elif t == "function_call_output": |
| 1239 | + item["type"] = "tool_result" |
| 1240 | + if "function_call_output" in item: |
| 1241 | + item["tool_result"] = item.pop("function_call_output") |
| 1242 | + elif t == "openapi_call": |
| 1243 | + item["type"] = "tool_call" |
| 1244 | + elif t == "openapi_call_output": |
| 1245 | + item["type"] = "tool_result" |
| 1246 | + if "openapi_call_output" in item: |
| 1247 | + item["tool_result"] = item.pop("openapi_call_output") |
| 1248 | + return messages |
| 1249 | + |
| 1250 | + |
| 1251 | +def _preprocess_messages(messages): |
| 1252 | + """Preprocess conversation messages by dropping MCP approval messages and normalizing function call types. |
| 1253 | +
|
| 1254 | + This should be called before passing messages to serialization or evaluation functions. |
| 1255 | +
|
| 1256 | + :param messages: The conversation messages. |
| 1257 | + :type messages: List[dict] |
| 1258 | + :return: The preprocessed messages. |
| 1259 | + :rtype: List[dict] |
| 1260 | + """ |
| 1261 | + messages = _drop_mcp_approval_messages(messages) |
| 1262 | + messages = _normalize_function_call_types(messages) |
| 1263 | + return messages |
| 1264 | + |
| 1265 | + |
| 1266 | +# endregion Multi-turn utilities |
0 commit comments