Skip to content

Commit 660d933

Browse files
[feat] add extra_headers for rai service evaluators + move gen_ai usage to internal properties (#46685)
* [feat] add extra_headers for rai service evaluators * reformat * change to internal properties * fix * sdk-owned headers + changelogs --------- Co-authored-by: zyysurely <yingyingzhao@microsoft.com>
1 parent 8ed1412 commit 660d933

6 files changed

Lines changed: 463 additions & 55 deletions

File tree

sdk/evaluation/azure-ai-evaluation/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44

55
### Features Added
66

7+
- Added `extra_headers` keyword argument to `RaiServiceEvaluatorBase` (and all content safety evaluators) to allow passing custom HTTP headers to all backend RAI service calls. SDK-owned headers (`Authorization`, `User-Agent`, `Content-Type`, `aml-user-token`, `x-ms-client-request-id`) cannot be overridden by `extra_headers`.
8+
79
### Breaking Changes
810

911
### Bugs Fixed
1012

1113
### Other Changes
1214

15+
- Moved token usage attributes (`gen_ai.evaluation.usage.input_tokens`, `gen_ai.evaluation.usage.output_tokens`) from standard App Insights event attributes into the `internal_properties` JSON bag to align with internal telemetry conventions.
16+
1317
## 1.16.6 (2026-04-27)
1418

1519
### Bugs Fixed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/rai_service.py

Lines changed: 108 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,19 @@ def get_formatted_template(data: dict, annotation_task: str) -> str:
132132
return user_text.replace("'", '\\"')
133133

134134

135-
def get_common_headers(token: str, evaluator_name: Optional[str] = None) -> Dict:
135+
def get_common_headers(
136+
token: str,
137+
evaluator_name: Optional[str] = None,
138+
extra_headers: Optional[Dict[str, str]] = None,
139+
) -> Dict:
136140
"""Get common headers for the HTTP request
137141
138142
:param token: The Azure authentication token.
139143
:type token: str
140144
:param evaluator_name: The evaluator name. Default is None.
141145
:type evaluator_name: str
146+
:param extra_headers: Additional headers to include in the request. Default is None.
147+
:type extra_headers: Optional[Dict[str, str]]
142148
:return: The common headers.
143149
:rtype: Dict
144150
"""
@@ -147,10 +153,12 @@ def get_common_headers(token: str, evaluator_name: Optional[str] = None) -> Dict
147153
if evaluator_name
148154
else UserAgentSingleton().value
149155
)
150-
return {
151-
"Authorization": f"Bearer {token}",
152-
"User-Agent": user_agent,
153-
}
156+
# Apply extra_headers first, then SDK-owned headers on top so that
157+
# Authorization, User-Agent, etc. can never be silently overridden.
158+
headers = dict(extra_headers) if extra_headers else {}
159+
headers["Authorization"] = f"Bearer {token}"
160+
headers["User-Agent"] = user_agent
161+
return headers
154162

155163

156164
def get_async_http_client_with_timeout() -> AsyncHttpPipeline:
@@ -203,7 +211,12 @@ async def ensure_service_availability_onedp(
203211
)
204212

205213

206-
async def ensure_service_availability(rai_svc_url: str, token: str, capability: Optional[str] = None) -> None:
214+
async def ensure_service_availability(
215+
rai_svc_url: str,
216+
token: str,
217+
capability: Optional[str] = None,
218+
extra_headers: Optional[Dict[str, str]] = None,
219+
) -> None:
207220
"""Check if the Responsible AI service is available in the region and has the required capability, if relevant.
208221
209222
:param rai_svc_url: The Responsible AI service URL.
@@ -212,9 +225,11 @@ async def ensure_service_availability(rai_svc_url: str, token: str, capability:
212225
:type token: str
213226
:param capability: The capability to check. Default is None.
214227
:type capability: str
228+
:param extra_headers: Additional headers to include in the request. Default is None.
229+
:type extra_headers: Optional[Dict[str, str]]
215230
:raises Exception: If the service is not available in the region or the capability is not available.
216231
"""
217-
headers = get_common_headers(token)
232+
headers = get_common_headers(token, extra_headers=extra_headers)
218233
svc_liveness_url = rai_svc_url + "/checkannotation"
219234

