33import math
44import os
55import logging
6- from typing import Dict , Union , List , Optional
6+ from typing import Dict , Union , List , Optional , Tuple
77
88from typing_extensions import overload , override
99
@@ -52,6 +52,78 @@ class MessageRole(str, Enum):
5252 DEVELOPER = "developer"
5353
5454
55+ class EvaluationLevel (str , Enum ):
56+ """Supported evaluation levels for CustomerSatisfactionEvaluator.
57+
58+ - ``CONVERSATION``: Force conversation-level evaluation using the multi-turn path.
59+ - ``TRACE``: Force trace-level evaluation using the single-turn query/response path.
60+ """
61+
62+ CONVERSATION = "conversation"
63+ TRACE = "trace"
64+
65+
66+ def _merge_query_response_messages (query : List [dict ], response : List [dict ]) -> List [dict ]:
67+ """Merge query and response message lists into a single conversation."""
68+ return [* query , * response ]
69+
70+
71+ def _split_messages_at_latest_user (messages : List [dict ]) -> Tuple [List [dict ], List [dict ]]:
72+ """Split messages into query/response slices at the latest user turn."""
73+ latest_user_index = max (i for i , message in enumerate (messages ) if message ["role" ] == MessageRole .USER )
74+ return messages [: latest_user_index + 1 ], messages [latest_user_index + 1 :]
75+
76+
77+ def _wrap_string_messages (query : str , response : str ) -> Tuple [List [dict ], List [dict ]]:
78+ """Wrap string query/response into separate message lists."""
79+ return (
80+ [{"role" : "user" , "content" : [{"type" : "text" , "text" : query }]}],
81+ [{"role" : "assistant" , "content" : [{"type" : "text" , "text" : response }]}],
82+ )
83+
84+
85+ def _resolve_evaluation_level (
86+ evaluation_level : Optional [Union [EvaluationLevel , str ]],
87+ error_target : ErrorTarget ,
88+ ) -> Optional [EvaluationLevel ]:
89+ """Validate and normalize the evaluation_level parameter.
90+
91+ :param evaluation_level: The evaluation level to resolve.
92+ :type evaluation_level: Optional[Union[EvaluationLevel, str]]
93+ :param error_target: The error target for exceptions.
94+ :type error_target: ErrorTarget
95+ :return: The resolved EvaluationLevel or None for auto-detect.
96+ :rtype: Optional[EvaluationLevel]
97+ """
98+ valid = [level .value for level in EvaluationLevel ]
99+ if evaluation_level is None :
100+ return None
101+ if isinstance (evaluation_level , EvaluationLevel ):
102+ return evaluation_level
103+ if isinstance (evaluation_level , str ):
104+ try :
105+ return EvaluationLevel (evaluation_level )
106+ except ValueError :
107+ raise EvaluationException (
108+ message = (
109+ f"Invalid evaluation_level '{ evaluation_level } '. "
110+ f"Must be one of: { valid } ."
111+ ),
112+ blame = ErrorBlame .USER_ERROR ,
113+ category = ErrorCategory .INVALID_VALUE ,
114+ target = error_target ,
115+ )
116+ raise EvaluationException (
117+ message = (
118+ f"Invalid evaluation_level '{ evaluation_level } '. "
119+ f"Must be one of: { valid } ."
120+ ),
121+ blame = ErrorBlame .USER_ERROR ,
122+ category = ErrorCategory .INVALID_VALUE ,
123+ target = error_target ,
124+ )
125+
126+
55127class ContentType (str , Enum ):
56128 """Valid content types in messages."""
57129
@@ -723,13 +795,13 @@ def serialize_messages(messages: List[dict]) -> str:
723795
724796 # Normalize string content to list format for _get_agent_response
725797 normalized = msg
726- if role == "assistant" and isinstance (msg .get ("content" ), str ):
798+ if role == MessageRole . ASSISTANT and isinstance (msg .get ("content" ), str ):
727799 normalized = {** msg , "content" : [{"type" : "text" , "text" : msg ["content" ]}]}
728800
729- if role in ("system" , "developer" ):
801+ if role in (MessageRole . SYSTEM , MessageRole . DEVELOPER ):
730802 system_message = msg .get ("content" , "" )
731803
732- elif role == "user" and "content" in msg :
804+ elif role == MessageRole . USER and "content" in msg :
733805 if cur_agent_response :
734806 formatted = _get_agent_response (cur_agent_response , include_tool_messages = True )
735807 all_agent_responses .append ([formatted ])
@@ -742,7 +814,7 @@ def serialize_messages(messages: List[dict]) -> str:
742814 if text_in_msg :
743815 cur_user_query .append (text_in_msg )
744816
745- elif role in ("assistant" , "tool" ):
817+ elif role in (MessageRole . ASSISTANT , MessageRole . TOOL ):
746818 if cur_user_query :
747819 all_user_queries .append (cur_user_query )
748820 cur_user_query = []
@@ -845,7 +917,7 @@ class CustomerSatisfactionEvaluator(PromptyEvaluatorBase[Union[str, float]]):
845917 """Evaluator identifier, experimental and to be used only with evaluation in cloud."""
846918
847919 @override
848- def __init__ (self , model_config , * , credential = None , threshold = 3 , ** kwargs ):
920+ def __init__ (self , model_config , * , credential = None , threshold = 3 , evaluation_level = None , ** kwargs ):
849921 """Initialize the CustomerSatisfactionEvaluator.
850922
851923 :param model_config: Configuration for the Azure OpenAI model.
@@ -854,13 +926,23 @@ def __init__(self, model_config, *, credential=None, threshold=3, **kwargs):
854926 :type credential: Optional[TokenCredential]
855927 :keyword threshold: The threshold for the evaluator. Default is 3.
856928 :type threshold: int
929+ :keyword evaluation_level: Force a specific evaluation level for this invocation. When ``None``
930+ (default), the level is auto-detected from input shape (``messages`` -> conversation,
931+ ``query``/``response`` -> trace). Set to ``EvaluationLevel.CONVERSATION`` or
932+ ``EvaluationLevel.TRACE`` to override auto-detection.
933+ :type evaluation_level: Optional[Union[EvaluationLevel, str]]
857934 :keyword kwargs: Additional keyword arguments.
858935 """
859936 current_dir = os .path .dirname (__file__ )
860937 prompty_path = os .path .join (current_dir , self ._PROMPTY_FILE )
861938 self ._threshold = threshold
862939 self ._higher_is_better = True
863940
941+ # Validate and store evaluation level
942+ self ._evaluation_level = _resolve_evaluation_level (
943+ evaluation_level , ExtendedErrorTarget .CUSTOMER_SATISFACTION_EVALUATOR
944+ )
945+
864946 # Initialize input validator
865947 self ._validator = ConversationValidator (
866948 error_target = ExtendedErrorTarget .CUSTOMER_SATISFACTION_EVALUATOR ,
@@ -987,6 +1069,24 @@ def _not_applicable_result(
9871069 f"{ self ._result_key } _sample_output" : "" ,
9881070 }
9891071
1072+ def _should_use_conversation_level (self , eval_input : Dict ) -> bool :
1073+ """Determine whether to use conversation-level evaluation.
1074+
1075+ When ``_evaluation_level`` is set, it takes precedence. Otherwise, auto-detect
1076+ based on whether ``messages`` is present in the input.
1077+
1078+ :param eval_input: The evaluation input.
1079+ :type eval_input: Dict
1080+ :return: True if conversation-level evaluation should be used.
1081+ :rtype: bool
1082+ """
1083+ if self ._evaluation_level == EvaluationLevel .CONVERSATION :
1084+ return True
1085+ if self ._evaluation_level == EvaluationLevel .TRACE :
1086+ return False
1087+ # Auto-detect (_evaluation_level is None)
1088+ return eval_input .get ("messages" ) is not None
1089+
9901090 @override
9911091 async def _real_call (self , ** kwargs ):
9921092 """Perform asynchronous call where real end-to-end evaluation logic is executed.
@@ -996,7 +1096,20 @@ async def _real_call(self, **kwargs):
9961096 :return: The evaluation result.
9971097 :rtype: Union[DoEvalResult[T_EvalValue], AggregateResult[T_EvalValue]]
9981098 """
999- # Validate input before processing
1099+ # Reshape inputs based on evaluation level before validation
1100+ if self ._evaluation_level == EvaluationLevel .CONVERSATION and not kwargs .get ("messages" ):
1101+ query = kwargs .get ("query" )
1102+ response = kwargs .get ("response" )
1103+ if isinstance (query , str ) and isinstance (response , str ) and query and response :
1104+ query , response = _wrap_string_messages (query , response )
1105+ if isinstance (query , list ) and isinstance (response , list ):
1106+ kwargs ["messages" ] = _merge_query_response_messages (query , response )
1107+ elif self ._evaluation_level == EvaluationLevel .TRACE and kwargs .get ("messages" ):
1108+ if any (m .get ("role" ) == MessageRole .USER for m in kwargs ["messages" ]):
1109+ query_messages , response_messages = _split_messages_at_latest_user (kwargs ["messages" ])
1110+ kwargs ["query" ] = query_messages
1111+ kwargs ["response" ] = response_messages
1112+
10001113 self ._validator .validate_eval_input (kwargs )
10011114
10021115 return await super ()._real_call (** kwargs )
@@ -1005,16 +1118,17 @@ async def _real_call(self, **kwargs):
10051118 async def _do_eval (self , eval_input : Dict ) -> Dict [str , Union [float , str ]]: # type: ignore[override]
10061119 """Do Customer Satisfaction evaluation.
10071120
1008- Routes to the multi-turn path when ``messages`` is provided,
1009- otherwise falls through to the single-turn query/response path.
1121+ Routes to conversation-level or trace-level evaluation based on
1122+ ``_evaluation_level`` (if set)
1123+ or auto-detects from input shape (default).
10101124
10111125 :param eval_input: The input to the evaluator.
10121126 :type eval_input: Dict
10131127 :return: The evaluation result.
10141128 :rtype: Dict
10151129 """
10161130 # Multi-turn path (messages)
1017- if eval_input . get ( "messages" ) is not None :
1131+ if self . _should_use_conversation_level ( eval_input ) :
10181132 return await self ._do_eval_multi_turn (eval_input )
10191133
10201134 # Single-turn path (query/response)
@@ -1033,6 +1147,11 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # t
10331147 self ._threshold ,
10341148 )
10351149
1150+ if isinstance (eval_input .get ("response" ), list ):
1151+ eval_input ["response" ] = _preprocess_messages (eval_input ["response" ])
1152+ if isinstance (eval_input .get ("query" ), list ):
1153+ eval_input ["query" ] = _preprocess_messages (eval_input ["query" ])
1154+
10361155 # Reformat inputs if they are lists of messages
10371156 if isinstance (eval_input .get ("query" ), list ):
10381157 eval_input ["query" ] = reformat_conversation_history (
0 commit comments