Skip to content

Commit f22ad42

Browse files
xuming-msMing Xu
andauthored
[Core] Differentiate Copilot agent requests from manual requests by adding session ID into token claims (#33309)
Co-authored-by: Ming Xu <xumi@microsoft.com>
1 parent 62e498c commit f22ad42

9 files changed

Lines changed: 441 additions & 6 deletions

File tree

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
"""
7+
Support for Entra Agentic Sessions.
8+
9+
When CLI runs inside an agent context (e.g., Copilot, Azure MCP), the orchestrator sets the
10+
COPILOT_AGENT_SESSION_ID environment variable. CLI reads it and passes it to MSAL as both:
11+
- A query parameter (`client_session`) so ESTS can identify the agentic session
12+
- A claims challenge so ESTS embeds an agentic marker claim in the token (and MSAL bypasses
13+
the access token cache to ensure a fresh, agent-tagged token is always fetched)
14+
15+
This enables downstream systems (RBAC, Defender, Purview) to enforce differentiated policies
16+
for agent-driven vs. human-driven operations.
17+
"""
18+
19+
import json
20+
import os
21+
22+
from knack.log import get_logger
23+
24+
logger = get_logger(__name__)
25+
26+
COPILOT_AGENT_SESSION_ID = "COPILOT_AGENT_SESSION_ID"
27+
28+
29+
def build_agentic_session_params():
30+
"""Read COPILOT_AGENT_SESSION_ID and build the agentic claims challenge.
31+
32+
:returns: (session_id, claims_challenge) — both None when env var is not set.
33+
"""
34+
session_id = os.environ.get(COPILOT_AGENT_SESSION_ID) or None
35+
if not session_id:
36+
return None, None
37+
38+
logger.debug("Agentic session detected (COPILOT_AGENT_SESSION_ID is set)")
39+
40+
claims_challenge = json.dumps({
41+
"access_token": {
42+
"xms_cli_sid": {"values": [session_id]}
43+
}
44+
})
45+
return session_id, claims_challenge
46+
47+
48+
def merge_access_token_claims(existing_claims, new_claims):
49+
"""Merge new claims into an existing claims_challenge JSON string.
50+
51+
:param existing_claims: Existing claims_challenge JSON string (or None).
52+
:param new_claims: New claims_challenge JSON string to merge in. Must not be None or empty,
53+
and must contain a non-empty ``access_token`` object.
54+
:returns: Merged claims_challenge JSON string.
55+
:raises ValueError: If ``new_claims`` is None, empty, or does not contain a non-empty
56+
``access_token`` object.
57+
"""
58+
if not new_claims:
59+
raise ValueError("new_claims must not be None or empty")
60+
new_access_token = json.loads(new_claims).get("access_token")
61+
if not new_access_token:
62+
raise ValueError("new_claims must contain a non-empty access_token")
63+
64+
claims_dict = json.loads(existing_claims) if existing_claims else {}
65+
claims_dict["access_token"] = claims_dict.get("access_token") or {}
66+
claims_dict["access_token"].update(new_access_token)
67+
return json.dumps(claims_dict)

src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,27 @@ def acquire_token(self, scopes, claims_challenge=None, **kwargs):
5050
logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, kwargs=%r",
5151
scopes, claims_challenge, kwargs)
5252