220235
async with get_async_http_client() as client:
@@ -285,7 +300,13 @@ def generate_payload(normalized_user_text: str, metric: str, annotation_task: st
285300

286301

287302
async def submit_request(
288-
data: dict, metric: str, rai_svc_url: str, token: str, annotation_task: str, evaluator_name: str
303+
data: dict,
304+
metric: str,
305+
rai_svc_url: str,
306+
token: str,
307+
annotation_task: str,
308+
evaluator_name: str,
309+
extra_headers: Optional[Dict[str, str]] = None,
289310
) -> str:
290311
"""Submit request to Responsible AI service for evaluation and return operation ID
291312
@@ -308,7 +329,7 @@ async def submit_request(
308329
payload = generate_payload(normalized_user_text, metric, annotation_task=annotation_task)
309330

310331
url = rai_svc_url + "/submitannotation"
311-
headers = get_common_headers(token, evaluator_name)
332+
headers = get_common_headers(token, evaluator_name, extra_headers=extra_headers)
312333

313334
async with get_async_http_client_with_timeout() as client:
314335
http_response = await client.post(url, json=payload, headers=headers)
@@ -360,7 +381,13 @@ async def submit_request_onedp(
360381
return operation_id
361382

362383

363-
async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCredential, token: str) -> Dict:
384+
async def fetch_result(
385+
operation_id: str,
386+
rai_svc_url: str,
387+
credential: TokenCredential,
388+
token: str,
389+
extra_headers: Optional[Dict[str, str]] = None,
390+
) -> Dict:
364391
"""Fetch the annotation result from Responsible AI service
365392
366393
:param operation_id: The operation ID.
@@ -380,7 +407,7 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
380407
url = rai_svc_url + "/operations/" + operation_id
381408
while True:
382409
token = await fetch_or_reuse_token(credential, token)
383-
headers = get_common_headers(token)
410+
headers = get_common_headers(token, extra_headers=extra_headers)
384411

385412
async with get_async_http_client() as client:
386413
response = await client.get(url, headers=headers, timeout=RAIService.TIMEOUT)
@@ -725,17 +752,23 @@ def _parse_content_harm_response(
725752
return result
726753

727754

728-
async def _get_service_discovery_url(azure_ai_project: AzureAIProject, token: str) -> str:
755+
async def _get_service_discovery_url(
756+
azure_ai_project: AzureAIProject,
757+
token: str,
758+
extra_headers: Optional[Dict[str, str]] = None,
759+
) -> str:
729760
"""Get the discovery service URL for the Azure AI project
730761
731762
:param azure_ai_project: The Azure AI project details.
732763
:type azure_ai_project: ~azure.ai.evaluation.AzureAIProject
733764
:param token: The Azure authentication token.
734765
:type token: str
766+
:param extra_headers: Additional headers to include in the request. Default is None.
767+
:type extra_headers: Optional[Dict[str, str]]
735768
:return: The discovery service URL.
736769
:rtype: str
737770
"""
738-
headers = get_common_headers(token)
771+
headers = get_common_headers(token, extra_headers=extra_headers)
739772

740773
async with get_async_http_client_with_timeout() as client:
741774
response = await client.get(
@@ -764,17 +797,25 @@ async def _get_service_discovery_url(azure_ai_project: AzureAIProject, token: st
764797
return f"{base_url.scheme}://{base_url.netloc}"
765798

766799

767-
async def get_rai_svc_url(project_scope: AzureAIProject, token: str) -> str:
800+
async def get_rai_svc_url(
801+
project_scope: AzureAIProject,
802+
token: str,
803+
extra_headers: Optional[Dict[str, str]] = None,
804+
) -> str:
768805
"""Get the Responsible AI service URL
769806
770807
:param project_scope: The Azure AI project scope details.
771808
:type project_scope: Dict
772809
:param token: The Azure authentication token.
773810
:type token: str
811+
:param extra_headers: Additional headers to include in the request. Default is None.
812+
:type extra_headers: Optional[Dict[str, str]]
774813
:return: The Responsible AI service URL.
775814
:rtype: str
776815
"""
777-
discovery_url = await _get_service_discovery_url(azure_ai_project=project_scope, token=token)
816+
discovery_url = await _get_service_discovery_url(
817+
azure_ai_project=project_scope, token=token, extra_headers=extra_headers
818+
)
778819
subscription_id = project_scope["subscription_id"]
779820
resource_group_name = project_scope["resource_group_name"]
780821
project_name = project_scope["project_name"]
@@ -826,6 +867,7 @@ async def evaluate_with_rai_service(
826867
metric_display_name=None,
827868
evaluator_name=None,
828869
scan_session_id: Optional[str] = None,
870+
extra_headers: Optional[Dict[str, str]] = None,
829871
) -> Dict[str, Union[str, float]]:
830872
"""Evaluate the content safety of the response using Responsible AI service (legacy endpoint)
831873
@@ -849,12 +891,12 @@ async def evaluate_with_rai_service(
849891
:return: The parsed annotation result.
850892
:rtype: Dict[str, Union[str, float]]
851893
"""
852-
853894
if is_onedp_project(project_scope):
854895
client = AIProjectClient(
855896
endpoint=project_scope,
856897
credential=credential,
857898
user_agent_policy=UserAgentPolicy(base_user_agent=UserAgentSingleton().value),
899+
headers=extra_headers or {},
858900
)
859901
token = await fetch_or_reuse_token(credential=credential, workspace=COG_SRV_WORKSPACE)
860902
await ensure_service_availability_onedp(client, token, annotation_task)
@@ -867,12 +909,22 @@ async def evaluate_with_rai_service(
867909
else:
868910
# Get RAI service URL from discovery service and check service availability
869911
token = await fetch_or_reuse_token(credential)
870-
rai_svc_url = await get_rai_svc_url(project_scope, token)
871-
await ensure_service_availability(rai_svc_url, token, annotation_task)
912+
rai_svc_url = await get_rai_svc_url(project_scope, token, extra_headers=extra_headers)
913+
await ensure_service_availability(rai_svc_url, token, annotation_task, extra_headers=extra_headers)
872914

873915
# Submit annotation request and fetch result
874-
operation_id = await submit_request(data, metric_name, rai_svc_url, token, annotation_task, evaluator_name)
875-
annotation_response = cast(List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token))
916+
operation_id = await submit_request(
917+
data,
918+
metric_name,
919+
rai_svc_url,
920+
token,
921+
annotation_task,
922+
evaluator_name,
923+
extra_headers=extra_headers,
924+
)
925+
annotation_response = cast(
926+
List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token, extra_headers=extra_headers)
927+
)
876928
result = parse_response(annotation_response, metric_name, metric_display_name)
877929

878930
return result
@@ -910,7 +962,9 @@ def generate_payload_multimodal(content_type: str, messages, metric: str) -> Dic
910962
}
911963

912964

913-
async def submit_multimodal_request(messages, metric: str, rai_svc_url: str, token: str) -> str:
965+
async def submit_multimodal_request(
966+
messages, metric: str, rai_svc_url: str, token: str, extra_headers: Optional[Dict[str, str]] = None
967+
) -> str:
914968
"""Submit request to Responsible AI service for evaluation and return operation ID
915969
:param messages: The normalized list of messages to be entered as the "Contents" in the payload.
916970
:type messages: str
@@ -942,7 +996,7 @@ async def submit_multimodal_request(messages, metric: str, rai_svc_url: str, tok
942996

943997
## calling rai service for annotation
944998
url = rai_svc_url + "/submitannotation"
945-
headers = get_common_headers(token)
999+
headers = get_common_headers(token, extra_headers=extra_headers)
9461000
async with get_async_http_client() as client:
9471001
response = await client.post( # pylint: disable=too-many-function-args,unexpected-keyword-arg
9481002
url, json=payload, headers=headers
@@ -1077,6 +1131,7 @@ async def evaluate_with_rai_service_sync(
10771131
evaluator_name=None,
10781132
scan_session_id: Optional[str] = None,
10791133
use_legacy_endpoint: bool = False,
1134+
extra_headers: Optional[Dict[str, str]] = None,
10801135
):
10811136
"""Evaluate the content safety of the response using Responsible AI service.
10821137
@@ -1120,19 +1175,25 @@ async def evaluate_with_rai_service_sync(
11201175
metric_display_name=metric_display_name,
11211176
evaluator_name=evaluator_name,
11221177
scan_session_id=scan_session_id,
1178+
extra_headers=extra_headers,
11231179
)
11241180

11251181
# Sync evals endpoint implementation (default)
11261182
api_version = "2025-11-15-preview"
11271183
if not is_onedp_project(project_scope):
11281184
# Get RAI service URL from discovery service and check service availability
11291185
token = await fetch_or_reuse_token(credential)
1130-
rai_svc_url = await get_rai_svc_url(project_scope, token)
1131-
await ensure_service_availability(rai_svc_url, token, annotation_task)
1186+
rai_svc_url = await get_rai_svc_url(project_scope, token, extra_headers=extra_headers)
1187+
await ensure_service_availability(rai_svc_url, token, annotation_task, extra_headers=extra_headers)
11321188

11331189
# Submit annotation request and fetch result
11341190
url = rai_svc_url + f"/sync_evals:run?api-version={api_version}"
1135-
headers = {"aml-user-token": token, "Authorization": "Bearer " + token, "Content-Type": "application/json"}
1191+
# Apply extra_headers first, then SDK-owned headers on top so that
1192+
# auth and content-type can never be silently overridden.
1193+
headers = dict(extra_headers) if extra_headers else {}
1194+
headers["aml-user-token"] = token
1195+
headers["Authorization"] = "Bearer " + token
1196+
headers["Content-Type"] = "application/json"
11361197
sync_eval_payload = _build_sync_eval_payload(data, metric_name, annotation_task, scan_session_id)
11371198
sync_eval_payload_json = json.dumps(sync_eval_payload, cls=SdkJSONEncoder)
11381199

@@ -1150,6 +1211,7 @@ async def evaluate_with_rai_service_sync(
11501211
endpoint=project_scope,
11511212
credential=credential,
11521213
user_agent_policy=UserAgentPolicy(base_user_agent=UserAgentSingleton().value),
1214+
headers=extra_headers or {},
11531215
)
11541216

11551217
sync_eval_payload = _build_sync_eval_payload(data, metric_name, annotation_task, scan_session_id)
@@ -1305,6 +1367,7 @@ async def evaluate_with_rai_service_sync_multimodal(
13051367
credential: TokenCredential,
13061368
scan_session_id: Optional[str] = None,
13071369
use_legacy_endpoint: bool = False,
1370+
extra_headers: Optional[Dict[str, str]] = None,
13081371
):
13091372
"""Evaluate multimodal content using Responsible AI service.
13101373
@@ -1336,6 +1399,7 @@ async def evaluate_with_rai_service_sync_multimodal(
13361399
project_scope=project_scope,
13371400
credential=credential,
13381401
metric_display_name=metric_display_name,
1402+
extra_headers=extra_headers,
13391403
)
13401404

13411405
# Sync evals endpoint implementation (default)
@@ -1347,6 +1411,7 @@ async def evaluate_with_rai_service_sync_multimodal(
13471411
endpoint=project_scope,
13481412
credential=credential,
13491413
user_agent_policy=UserAgentPolicy(base_user_agent=UserAgentSingleton().value),
1414+
headers=extra_headers or {},
13501415
)
13511416

13521417
headers = {"x-ms-client-request-id": scan_session_id} if scan_session_id else None
@@ -1355,15 +1420,16 @@ async def evaluate_with_rai_service_sync_multimodal(
13551420
return client.sync_evals.create(eval=sync_eval_payload)
13561421

13571422
token = await fetch_or_reuse_token(credential)
1358-
rai_svc_url = await get_rai_svc_url(project_scope, token)
1359-
await ensure_service_availability(rai_svc_url, token, Tasks.CONTENT_HARM)
1423+
rai_svc_url = await get_rai_svc_url(project_scope, token, extra_headers=extra_headers)
1424+
await ensure_service_availability(rai_svc_url, token, Tasks.CONTENT_HARM, extra_headers=extra_headers)
13601425

13611426
url = rai_svc_url + f"/sync_evals:run?api-version={api_version}"
1362-
headers = {
1363-
"aml-user-token": token,
1364-
"Authorization": "Bearer " + token,
1365-
"Content-Type": "application/json",
1366-
}
1427+
# Apply extra_headers first, then SDK-owned headers on top so that
1428+
# auth, content-type, and correlation IDs can never be silently overridden.
1429+
headers = dict(extra_headers) if extra_headers else {}
1430+
headers["aml-user-token"] = token
1431+
headers["Authorization"] = "Bearer " + token
1432+
headers["Content-Type"] = "application/json"
13671433
if scan_session_id:
13681434
headers["x-ms-client-request-id"] = scan_session_id
13691435

@@ -1385,6 +1451,7 @@ async def evaluate_with_rai_service_multimodal(
13851451
project_scope: Union[str, AzureAIProject],
13861452
credential: TokenCredential,
13871453
metric_display_name: Optional[str] = None,
1454+
extra_headers: Optional[Dict[str, str]] = None,
13881455
):
13891456
"""Evaluate the content safety of the response using Responsible AI service (legacy endpoint)
13901457
:param messages: The normalized list of messages.
@@ -1407,6 +1474,7 @@ async def evaluate_with_rai_service_multimodal(
14071474
endpoint=project_scope,
14081475
credential=credential,
14091476
user_agent_policy=UserAgentPolicy(base_user_agent=UserAgentSingleton().value),
1477+
headers=extra_headers or {},
14101478
)
14111479
token = await fetch_or_reuse_token(credential=credential, workspace=COG_SRV_WORKSPACE)
14121480
await ensure_service_availability_onedp(client, token, Tasks.CONTENT_HARM)
@@ -1416,10 +1484,14 @@ async def evaluate_with_rai_service_multimodal(
14161484
return result
14171485
else:
14181486
token = await fetch_or_reuse_token(credential)
1419-
rai_svc_url = await get_rai_svc_url(project_scope, token)
1420-
await ensure_service_availability(rai_svc_url, token, Tasks.CONTENT_HARM)
1487+
rai_svc_url = await get_rai_svc_url(project_scope, token, extra_headers=extra_headers)
1488+
await ensure_service_availability(rai_svc_url, token, Tasks.CONTENT_HARM, extra_headers=extra_headers)
14211489
# Submit annotation request and fetch result
1422-
operation_id = await submit_multimodal_request(messages, metric_name, rai_svc_url, token)
1423-
annotation_response = cast(List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token))
1490+
operation_id = await submit_multimodal_request(
1491+
messages, metric_name, rai_svc_url, token, extra_headers=extra_headers
1492+
)
1493+
annotation_response = cast(
1494+
List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token, extra_headers=extra_headers)
1495+
)
14241496
result = parse_response(annotation_response, metric_name, metric_display_name)
14251497
return result

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,9 +1277,9 @@ def _log_events_to_app_insights(
12771277
usage = sample.get("usage", {})
12781278
usage = usage if isinstance(usage, dict) else {}
12791279
if usage.get("prompt_tokens") is not None:
1280-
standard_log_attributes["gen_ai.evaluation.usage.input_tokens"] = str(usage["prompt_tokens"])
1280+
internal_log_attributes["gen_ai.evaluation.usage.input_tokens"] = str(usage["prompt_tokens"])
12811281
if usage.get("completion_tokens") is not None:
1282-
standard_log_attributes["gen_ai.evaluation.usage.output_tokens"] = str(usage["completion_tokens"])
1282+
internal_log_attributes["gen_ai.evaluation.usage.output_tokens"] = str(usage["completion_tokens"])
12831283

12841284
# Combine standard and internal attributes, put internal under the properties bag
12851285
standard_log_attributes["internal_properties"] = json.dumps(internal_log_attributes)

0 commit comments

Comments
 (0)