@@ -132,13 +132,19 @@ def get_formatted_template(data: dict, annotation_task: str) -> str:
132132 return user_text .replace ("'" , '\\ "' )
133133
134134
135- def get_common_headers (token : str , evaluator_name : Optional [str ] = None ) -> Dict :
135+ def get_common_headers (
136+ token : str ,
137+ evaluator_name : Optional [str ] = None ,
138+ extra_headers : Optional [Dict [str , str ]] = None ,
139+ ) -> Dict :
136140 """Get common headers for the HTTP request
137141
138142 :param token: The Azure authentication token.
139143 :type token: str
140144 :param evaluator_name: The evaluator name. Default is None.
141145 :type evaluator_name: str
146+ :param extra_headers: Additional headers to include in the request. Default is None.
147+ :type extra_headers: Optional[Dict[str, str]]
142148 :return: The common headers.
143149 :rtype: Dict
144150 """
@@ -147,10 +153,12 @@ def get_common_headers(token: str, evaluator_name: Optional[str] = None) -> Dict
147153 if evaluator_name
148154 else UserAgentSingleton ().value
149155 )
150- return {
151- "Authorization" : f"Bearer { token } " ,
152- "User-Agent" : user_agent ,
153- }
156+ # Apply extra_headers first, then SDK-owned headers on top so that
157+ # Authorization, User-Agent, etc. can never be silently overridden.
158+ headers = dict (extra_headers ) if extra_headers else {}
159+ headers ["Authorization" ] = f"Bearer { token } "
160+ headers ["User-Agent" ] = user_agent
161+ return headers
154162
155163
156164def get_async_http_client_with_timeout () -> AsyncHttpPipeline :
@@ -203,7 +211,12 @@ async def ensure_service_availability_onedp(
203211 )
204212
205213
206- async def ensure_service_availability (rai_svc_url : str , token : str , capability : Optional [str ] = None ) -> None :
214+ async def ensure_service_availability (
215+ rai_svc_url : str ,
216+ token : str ,
217+ capability : Optional [str ] = None ,
218+ extra_headers : Optional [Dict [str , str ]] = None ,
219+ ) -> None :
207220 """Check if the Responsible AI service is available in the region and has the required capability, if relevant.
208221
209222 :param rai_svc_url: The Responsible AI service URL.
@@ -212,9 +225,11 @@ async def ensure_service_availability(rai_svc_url: str, token: str, capability:
212225 :type token: str
213226 :param capability: The capability to check. Default is None.
214227 :type capability: str
228+ :param extra_headers: Additional headers to include in the request. Default is None.
229+ :type extra_headers: Optional[Dict[str, str]]
215230 :raises Exception: If the service is not available in the region or the capability is not available.
216231 """
217- headers = get_common_headers (token )
232+ headers = get_common_headers (token , extra_headers = extra_headers )
218233 svc_liveness_url = rai_svc_url + "/checkannotation"
219234
220235 async with get_async_http_client () as client :
@@ -285,7 +300,13 @@ def generate_payload(normalized_user_text: str, metric: str, annotation_task: st
285300
286301
287302async def submit_request (
288- data : dict , metric : str , rai_svc_url : str , token : str , annotation_task : str , evaluator_name : str
303+ data : dict ,
304+ metric : str ,
305+ rai_svc_url : str ,
306+ token : str ,
307+ annotation_task : str ,
308+ evaluator_name : str ,
309+ extra_headers : Optional [Dict [str , str ]] = None ,
289310) -> str :
290311 """Submit request to Responsible AI service for evaluation and return operation ID
291312
@@ -308,7 +329,7 @@ async def submit_request(
308329 payload = generate_payload (normalized_user_text , metric , annotation_task = annotation_task )
309330
310331 url = rai_svc_url + "/submitannotation"
311- headers = get_common_headers (token , evaluator_name )
332+ headers = get_common_headers (token , evaluator_name , extra_headers = extra_headers )
312333
313334 async with get_async_http_client_with_timeout () as client :
314335 http_response = await client .post (url , json = payload , headers = headers )
@@ -360,7 +381,13 @@ async def submit_request_onedp(
360381 return operation_id
361382
362383
363- async def fetch_result (operation_id : str , rai_svc_url : str , credential : TokenCredential , token : str ) -> Dict :
384+ async def fetch_result (
385+ operation_id : str ,
386+ rai_svc_url : str ,
387+ credential : TokenCredential ,
388+ token : str ,
389+ extra_headers : Optional [Dict [str , str ]] = None ,
390+ ) -> Dict :
364391 """Fetch the annotation result from Responsible AI service
365392
366393 :param operation_id: The operation ID.
@@ -380,7 +407,7 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
380407 url = rai_svc_url + "/operations/" + operation_id
381408 while True :
382409 token = await fetch_or_reuse_token (credential , token )
383- headers = get_common_headers (token )
410+ headers = get_common_headers (token , extra_headers = extra_headers )
384411
385412 async with get_async_http_client () as client :
386413 response = await client .get (url , headers = headers , timeout = RAIService .TIMEOUT )
@@ -725,17 +752,23 @@ def _parse_content_harm_response(
725752 return result
726753
727754
728- async def _get_service_discovery_url (azure_ai_project : AzureAIProject , token : str ) -> str :
755+ async def _get_service_discovery_url (
756+ azure_ai_project : AzureAIProject ,
757+ token : str ,
758+ extra_headers : Optional [Dict [str , str ]] = None ,
759+ ) -> str :
729760 """Get the discovery service URL for the Azure AI project
730761
731762 :param azure_ai_project: The Azure AI project details.
732763 :type azure_ai_project: ~azure.ai.evaluation.AzureAIProject
733764 :param token: The Azure authentication token.
734765 :type token: str
766+ :param extra_headers: Additional headers to include in the request. Default is None.
767+ :type extra_headers: Optional[Dict[str, str]]
735768 :return: The discovery service URL.
736769 :rtype: str
737770 """
738- headers = get_common_headers (token )
771+ headers = get_common_headers (token , extra_headers = extra_headers )
739772
740773 async with get_async_http_client_with_timeout () as client :
741774 response = await client .get (
@@ -764,17 +797,25 @@ async def _get_service_discovery_url(azure_ai_project: AzureAIProject, token: st
764797 return f"{ base_url .scheme } ://{ base_url .netloc } "
765798
766799
767- async def get_rai_svc_url (project_scope : AzureAIProject , token : str ) -> str :
800+ async def get_rai_svc_url (
801+ project_scope : AzureAIProject ,
802+ token : str ,
803+ extra_headers : Optional [Dict [str , str ]] = None ,
804+ ) -> str :
768805 """Get the Responsible AI service URL
769806
770807 :param project_scope: The Azure AI project scope details.
771808 :type project_scope: Dict
772809 :param token: The Azure authentication token.
773810 :type token: str
811+ :param extra_headers: Additional headers to include in the request. Default is None.
812+ :type extra_headers: Optional[Dict[str, str]]
774813 :return: The Responsible AI service URL.
775814 :rtype: str
776815 """
777- discovery_url = await _get_service_discovery_url (azure_ai_project = project_scope , token = token )
816+ discovery_url = await _get_service_discovery_url (
817+ azure_ai_project = project_scope , token = token , extra_headers = extra_headers
818+ )
778819 subscription_id = project_scope ["subscription_id" ]
779820 resource_group_name = project_scope ["resource_group_name" ]
780821 project_name = project_scope ["project_name" ]
@@ -826,6 +867,7 @@ async def evaluate_with_rai_service(
826867 metric_display_name = None ,
827868 evaluator_name = None ,
828869 scan_session_id : Optional [str ] = None ,
870+ extra_headers : Optional [Dict [str , str ]] = None ,
829871) -> Dict [str , Union [str , float ]]:
830872 """Evaluate the content safety of the response using Responsible AI service (legacy endpoint)
831873
@@ -849,12 +891,12 @@ async def evaluate_with_rai_service(
849891 :return: The parsed annotation result.
850892 :rtype: Dict[str, Union[str, float]]
851893 """
852-
853894 if is_onedp_project (project_scope ):
854895 client = AIProjectClient (
855896 endpoint = project_scope ,
856897 credential = credential ,
857898 user_agent_policy = UserAgentPolicy (base_user_agent = UserAgentSingleton ().value ),
899+ headers = extra_headers or {},
858900 )
859901 token = await fetch_or_reuse_token (credential = credential , workspace = COG_SRV_WORKSPACE )
860902 await ensure_service_availability_onedp (client , token , annotation_task )
@@ -867,12 +909,22 @@ async def evaluate_with_rai_service(
867909 else :
868910 # Get RAI service URL from discovery service and check service availability
869911 token = await fetch_or_reuse_token (credential )
870- rai_svc_url = await get_rai_svc_url (project_scope , token )
871- await ensure_service_availability (rai_svc_url , token , annotation_task )
912+ rai_svc_url = await get_rai_svc_url (project_scope , token , extra_headers = extra_headers )
913+ await ensure_service_availability (rai_svc_url , token , annotation_task , extra_headers = extra_headers )
872914
873915 # Submit annotation request and fetch result
874- operation_id = await submit_request (data , metric_name , rai_svc_url , token , annotation_task , evaluator_name )
875- annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
916+ operation_id = await submit_request (
917+ data ,
918+ metric_name ,
919+ rai_svc_url ,
920+ token ,
921+ annotation_task ,
922+ evaluator_name ,
923+ extra_headers = extra_headers ,
924+ )
925+ annotation_response = cast (
926+ List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token , extra_headers = extra_headers )
927+ )
876928 result = parse_response (annotation_response , metric_name , metric_display_name )
877929
878930 return result
@@ -910,7 +962,9 @@ def generate_payload_multimodal(content_type: str, messages, metric: str) -> Dic
910962 }
911963
912964
913- async def submit_multimodal_request (messages , metric : str , rai_svc_url : str , token : str ) -> str :
965+ async def submit_multimodal_request (
966+ messages , metric : str , rai_svc_url : str , token : str , extra_headers : Optional [Dict [str , str ]] = None
967+ ) -> str :
914968 """Submit request to Responsible AI service for evaluation and return operation ID
915969 :param messages: The normalized list of messages to be entered as the "Contents" in the payload.
916970 :type messages: str
@@ -942,7 +996,7 @@ async def submit_multimodal_request(messages, metric: str, rai_svc_url: str, tok
942996
943997 ## calling rai service for annotation
944998 url = rai_svc_url + "/submitannotation"
945- headers = get_common_headers (token )
999+ headers = get_common_headers (token , extra_headers = extra_headers )
9461000 async with get_async_http_client () as client :
9471001 response = await client .post ( # pylint: disable=too-many-function-args,unexpected-keyword-arg
9481002 url , json = payload , headers = headers
@@ -1077,6 +1131,7 @@ async def evaluate_with_rai_service_sync(
10771131 evaluator_name = None ,
10781132 scan_session_id : Optional [str ] = None ,
10791133 use_legacy_endpoint : bool = False ,
1134+ extra_headers : Optional [Dict [str , str ]] = None ,
10801135):
10811136 """Evaluate the content safety of the response using Responsible AI service.
10821137
@@ -1120,19 +1175,25 @@ async def evaluate_with_rai_service_sync(
11201175 metric_display_name = metric_display_name ,
11211176 evaluator_name = evaluator_name ,
11221177 scan_session_id = scan_session_id ,
1178+ extra_headers = extra_headers ,
11231179 )
11241180
11251181 # Sync evals endpoint implementation (default)
11261182 api_version = "2025-11-15-preview"
11271183 if not is_onedp_project (project_scope ):
11281184 # Get RAI service URL from discovery service and check service availability
11291185 token = await fetch_or_reuse_token (credential )
1130- rai_svc_url = await get_rai_svc_url (project_scope , token )
1131- await ensure_service_availability (rai_svc_url , token , annotation_task )
1186+ rai_svc_url = await get_rai_svc_url (project_scope , token , extra_headers = extra_headers )
1187+ await ensure_service_availability (rai_svc_url , token , annotation_task , extra_headers = extra_headers )
11321188
11331189 # Submit annotation request and fetch result
11341190 url = rai_svc_url + f"/sync_evals:run?api-version={ api_version } "
1135- headers = {"aml-user-token" : token , "Authorization" : "Bearer " + token , "Content-Type" : "application/json" }
1191+ # Apply extra_headers first, then SDK-owned headers on top so that
1192+ # auth and content-type can never be silently overridden.
1193+ headers = dict (extra_headers ) if extra_headers else {}
1194+ headers ["aml-user-token" ] = token
1195+ headers ["Authorization" ] = "Bearer " + token
1196+ headers ["Content-Type" ] = "application/json"
11361197 sync_eval_payload = _build_sync_eval_payload (data , metric_name , annotation_task , scan_session_id )
11371198 sync_eval_payload_json = json .dumps (sync_eval_payload , cls = SdkJSONEncoder )
11381199
@@ -1150,6 +1211,7 @@ async def evaluate_with_rai_service_sync(
11501211 endpoint = project_scope ,
11511212 credential = credential ,
11521213 user_agent_policy = UserAgentPolicy (base_user_agent = UserAgentSingleton ().value ),
1214+ headers = extra_headers or {},
11531215 )
11541216
11551217 sync_eval_payload = _build_sync_eval_payload (data , metric_name , annotation_task , scan_session_id )
@@ -1305,6 +1367,7 @@ async def evaluate_with_rai_service_sync_multimodal(
13051367 credential : TokenCredential ,
13061368 scan_session_id : Optional [str ] = None ,
13071369 use_legacy_endpoint : bool = False ,
1370+ extra_headers : Optional [Dict [str , str ]] = None ,
13081371):
13091372 """Evaluate multimodal content using Responsible AI service.
13101373
@@ -1336,6 +1399,7 @@ async def evaluate_with_rai_service_sync_multimodal(
13361399 project_scope = project_scope ,
13371400 credential = credential ,
13381401 metric_display_name = metric_display_name ,
1402+ extra_headers = extra_headers ,
13391403 )
13401404
13411405 # Sync evals endpoint implementation (default)
@@ -1347,6 +1411,7 @@ async def evaluate_with_rai_service_sync_multimodal(
13471411 endpoint = project_scope ,
13481412 credential = credential ,
13491413 user_agent_policy = UserAgentPolicy (base_user_agent = UserAgentSingleton ().value ),
1414+ headers = extra_headers or {},
13501415 )
13511416
13521417 headers = {"x-ms-client-request-id" : scan_session_id } if scan_session_id else None
@@ -1355,15 +1420,16 @@ async def evaluate_with_rai_service_sync_multimodal(
13551420 return client .sync_evals .create (eval = sync_eval_payload )
13561421
13571422 token = await fetch_or_reuse_token (credential )
1358- rai_svc_url = await get_rai_svc_url (project_scope , token )
1359- await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM )
1423+ rai_svc_url = await get_rai_svc_url (project_scope , token , extra_headers = extra_headers )
1424+ await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM , extra_headers = extra_headers )
13601425
13611426 url = rai_svc_url + f"/sync_evals:run?api-version={ api_version } "
1362- headers = {
1363- "aml-user-token" : token ,
1364- "Authorization" : "Bearer " + token ,
1365- "Content-Type" : "application/json" ,
1366- }
1427+ # Apply extra_headers first, then SDK-owned headers on top so that
1428+ # auth, content-type, and correlation IDs can never be silently overridden.
1429+ headers = dict (extra_headers ) if extra_headers else {}
1430+ headers ["aml-user-token" ] = token
1431+ headers ["Authorization" ] = "Bearer " + token
1432+ headers ["Content-Type" ] = "application/json"
13671433 if scan_session_id :
13681434 headers ["x-ms-client-request-id" ] = scan_session_id
13691435
@@ -1385,6 +1451,7 @@ async def evaluate_with_rai_service_multimodal(
13851451 project_scope : Union [str , AzureAIProject ],
13861452 credential : TokenCredential ,
13871453 metric_display_name : Optional [str ] = None ,
1454+ extra_headers : Optional [Dict [str , str ]] = None ,
13881455):
13891456 """Evaluate the content safety of the response using Responsible AI service (legacy endpoint)
13901457 :param messages: The normalized list of messages.
@@ -1407,6 +1474,7 @@ async def evaluate_with_rai_service_multimodal(
14071474 endpoint = project_scope ,
14081475 credential = credential ,
14091476 user_agent_policy = UserAgentPolicy (base_user_agent = UserAgentSingleton ().value ),
1477+ headers = extra_headers or {},
14101478 )
14111479 token = await fetch_or_reuse_token (credential = credential , workspace = COG_SRV_WORKSPACE )
14121480 await ensure_service_availability_onedp (client , token , Tasks .CONTENT_HARM )
@@ -1416,10 +1484,14 @@ async def evaluate_with_rai_service_multimodal(
14161484 return result
14171485 else :
14181486 token = await fetch_or_reuse_token (credential )
1419- rai_svc_url = await get_rai_svc_url (project_scope , token )
1420- await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM )
1487+ rai_svc_url = await get_rai_svc_url (project_scope , token , extra_headers = extra_headers )
1488+ await ensure_service_availability (rai_svc_url , token , Tasks .CONTENT_HARM , extra_headers = extra_headers )
14211489 # Submit annotation request and fetch result
1422- operation_id = await submit_multimodal_request (messages , metric_name , rai_svc_url , token )
1423- annotation_response = cast (List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token ))
1490+ operation_id = await submit_multimodal_request (
1491+ messages , metric_name , rai_svc_url , token , extra_headers = extra_headers
1492+ )
1493+ annotation_response = cast (
1494+ List [Dict ], await fetch_result (operation_id , rai_svc_url , credential , token , extra_headers = extra_headers )
1495+ )
14241496 result = parse_response (annotation_response , metric_name , metric_display_name )
14251497 return result
0 commit comments