53+
# Apply agentic session parameters for user identity flows
54+
from .agentic_session import build_agentic_session_params, merge_access_token_claims
55+
agentic_session_id, agentic_claims = build_agentic_session_params()
56+
if agentic_session_id:
57+
# Both paths: client_session in data and params so eSTS can identify the agentic session
58+
kwargs["data"] = kwargs.get("data") or {}
59+
kwargs["data"]["client_session"] = agentic_session_id
60+
kwargs["params"] = kwargs.get("params") or {}
61+
kwargs["params"]["client_session"] = agentic_session_id
62+
63+
if getattr(self._msal_app, '_enable_broker', False):
64+
# Broker path: claims_challenge flows to MSALRuntime cache key via set_decoded_claims.
65+
# This causes MSAL to skip its local AT cache and forward claims to the broker,
66+
# where requestedClaims becomes part of the C++ cache key.
67+
claims_challenge = merge_access_token_claims(claims_challenge, agentic_claims)
68+
# Non-broker path: client_session in data flows into ext_cache_key (SHA256 hash),
69+
# which partitions the MSAL Python token cache. No claims_challenge needed.
70+
71+
from azure.cli.core.telemetry import set_agentic_session
72+
set_agentic_session(True)
73+
5374
if claims_challenge:
5475
logger.info('Acquiring new access token silently with claims challenge: %s', claims_challenge)
5576
result = self._msal_app.acquire_token_silent_with_error(
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
import json
7+
import os
8+
import unittest
9+
from unittest.mock import patch
10+
11+
from azure.cli.core.auth.agentic_session import (
12+
COPILOT_AGENT_SESSION_ID,
13+
build_agentic_session_params,
14+
merge_access_token_claims,
15+
)
16+
17+
18+
class TestBuildAgenticSessionParams(unittest.TestCase):
19+
20+
def test_returns_none_when_env_not_set(self):
21+
with patch.dict(os.environ, {}, clear=True):
22+
session_id, claims = build_agentic_session_params()
23+
self.assertIsNone(session_id)
24+
self.assertIsNone(claims)
25+
26+
def test_returns_none_when_env_is_empty_string(self):
27+
with patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: ""}):
28+
session_id, claims = build_agentic_session_params()
29+
self.assertIsNone(session_id)
30+
self.assertIsNone(claims)
31+
32+
def test_returns_session_id_and_claims(self):
33+
with patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "sess-456"}):
34+
session_id, claims = build_agentic_session_params()
35+
self.assertEqual(session_id, "sess-456")
36+
parsed = json.loads(claims)
37+
self.assertEqual(parsed["access_token"]["xms_cli_sid"]["values"], ["sess-456"])
38+
39+
def _agentic_claims(session_id="s1"):
40+
return json.dumps({"access_token": {"xms_cli_sid": {"values": [session_id]}}})
41+
42+
43+
class TestMergeAccessTokenClaims(unittest.TestCase):
44+
45+
# --- Validation ---
46+
47+
def test_raises_when_new_claims_is_none(self):
48+
with self.assertRaises(ValueError):
49+
merge_access_token_claims(None, None)
50+
51+
def test_raises_when_new_access_token_is_null(self):
52+
new = json.dumps({"access_token": None})
53+
with self.assertRaises(ValueError):
54+
merge_access_token_claims(None, new)
55+
56+
# --- Merging ---
57+
58+
def test_merges_into_none(self):
59+
result = merge_access_token_claims(None, _agentic_claims("s1"))
60+
claims = json.loads(result)
61+
self.assertEqual(len(claims), 1)
62+
self.assertEqual(len(claims["access_token"]), 1)
63+
self.assertEqual(claims["access_token"]["xms_cli_sid"], {"values": ["s1"]})
64+
65+
def test_merges_into_existing(self):
66+
existing = json.dumps({"access_token": {"nbf": {"essential": True, "value": "999"}}})
67+
result = merge_access_token_claims(existing, _agentic_claims("s1"))
68+
merged = json.loads(result)
69+
self.assertEqual(len(merged), 1)
70+
self.assertEqual(len(merged["access_token"]), 2)
71+
self.assertEqual(merged["access_token"]["nbf"], {"essential": True, "value": "999"})
72+
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})
73+
74+
def test_preserves_non_access_token_keys(self):
75+
existing = json.dumps({
76+
"access_token": {"nbf": {"essential": True}},
77+
"id_token": {"auth_time": {"essential": True}}
78+
})
79+
result = merge_access_token_claims(existing, _agentic_claims())
80+
merged = json.loads(result)
81+
self.assertEqual(len(merged), 2)
82+
self.assertEqual(len(merged["access_token"]), 2)
83+
self.assertEqual(merged["id_token"], {"auth_time": {"essential": True}})
84+
self.assertEqual(merged["access_token"]["nbf"], {"essential": True})
85+
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})
86+
87+
def test_new_claims_overwrites_existing_key(self):
88+
existing = json.dumps({"access_token": {"xms_cli_sid": {"values": ["old"]}}})
89+
result = merge_access_token_claims(existing, _agentic_claims("new"))
90+
merged = json.loads(result)
91+
self.assertEqual(len(merged), 1)
92+
self.assertEqual(len(merged["access_token"]), 1)
93+
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["new"]})
94+
95+
def test_creates_access_token_when_missing_in_existing(self):
96+
existing = json.dumps({"id_token": {"auth_time": {"essential": True}}})
97+
result = merge_access_token_claims(existing, _agentic_claims())
98+
merged = json.loads(result)
99+
self.assertEqual(len(merged), 2)
100+
self.assertEqual(len(merged["access_token"]), 1)
101+
self.assertEqual(merged["id_token"], {"auth_time": {"essential": True}})
102+
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})
103+
104+
def test_handles_null_access_token_in_existing(self):
105+
existing = json.dumps({"access_token": None})
106+
result = merge_access_token_claims(existing, _agentic_claims())
107+
merged = json.loads(result)
108+
self.assertEqual(len(merged), 1)
109+
self.assertEqual(len(merged["access_token"]), 1)
110+
self.assertEqual(merged["access_token"]["xms_cli_sid"], {"values": ["s1"]})
111+
112+
113+
class TestUserCredentialAgenticSession(unittest.TestCase):
114+
"""Verify that UserCredential.acquire_token merges agentic claims and passes
115+
client_session param when COPILOT_AGENT_SESSION_ID is set."""
116+
117+
def _build_user_credential(self, enable_broker=False):
118+
"""Build a UserCredential with mocked MSAL app."""
119+
from unittest.mock import MagicMock, PropertyMock
120+
from azure.cli.core.auth.msal_credentials import UserCredential
121+
122+
cred = object.__new__(UserCredential)
123+
124+
cred._msal_app = MagicMock()
125+
cred._msal_app.client_id = "test-client-id"
126+
cred._msal_app._enable_broker = enable_broker
127+
type(cred._msal_app).authority = PropertyMock(return_value=MagicMock(
128+
instance="login.microsoftonline.com",
129+
tenant="test-tenant",
130+
is_adfs=False,
131+
))
132+
cred._account = {
133+
"home_account_id": "uid.utid",
134+
"username": "user@test.com",
135+
}
136+
return cred
137+
138+
@patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "agent-sess-1"})
139+
def test_non_broker_passes_data_only(self):
140+
"""Non-broker path: client_session in data for ext_cache_key, no claims_challenge."""
141+
cred = self._build_user_credential(enable_broker=False)
142+
cred._msal_app.acquire_token_silent_with_error.return_value = {
143+
"access_token": "agent-tagged-token",
144+
"token_type": "Bearer",
145+
"expires_in": 3600,
146+
}
147+
148+
result = cred.acquire_token(["https://management.azure.com/.default"])
149+
150+
self.assertEqual(result["access_token"], "agent-tagged-token")
151+
152+
call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
153+
self.assertIsNone(call_kwargs.kwargs.get("claims_challenge"))
154+
self.assertEqual(call_kwargs.kwargs["data"], {"client_session": "agent-sess-1"})
155+
self.assertEqual(call_kwargs.kwargs["params"], {"client_session": "agent-sess-1"})
156+
157+
@patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "agent-sess-1"})
158+
def test_broker_passes_claims_and_data(self):
159+
"""Broker path: claims_challenge with xms_cli_sid AND client_session in data."""
160+
cred = self._build_user_credential(enable_broker=True)
161+
cred._msal_app.acquire_token_silent_with_error.return_value = {
162+
"access_token": "agent-tagged-token",
163+
"token_type": "Bearer",
164+
"expires_in": 3600,
165+
}
166+
167+
result = cred.acquire_token(["https://management.azure.com/.default"])
168+
169+
self.assertEqual(result["access_token"], "agent-tagged-token")
170+
171+
call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
172+
claims = json.loads(call_kwargs.kwargs["claims_challenge"])
173+
self.assertEqual(claims["access_token"]["xms_cli_sid"]["values"], ["agent-sess-1"])
174+
self.assertEqual(call_kwargs.kwargs["data"], {"client_session": "agent-sess-1"})
175+
self.assertEqual(call_kwargs.kwargs["params"], {"client_session": "agent-sess-1"})
176+
177+
@patch.dict(os.environ, {}, clear=True)
178+
def test_no_agentic_params_without_env(self):
179+
"""When COPILOT_AGENT_SESSION_ID is not set, no agentic params are added."""
180+
cred = self._build_user_credential(enable_broker=False)
181+
cred._msal_app.acquire_token_silent_with_error.return_value = {
182+
"access_token": "normal-token",
183+
"token_type": "Bearer",
184+
"expires_in": 3600,
185+
}
186+
187+
result = cred.acquire_token(["https://management.azure.com/.default"])
188+
189+
self.assertEqual(result["access_token"], "normal-token")
190+
191+
call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
192+
self.assertIsNone(call_kwargs.kwargs.get("claims_challenge"))
193+
self.assertNotIn("params", call_kwargs.kwargs)
194+
195+
@patch.dict(os.environ, {COPILOT_AGENT_SESSION_ID: "agent-sess-2"})
196+
def test_broker_merges_with_existing_claims(self):
197+
"""Broker path: agentic claims are merged with existing claims_challenge."""
198+
cred = self._build_user_credential(enable_broker=True)
199+
cred._msal_app.acquire_token_silent_with_error.return_value = {
200+
"access_token": "token",
201+
"token_type": "Bearer",
202+
"expires_in": 3600,
203+
}
204+
205+
existing_claims = json.dumps({"access_token": {"nbf": {"essential": True, "value": "999"}}})
206+
cred.acquire_token(["scope"], claims_challenge=existing_claims)
207+
208+
call_kwargs = cred._msal_app.acquire_token_silent_with_error.call_args
209+
claims = json.loads(call_kwargs.kwargs["claims_challenge"])
210+
self.assertEqual(claims["access_token"]["nbf"], {"essential": True, "value": "999"})
211+
self.assertEqual(claims["access_token"]["xms_cli_sid"]["values"], ["agent-sess-2"])
212+
213+
214+
if __name__ == '__main__':
215+
unittest.main()

