Skip to content

Commit 6e8bc21

Browse files
committed
feat: add token refresher
1 parent 5717b13 commit 6e8bc21

7 files changed

Lines changed: 560 additions & 15 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-mcp"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "UiPath MCP SDK"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"
@@ -42,6 +42,7 @@ dev = [
4242
"mypy>=1.14.1",
4343
"ruff>=0.9.4",
4444
"pytest>=7.4.0",
45+
"pytest-asyncio>=0.23.0",
4546
"pytest-cov>=4.1.0",
4647
"pytest-mock>=3.11.1",
4748
"pre-commit>=4.5.1",
@@ -84,6 +85,7 @@ disallow_untyped_defs = false
8485
testpaths = ["tests"]
8586
python_files = "test_*.py"
8687
addopts = "-ra -q"
88+
asyncio_mode = "auto"
8789

8890
[[tool.uv.index]]
8991
name = "testpypi"

src/uipath_mcp/_cli/_runtime/_runtime.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ._context import UiPathServerType
3939
from ._exception import McpErrorCode, UiPathMcpRuntimeError
4040
from ._session import BaseSessionServer, StdioSessionServer, StreamableHttpSessionServer
41+
from ._token_refresh import TokenRefresher
4142

4243
logger = logging.getLogger(__name__)
4344
tracer = trace.get_tracer(__name__)
@@ -85,6 +86,7 @@ def __init__(
8586
self._http_stderr_drain_task: asyncio.Task[None] | None = None
8687
self._http_server_stderr_lines: list[str] = []
8788
self._uipath = UiPath()
89+
self._token_refresher: TokenRefresher | None = None
8890
self._cleanup_done = False
8991

9092
# Context fields from UiPathConfig
@@ -206,15 +208,19 @@ async def _run_server(self) -> UiPathRuntimeResult:
206208
root_span.set_attribute("command", str(self._server.command))
207209
root_span.set_attribute("args", json.dumps(self._server.args))
208210
root_span.set_attribute("span_type", "MCP Server")
209-
bearer_token = self._uipath._config.secret
211+
212+
signalr_headers = {
213+
"X-UiPath-Internal-TenantId": str(self._tenant_id),
214+
"X-UiPath-Internal-AccountId": str(self._org_id),
215+
"X-UIPATH-FolderKey": self._folder_key,
216+
"Authorization": f"Bearer {self._uipath._config.secret}",
217+
}
218+
219+
self._token_refresher = TokenRefresher(self._uipath)
220+
210221
self._signalr_client = SignalRClient(
211222
signalr_url,
212-
headers={
213-
"X-UiPath-Internal-TenantId": str(self._tenant_id),
214-
"X-UiPath-Internal-AccountId": str(self._org_id),
215-
"X-UIPATH-FolderKey": self._folder_key,
216-
"Authorization": f"Bearer {bearer_token}",
217-
},
223+
headers=signalr_headers,
218224
)
219225
self._signalr_client.on("MessageReceived", self._handle_signalr_message)
220226
self._signalr_client.on(
@@ -236,6 +242,7 @@ async def _run_server(self) -> UiPathRuntimeResult:
236242
run_task = asyncio.create_task(self._signalr_client.run())
237243
cancel_task = asyncio.create_task(self._cancel_event.wait())
238244
self._keep_alive_task = asyncio.create_task(self._keep_alive())
245+
self._token_refresher.start()
239246

240247
try:
241248
# Wait for either the run to complete or cancellation
@@ -297,6 +304,9 @@ async def _cleanup(self) -> None:
297304

298305
await self._on_runtime_abort()
299306

307+
if self._token_refresher:
308+
await self._token_refresher.stop()
309+
300310
if self._keep_alive_task:
301311
self._keep_alive_task.cancel()
302312
try:
@@ -374,11 +384,11 @@ async def _handle_signalr_message(self, args: list[str]) -> None:
374384
session_server: BaseSessionServer
375385
if self._server.is_streamable_http:
376386
session_server = StreamableHttpSessionServer(
377-
self._server, self.slug, session_id
387+
self._server, self.slug, session_id, self._uipath
378388
)
379389
else:
380390
session_server = StdioSessionServer(
381-
self._server, self.slug, session_id
391+
self._server, self.slug, session_id, self._uipath
382392
)
383393
try:
384394
await session_server.start()

src/uipath_mcp/_cli/_runtime/_session.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@
3131
class BaseSessionServer(ABC):
3232
"""Base class with transport-agnostic message relay logic."""
3333

34-
def __init__(self, server_config: McpServer, server_slug: str, session_id: str):
34+
def __init__(
35+
self,
36+
server_config: McpServer,
37+
server_slug: str,
38+
session_id: str,
39+
uipath: UiPath,
40+
):
3541
self._server_config = server_config
3642
self._server_slug = server_slug
3743
self._session_id = session_id
@@ -42,7 +48,7 @@ def __init__(self, server_config: McpServer, server_slug: str, session_id: str):
4248
self._active_requests: dict[str, str] = {}
4349
self._last_request_id: str | None = None
4450
self._last_message_id: str | None = None
45-
self._uipath = UiPath()
51+
self._uipath = uipath
4652
self._mcp_tracer = McpTracer(tracer, logger)
4753

4854
@property
@@ -284,8 +290,14 @@ def _get_message_id(self, message: JSONRPCMessage) -> str:
284290
class StdioSessionServer(BaseSessionServer):
285291
"""Manages a stdio server process for a specific session."""
286292

287-
def __init__(self, server_config: McpServer, server_slug: str, session_id: str):
288-
super().__init__(server_config, server_slug, session_id)
293+
def __init__(
294+
self,
295+
server_config: McpServer,
296+
server_slug: str,
297+
session_id: str,
298+
uipath: UiPath,
299+
):
300+
super().__init__(server_config, server_slug, session_id, uipath)
289301
self._server_stderr_output: str | None = None
290302

291303
@property
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import asyncio
2+
import logging
3+
import os
4+
import time
5+
from enum import Enum
6+
7+
import httpx
8+
from uipath._cli._auth._portal_service import PortalService
9+
from uipath._cli._auth._url_utils import build_service_url, resolve_domain
10+
from uipath._cli._auth._utils import get_auth_data, update_auth_file
11+
from uipath._utils._auth import parse_access_token
12+
from uipath._utils._ssl_context import get_httpx_client_kwargs
13+
from uipath._utils.constants import ENV_UIPATH_ACCESS_TOKEN
14+
from uipath.platform import UiPath
15+
from uipath.platform.common import TokenData
16+
from uipath.platform.common._config import UiPathApiConfig
17+
18+
logger = logging.getLogger(__name__)
19+
20+
REFRESH_MARGIN_SECONDS = 300 # Refresh 5 minutes before expiry
21+
FALLBACK_REFRESH_INTERVAL = 45 * 60 # 45 minutes when exp claim is unavailable
22+
MAX_RETRY_ATTEMPTS = 3
23+
RETRY_BASE_DELAY = 5 # seconds
24+
RETRY_FALLBACK_INTERVAL = 60 # seconds to wait after all retries fail
25+
26+
27+
class AuthStrategy(Enum):
28+
OAUTH = "oauth"
29+
CLIENT_CREDENTIALS = "client_credentials"
30+
NONE = "none"
31+
32+
33+
class TokenRefresher:
34+
"""Manages token refresh for long-lived MCP runtime connections."""
35+
36+
def __init__(self, uipath: UiPath):
37+
self._uipath = uipath
38+
self._refresh_task: asyncio.Task[None] | None = None
39+
self._cancel_event = asyncio.Event()
40+
41+
self._client_id: str | None = os.environ.get("UIPATH_CLIENT_ID")
42+
self._client_secret: str | None = os.environ.get("UIPATH_CLIENT_SECRET")
43+
44+
self._base_url: str = uipath._config.base_url
45+
self._domain: str = resolve_domain(self._base_url, environment=None)
46+
47+
self._strategy = self._detect_strategy()
48+
self._token_url: str | None = self._resolve_token_url()
49+
50+
if (
51+
self._strategy == AuthStrategy.CLIENT_CREDENTIALS
52+
and self._token_url is None
53+
):
54+
logger.error("Token refresh disabled: could not resolve token URL")
55+
self._strategy = AuthStrategy.NONE
56+
57+
def _detect_strategy(self) -> AuthStrategy:
58+
"""Detect which auth flow is available for token refresh."""
59+
if self._client_id and self._client_secret:
60+
return AuthStrategy.CLIENT_CREDENTIALS
61+
62+
try:
63+
auth_data = get_auth_data()
64+
if auth_data.refresh_token:
65+
return AuthStrategy.OAUTH
66+
except Exception as e:
67+
logger.debug(f"Could not read auth file for strategy detection: {e}")
68+
69+
return AuthStrategy.NONE
70+
71+
def _resolve_token_url(self) -> str | None:
72+
"""Derive the identity token endpoint for client_credentials flow."""
73+
if self._strategy != AuthStrategy.CLIENT_CREDENTIALS:
74+
return None
75+
76+
try:
77+
return build_service_url(self._domain, "/identity_/connect/token")
78+
except Exception as e:
79+
logger.error(
80+
f"Could not resolve token URL from base_url '{self._base_url}': {e}"
81+
)
82+
return None
83+
84+
@property
85+
def strategy(self) -> AuthStrategy:
86+
return self._strategy
87+
88+
def start(self) -> None:
89+
"""Start the background refresh task."""
90+
if self._strategy == AuthStrategy.NONE:
91+
logger.info("No token refresh strategy available; refresh disabled")
92+
return
93+
94+
self._cancel_event.clear()
95+
self._refresh_task = asyncio.create_task(self._refresh_loop())
96+
logger.info("Token refresh background task started")
97+
98+
async def stop(self) -> None:
99+
"""Stop the background refresh task."""
100+
self._cancel_event.set()
101+
if self._refresh_task and not self._refresh_task.done():
102+
self._refresh_task.cancel()
103+
try:
104+
await asyncio.wait_for(self._refresh_task, timeout=5.0)
105+
except (asyncio.CancelledError, asyncio.TimeoutError):
106+
pass
107+
self._refresh_task = None
108+
logger.info("Token refresh stopped")
109+
110+
async def _wait_for_cancel(self, seconds: float) -> bool:
111+
"""Sleep for `seconds`, returning True if cancellation was requested."""
112+
try:
113+
await asyncio.wait_for(self._cancel_event.wait(), timeout=seconds)
114+
return True
115+
except asyncio.TimeoutError:
116+
return False
117+
118+
async def _refresh_loop(self) -> None:
119+
"""Background loop that refreshes the token before expiry."""
120+
try:
121+
while not self._cancel_event.is_set():
122+
wait_seconds = self._seconds_until_refresh()
123+
if wait_seconds > 0 and await self._wait_for_cancel(wait_seconds):
124+
break
125+
126+
if not await self._try_refresh() and not self._cancel_event.is_set():
127+
logger.error(
128+
"All token refresh attempts failed. "
129+
"The token may expire causing failures."
130+
)
131+
# Avoid retry loop when the token is already expired
132+
if await self._wait_for_cancel(RETRY_FALLBACK_INTERVAL):
133+
break
134+
except asyncio.CancelledError:
135+
logger.info("Token refresh loop cancelled")
136+
raise
137+
138+
async def _try_refresh(self) -> bool:
139+
"""Attempt to refresh the token with retries. Returns True on success."""
140+
for attempt in range(MAX_RETRY_ATTEMPTS):
141+
try:
142+
if self._strategy == AuthStrategy.OAUTH:
143+
token_data = await self._refresh_oauth()
144+
else:
145+
token_data = await self._refresh_client_credentials()
146+
147+
self._propagate_token(token_data)
148+
logger.info("Token refreshed successfully.")
149+
return True
150+
151+
except Exception as e:
152+
safe_msg = (
153+
f"HTTP {e.response.status_code}"
154+
if isinstance(e, httpx.HTTPStatusError)
155+
else type(e).__name__
156+
)
157+
logger.error(
158+
f"Token refresh attempt {attempt + 1}/{MAX_RETRY_ATTEMPTS} "
159+
f"failed: {safe_msg}"
160+
)
161+
if attempt < MAX_RETRY_ATTEMPTS - 1:
162+
logger.info(f"Retrying in {RETRY_BASE_DELAY}s...")
163+
if await self._wait_for_cancel(RETRY_BASE_DELAY):
164+
return False
165+
166+
return False
167+
168+
async def _refresh_oauth(self) -> TokenData:
169+
"""Refresh using OAuth refresh_token grant."""
170+
auth_data = get_auth_data()
171+
refresh_token = auth_data.refresh_token
172+
if not refresh_token:
173+
raise ValueError("No refresh_token found in .uipath/.auth.json")
174+
175+
def _do_refresh() -> TokenData:
176+
with PortalService(domain=self._domain) as portal:
177+
return portal.refresh_access_token(refresh_token)
178+
179+
# run in a thread to avoid blocking
180+
token_data = await asyncio.to_thread(_do_refresh)
181+
182+
try:
183+
update_auth_file(token_data)
184+
except Exception as e:
185+
logger.warning(f"Failed to update .auth.json: {type(e).__name__}")
186+
187+
return token_data
188+
189+
async def _refresh_client_credentials(self) -> TokenData:
190+
"""Refresh using client_credentials grant."""
191+
assert self._token_url is not None, (
192+
"token_url must be set for client_credentials strategy"
193+
)
194+
195+
data = {
196+
"grant_type": "client_credentials",
197+
"client_id": self._client_id,
198+
"client_secret": self._client_secret,
199+
"scope": os.environ.get("UIPATH_CLIENT_SCOPE", "OR.Execution"),
200+
}
201+
202+
async with httpx.AsyncClient(**get_httpx_client_kwargs()) as client:
203+
response = await client.post(
204+
self._token_url,
205+
data=data,
206+
headers={"Content-Type": "application/x-www-form-urlencoded"},
207+
)
208+
response.raise_for_status()
209+
return TokenData.model_validate(response.json())
210+
211+
def _propagate_token(self, token_data: TokenData) -> None:
212+
"""Update all token consumers after a successful refresh."""
213+
new_token = token_data.access_token
214+
215+
self._uipath._config = UiPathApiConfig(
216+
base_url=self._uipath._config.base_url,
217+
secret=new_token,
218+
)
219+
220+
os.environ[ENV_UIPATH_ACCESS_TOKEN] = new_token
221+
222+
def _seconds_until_refresh(self) -> float:
223+
"""Calculate seconds to wait before next refresh attempt."""
224+
try:
225+
claims = parse_access_token(self._uipath._config.secret)
226+
exp = claims.get("exp")
227+
if exp is not None:
228+
remaining = float(exp) - time.time()
229+
if remaining <= REFRESH_MARGIN_SECONDS:
230+
return 0
231+
return remaining - REFRESH_MARGIN_SECONDS
232+
except Exception as e:
233+
logger.warning(f"Failed to parse token expiry: {e}")
234+
235+
return FALLBACK_REFRESH_INTERVAL

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import pytest
2+
3+
4+
@pytest.fixture(autouse=True)
5+
def clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
6+
"""Clean environment variables before each test."""
7+
monkeypatch.delenv("UIPATH_URL", raising=False)
8+
monkeypatch.delenv("UIPATH_ACCESS_TOKEN", raising=False)
9+
monkeypatch.delenv("UIPATH_CLIENT_ID", raising=False)
10+
monkeypatch.delenv("UIPATH_CLIENT_SECRET", raising=False)
11+
monkeypatch.delenv("UIPATH_CLIENT_SCOPE", raising=False)

0 commit comments

Comments
 (0)