Skip to content

Commit 0503d4a

Browse files
committed
feat: add token refresher
1 parent 95b546d commit 0503d4a

9 files changed

Lines changed: 584 additions & 7 deletions

File tree

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ jobs:
2020

2121
lint:
2222
uses: ./.github/workflows/lint.yml
23+
24+
test:
25+
uses: ./.github/workflows/test.yml

.github/workflows/test.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: Test
2+
3+
on:
4+
workflow_call
5+
6+
jobs:
7+
test:
8+
name: Test
9+
runs-on: ${{ matrix.os }}
10+
timeout-minutes: 10
11+
strategy:
12+
matrix:
13+
python-version: ["3.11", "3.12", "3.13"]
14+
os: [ubuntu-latest, windows-latest]
15+
16+
permissions:
17+
contents: read
18+
19+
steps:
20+
- name: Checkout
21+
uses: actions/checkout@v4
22+
23+
- name: Setup uv
24+
uses: astral-sh/setup-uv@v5
25+
26+
- name: Setup Python
27+
uses: actions/setup-python@v5
28+
with:
29+
python-version: ${{ matrix.python-version }}
30+
31+
- name: Install dependencies
32+
run: uv sync --all-extras
33+
34+
- name: Run tests
35+
run: uv run pytest
36+
37+
continue-on-error: true

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dev = [
4242
"mypy>=1.14.1",
4343
"ruff>=0.9.4",
4444
"pytest>=7.4.0",
45+
"pytest-asyncio>=0.23.0",
4546
"pytest-cov>=4.1.0",
4647
"pytest-mock>=3.11.1",
4748
"pre-commit>=4.5.1",
@@ -84,6 +85,7 @@ disallow_untyped_defs = false
8485
testpaths = ["tests"]
8586
python_files = "test_*.py"
8687
addopts = "-ra -q"
88+
asyncio_mode = "auto"
8789

8890
[[tool.uv.index]]
8991
name = "testpypi"

src/uipath_mcp/_cli/_runtime/_runtime.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ._context import UiPathServerType
3939
from ._exception import McpErrorCode, UiPathMcpRuntimeError
4040
from ._session import BaseSessionServer, StdioSessionServer, StreamableHttpSessionServer
41+
from ._token_refresh import TokenRefresher
4142

4243
logger = logging.getLogger(__name__)
4344
tracer = trace.get_tracer(__name__)
@@ -85,6 +86,7 @@ def __init__(
8586
self._http_stderr_drain_task: asyncio.Task[None] | None = None
8687
self._http_server_stderr_lines: list[str] = []
8788
self._uipath = UiPath()
89+
self._token_refresher: TokenRefresher | None = None
8890
self._cleanup_done = False
8991

9092
# Context fields from UiPathConfig
@@ -207,6 +209,8 @@ async def _run_server(self) -> UiPathRuntimeResult:
207209
root_span.set_attribute("args", json.dumps(self._server.args))
208210
root_span.set_attribute("span_type", "MCP Server")
209211
bearer_token = self._uipath._config.secret
212+
self._token_refresher = TokenRefresher(self._uipath)
213+
210214
self._signalr_client = SignalRClient(
211215
signalr_url,
212216
headers={
@@ -236,6 +240,7 @@ async def _run_server(self) -> UiPathRuntimeResult:
236240
run_task = asyncio.create_task(self._signalr_client.run())
237241
cancel_task = asyncio.create_task(self._cancel_event.wait())
238242
self._keep_alive_task = asyncio.create_task(self._keep_alive())
243+
self._token_refresher.start()
239244

240245
try:
241246
# Wait for either the run to complete or cancellation
@@ -297,6 +302,9 @@ async def _cleanup(self) -> None:
297302

298303
await self._on_runtime_abort()
299304

305+
if self._token_refresher:
306+
await self._token_refresher.stop()
307+
300308
if self._keep_alive_task:
301309
self._keep_alive_task.cancel()
302310
try:
@@ -374,11 +382,11 @@ async def _handle_signalr_message(self, args: list[str]) -> None:
374382
session_server: BaseSessionServer
375383
if self._server.is_streamable_http:
376384
session_server = StreamableHttpSessionServer(
377-
self._server, self.slug, session_id
385+
self._server, self.slug, session_id, self._uipath
378386
)
379387
else:
380388
session_server = StdioSessionServer(
381-
self._server, self.slug, session_id
389+
self._server, self.slug, session_id, self._uipath
382390
)
383391
try:
384392
await session_server.start()

src/uipath_mcp/_cli/_runtime/_session.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@
3131
class BaseSessionServer(ABC):
3232
"""Base class with transport-agnostic message relay logic."""
3333

34-
def __init__(self, server_config: McpServer, server_slug: str, session_id: str):
34+
def __init__(
35+
self,
36+
server_config: McpServer,
37+
server_slug: str,
38+
session_id: str,
39+
uipath: UiPath,
40+
):
3541
self._server_config = server_config
3642
self._server_slug = server_slug
3743
self._session_id = session_id
@@ -42,7 +48,7 @@ def __init__(self, server_config: McpServer, server_slug: str, session_id: str):
4248
self._active_requests: dict[str, str] = {}
4349
self._last_request_id: str | None = None
4450
self._last_message_id: str | None = None
45-
self._uipath = UiPath()
51+
self._uipath = uipath
4652
self._mcp_tracer = McpTracer(tracer, logger)
4753

4854
@property
@@ -284,9 +290,7 @@ def _get_message_id(self, message: JSONRPCMessage) -> str:
284290
class StdioSessionServer(BaseSessionServer):
285291
"""Manages a stdio server process for a specific session."""
286292

287-
def __init__(self, server_config: McpServer, server_slug: str, session_id: str):
288-
super().__init__(server_config, server_slug, session_id)
289-
self._server_stderr_output: str | None = None
293+
_server_stderr_output: str | None = None
290294

291295
@property
292296
def output(self) -> str | None:
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import asyncio
2+
import logging
3+
import os
4+
import time
5+
from enum import Enum
6+
7+
import httpx
8+
from uipath._cli._auth._portal_service import PortalService
9+
from uipath._cli._auth._url_utils import build_service_url, resolve_domain
10+
from uipath._cli._auth._utils import get_auth_data, update_auth_file
11+
from uipath._utils._auth import parse_access_token
12+
from uipath._utils._ssl_context import get_httpx_client_kwargs
13+
from uipath._utils.constants import ENV_UIPATH_ACCESS_TOKEN
14+
from uipath.platform import UiPath
15+
from uipath.platform.common import TokenData
16+
from uipath.platform.common._config import UiPathApiConfig
17+
18+
logger = logging.getLogger(__name__)
19+
20+
REFRESH_MARGIN_SECONDS = 300 # Refresh 5 minutes before expiry
21+
FALLBACK_REFRESH_INTERVAL = 45 * 60 # 45 minutes when exp claim is unavailable
22+
MAX_RETRY_ATTEMPTS = 3
23+
RETRY_BASE_DELAY = 5 # seconds
24+
RETRY_FALLBACK_INTERVAL = 60 # seconds to wait after all retries fail
25+
26+
27+
class AuthStrategy(Enum):
28+
OAUTH = "oauth"
29+
CLIENT_CREDENTIALS = "client_credentials"
30+
NONE = "none"
31+
32+
33+
class TokenRefresher:
34+
"""Manages token refresh for long-lived MCP runtime connections."""
35+
36+
def __init__(self, uipath: UiPath):
37+
self._uipath = uipath
38+
self._refresh_task: asyncio.Task[None] | None = None
39+
self._cancel_event = asyncio.Event()
40+
41+
self._client_id: str | None = os.environ.get("UIPATH_CLIENT_ID")
42+
self._client_secret: str | None = os.environ.get("UIPATH_CLIENT_SECRET")
43+
44+
self._base_url: str = uipath._config.base_url
45+
self._domain: str = resolve_domain(self._base_url, environment=None)
46+
47+
self._strategy = self._detect_strategy()
48+
self._token_url: str | None = self._resolve_token_url()
49+
50+
if (
51+
self._strategy == AuthStrategy.CLIENT_CREDENTIALS
52+
and self._token_url is None
53+
):
54+
logger.error("Token refresh disabled: could not resolve token URL")
55+
self._strategy = AuthStrategy.NONE
56+
57+
def _detect_strategy(self) -> AuthStrategy:
58+
"""Detect which auth flow is available for token refresh."""
59+
if self._client_id and self._client_secret:
60+
return AuthStrategy.CLIENT_CREDENTIALS
61+
62+
try:
63+
auth_data = get_auth_data()
64+
if auth_data.refresh_token:
65+
return AuthStrategy.OAUTH
66+
except Exception as e:
67+
logger.debug(f"Could not read auth file for strategy detection: {e}")
68+
69+
return AuthStrategy.NONE
70+
71+
def _resolve_token_url(self) -> str | None:
72+
"""Derive the identity token endpoint for client_credentials flow."""
73+
if self._strategy != AuthStrategy.CLIENT_CREDENTIALS:
74+
return None
75+
76+
try:
77+
return build_service_url(self._domain, "/identity_/connect/token")
78+
except Exception as e:
79+
logger.error(
80+
f"Could not resolve token URL from base_url '{self._base_url}': {e}"
81+
)
82+
return None
83+
84+
@property
85+
def strategy(self) -> AuthStrategy:
86+
return self._strategy
87+
88+
def start(self) -> None:
89+
"""Start the background refresh task."""
90+
if self._strategy == AuthStrategy.NONE:
91+
logger.info("No token refresh strategy available; refresh disabled")
92+
return
93+
94+
self._cancel_event.clear()
95+
self._refresh_task = asyncio.create_task(self._refresh_loop())
96+
logger.info("Token refresh background task started")
97+
98+
async def stop(self) -> None:
99+
"""Stop the background refresh task."""
100+
self._cancel_event.set()
101+
if self._refresh_task and not self._refresh_task.done():
102+
self._refresh_task.cancel()
103+
try:
104+
await asyncio.wait_for(self._refresh_task, timeout=5.0)
105+
except (asyncio.CancelledError, asyncio.TimeoutError):
106+
pass
107+
self._refresh_task = None
108+
logger.info("Token refresh stopped")
109+
110+
async def _wait_for_cancel(self, seconds: float) -> bool:
111+
"""Sleep for `seconds`, returning True if cancellation was requested."""
112+
try:
113+
await asyncio.wait_for(self._cancel_event.wait(), timeout=seconds)
114+
return True
115+
except asyncio.TimeoutError:
116+
return False
117+
118+
async def _refresh_loop(self) -> None:
119+
"""Background loop that refreshes the token before expiry."""
120+
try:
121+
while not self._cancel_event.is_set():
122+
wait_seconds = self._seconds_until_refresh()
123+
if wait_seconds > 0 and await self._wait_for_cancel(wait_seconds):
124+
break
125+
126+
if not await self._try_refresh() and not self._cancel_event.is_set():
127+
logger.error(
128+
"All token refresh attempts failed. "
129+
"The token may expire causing failures."
130+
)
131+
# Avoid retry loop when the token is already expired
132+
if await self._wait_for_cancel(RETRY_FALLBACK_INTERVAL):
133+
break
134+
except asyncio.CancelledError:
135+
logger.info("Token refresh loop cancelled")
136+
raise
137+
138+
async def _try_refresh(self) -> bool:
139+
"""Attempt to refresh the token with retries. Returns True on success."""
140+
for attempt in range(MAX_RETRY_ATTEMPTS):
141+
try:
142+
if self._strategy == AuthStrategy.OAUTH:
143+
token_data = await self._refresh_oauth()
144+
else:
145+
token_data = await self._refresh_client_credentials()
146+
147+
self._propagate_token(token_data)
148+
logger.info("Token refreshed successfully.")
149+
return True
150+
151+
except Exception as e:
152+
safe_msg = (
153+
f"HTTP {e.response.status_code}"
154+
if isinstance(e, httpx.HTTPStatusError)
155+
else type(e).__name__
156+
)
157+
logger.error(
158+
f"Token refresh attempt {attempt + 1}/{MAX_RETRY_ATTEMPTS} "
159+
f"failed: {safe_msg}"
160+
)
161+
if attempt < MAX_RETRY_ATTEMPTS - 1:
162+
logger.info(f"Retrying in {RETRY_BASE_DELAY}s...")
163+
if await self._wait_for_cancel(RETRY_BASE_DELAY):
164+
return False
165+
166+
return False
167+
168+
async def _refresh_oauth(self) -> TokenData:
169+
"""Refresh using OAuth refresh_token grant."""
170+
auth_data = get_auth_data()
171+
refresh_token = auth_data.refresh_token
172+
if not refresh_token:
173+
raise ValueError("No refresh_token found in .uipath/.auth.json")
174+
175+
def _do_refresh() -> TokenData:
176+
with PortalService(domain=self._domain) as portal:
177+
return portal.refresh_access_token(refresh_token)
178+
179+
# run in a thread to avoid blocking
180+
token_data = await asyncio.to_thread(_do_refresh)
181+
182+
try:
183+
update_auth_file(token_data)
184+
except Exception as e:
185+
logger.warning(f"Failed to update .auth.json: {type(e).__name__}")
186+
187+
return token_data
188+
189+
async def _refresh_client_credentials(self) -> TokenData:
190+
"""Refresh using client_credentials grant."""
191+
if self._token_url is None:
192+
raise RuntimeError("token_url must be set for client_credentials strategy")
193+
194+
data = {
195+
"grant_type": "client_credentials",
196+
"client_id": self._client_id,
197+
"client_secret": self._client_secret,
198+
"scope": os.environ.get("UIPATH_CLIENT_SCOPE", "OR.Execution"),
199+
}
200+
201+
async with httpx.AsyncClient(**get_httpx_client_kwargs()) as client:
202+
response = await client.post(
203+
self._token_url,
204+
data=data,
205+
headers={"Content-Type": "application/x-www-form-urlencoded"},
206+
)
207+
response.raise_for_status()
208+
return TokenData.model_validate(response.json())
209+
210+
def _propagate_token(self, token_data: TokenData) -> None:
211+
"""Update all token consumers after a successful refresh."""
212+
new_token = token_data.access_token
213+
214+
self._uipath._config = UiPathApiConfig(
215+
base_url=self._uipath._config.base_url,
216+
secret=new_token,
217+
)
218+
219+
os.environ[ENV_UIPATH_ACCESS_TOKEN] = new_token
220+
221+
def _seconds_until_refresh(self) -> float:
222+
"""Calculate seconds to wait before next refresh attempt."""
223+
try:
224+
claims = parse_access_token(self._uipath._config.secret)
225+
exp = claims.get("exp")
226+
if exp is not None:
227+
remaining = float(exp) - time.time()
228+
if remaining <= REFRESH_MARGIN_SECONDS:
229+
return 0
230+
return remaining - REFRESH_MARGIN_SECONDS
231+
except Exception as e:
232+
logger.warning(f"Failed to parse token expiry: {e}")
233+
234+
return FALLBACK_REFRESH_INTERVAL

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import pytest
2+
3+
4+
@pytest.fixture(autouse=True)
5+
def clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
6+
"""Clean environment variables before each test."""
7+
monkeypatch.delenv("UIPATH_URL", raising=False)
8+
monkeypatch.delenv("UIPATH_ACCESS_TOKEN", raising=False)
9+
monkeypatch.delenv("UIPATH_CLIENT_ID", raising=False)
10+
monkeypatch.delenv("UIPATH_CLIENT_SECRET", raising=False)
11+
monkeypatch.delenv("UIPATH_CLIENT_SCOPE", raising=False)

0 commit comments

Comments
 (0)