src/azure-cli-core/azure/cli/core/telemetry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(self, correlation_id=None, application=None):
7979
self.enable_broker_on_windows = None
8080
self.msal_telemetry = None
8181
self.login_experience_v2 = None
82+
self.agentic_session = False
8283

8384
def add_event(self, name, properties):
8485
for key in self.instrumentation_key:
@@ -239,6 +240,7 @@ def _get_azure_cli_properties(self):
239240
set_custom_properties(result, 'EnableBrokerOnWindows', str(self.enable_broker_on_windows))
240241
set_custom_properties(result, 'MsalTelemetry', self.msal_telemetry)
241242
set_custom_properties(result, 'LoginExperienceV2', str(self.login_experience_v2))
243+
set_custom_properties(result, 'AgenticSession', str(self.agentic_session))
242244

243245
return result
244246

@@ -497,6 +499,11 @@ def set_msal_telemetry(msal_telemetry):
497499
@decorators.suppress_all_exceptions()
498500
def set_login_experience_v2(login_experience_v2):
499501
_session.login_experience_v2 = login_experience_v2
502+
503+
504+
@decorators.suppress_all_exceptions()
505+
def set_agentic_session(agentic_session):
506+
_session.agentic_session = agentic_session
500507
# endregion
501508

502509

src/azure-cli-core/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555
'knack~=0.14.0',
5656
'microsoft-security-utilities-secret-masker~=1.0.0b4',
5757
'msal-extensions==1.3.1',
58-
'msal[broker]==1.35.1; sys_platform == "win32"',
59-
'msal==1.35.1; sys_platform != "win32"',
58+
'msal[broker]==1.36.0; sys_platform == "win32"',
59+
'msal==1.36.0; sys_platform != "win32"',
6060
'packaging>=20.9',
6161
'pkginfo>=1.5.0.1',
6262
# psutil can't install on cygwin: https://github.com/Azure/azure-cli/issues/9399

0 commit comments

Comments
 (0)