Skip to content

Commit 61612a8

Browse files
First iteration of changes to SDK handling new client secrets and token-exchange
1 parent 5171475 commit 61612a8

15 files changed

Lines changed: 1590 additions & 4 deletions
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Shared helpers for exercising token-exchange auth callables.
2+
3+
The mock transports here let tests script a sequence of responses / exceptions
4+
for the ``POST /auth/token-exchange`` endpoint without standing up a real
5+
account-service.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import json
11+
from typing import Any, Callable, Iterable, List, Optional, Union
12+
13+
import httpx
14+
15+
16+
ResponseStep = Union[httpx.Response, Exception, Callable[[httpx.Request], httpx.Response]]
17+
18+
19+
class ScriptedTransport(httpx.MockTransport):
20+
"""A MockTransport that walks through a scripted sequence of responses.
21+
22+
Each element can be an :class:`httpx.Response`, an ``Exception`` instance
23+
(raised instead of returned), or a callable that accepts the request and
24+
returns a response. Tests can inspect :attr:`requests` to assert how many
25+
exchanges took place and what bodies were sent.
26+
"""
27+
28+
def __init__(self, steps: Iterable[ResponseStep]) -> None:
29+
self._steps: List[ResponseStep] = list(steps)
30+
self.requests: List[httpx.Request] = []
31+
super().__init__(self._handler)
32+
33+
def _handler(self, request: httpx.Request) -> httpx.Response:
34+
self.requests.append(request)
35+
if not self._steps:
36+
raise AssertionError(
37+
"ScriptedTransport exhausted; unexpected extra request to "
38+
f"{request.url}",
39+
)
40+
step = self._steps.pop(0)
41+
if isinstance(step, Exception):
42+
raise step
43+
if callable(step):
44+
return step(request)
45+
return step
46+
47+
48+
class AsyncScriptedTransport(httpx.MockTransport):
49+
"""Async counterpart to :class:`ScriptedTransport`."""
50+
51+
def __init__(self, steps: Iterable[ResponseStep]) -> None:
52+
self._steps: List[ResponseStep] = list(steps)
53+
self.requests: List[httpx.Request] = []
54+
55+
async def _handler(request: httpx.Request) -> httpx.Response:
56+
self.requests.append(request)
57+
if not self._steps:
58+
raise AssertionError(
59+
"AsyncScriptedTransport exhausted; unexpected extra "
60+
f"request to {request.url}",
61+
)
62+
step = self._steps.pop(0)
63+
if isinstance(step, Exception):
64+
raise step
65+
if callable(step):
66+
return step(request)
67+
return step
68+
69+
super().__init__(_handler)
70+
71+
72+
def exchange_response(
73+
access_token: Optional[str] = "jwt-1",
74+
*,
75+
expires_in: int = 900,
76+
token_exchange_enabled: bool = True,
77+
token_type: str = "bearer",
78+
status_code: int = 200,
79+
extra: Optional[dict] = None,
80+
) -> httpx.Response:
81+
"""Build a canned ``/auth/token-exchange`` response body."""
82+
body: dict[str, Any] = {
83+
"access_token": access_token,
84+
"token_type": token_type,
85+
"expires_in": expires_in,
86+
"token_exchange_enabled": token_exchange_enabled,
87+
}
88+
if extra:
89+
body.update(extra)
90+
return httpx.Response(status_code, json=body)
91+
92+
93+
def body_of(request: httpx.Request) -> dict:
94+
"""Decode the JSON body from an outgoing exchange request."""
95+
return json.loads(request.content.decode("utf-8"))
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""Unit tests for :class:`unstructured_client.auth.AsyncClientCredentials`."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from typing import List
7+
8+
import httpx
9+
import pytest
10+
11+
from unstructured_client.auth import (
12+
AsyncClientCredentials,
13+
InvalidCredentialError,
14+
TokenExchangeError,
15+
)
16+
17+
from ._mock_transport import AsyncScriptedTransport, body_of, exchange_response
18+
19+
SERVER_URL = "https://accounts.example.test"
20+
SECRET = "uns_sk_async_example"
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def _no_sleep(monkeypatch):
25+
async def _noop(*_args, **_kwargs):
26+
return None
27+
28+
monkeypatch.setattr(
29+
"unstructured_client.auth.client_credentials.asyncio.sleep",
30+
_noop,
31+
)
32+
33+
34+
@pytest.fixture
35+
def fake_clock(monkeypatch):
36+
state = {"now": 2_000_000.0}
37+
38+
def _now() -> float:
39+
return state["now"]
40+
41+
monkeypatch.setattr("unstructured_client.auth._base.time.monotonic", _now)
42+
monkeypatch.setattr(
43+
"unstructured_client.auth.client_credentials.time.monotonic", _now
44+
)
45+
return state
46+
47+
48+
class DescribeAsyncClientCredentials:
49+
@pytest.mark.asyncio
50+
async def it_exchanges_then_caches(self, fake_clock):
51+
transport = AsyncScriptedTransport(
52+
[exchange_response(access_token="jwt-1", expires_in=900)]
53+
)
54+
http_client = httpx.AsyncClient(transport=transport)
55+
acc = AsyncClientCredentials(
56+
client_secret=SECRET,
57+
server_url=SERVER_URL,
58+
http_client=http_client,
59+
)
60+
61+
first = await acc.acquire()
62+
second = await acc.acquire()
63+
64+
assert first == second == "jwt-1"
65+
assert len(transport.requests) == 1
66+
assert body_of(transport.requests[0]) == {
67+
"grant_type": "client_credentials",
68+
"client_secret": SECRET,
69+
}
70+
71+
@pytest.mark.asyncio
72+
async def it_raises_invalid_credential_on_401(self, fake_clock):
73+
transport = AsyncScriptedTransport(
74+
[httpx.Response(401, json={"detail": "bad"})]
75+
)
76+
http_client = httpx.AsyncClient(transport=transport)
77+
acc = AsyncClientCredentials(
78+
client_secret=SECRET,
79+
server_url=SERVER_URL,
80+
http_client=http_client,
81+
max_retries=5,
82+
)
83+
84+
with pytest.raises(InvalidCredentialError):
85+
await acc.acquire()
86+
87+
@pytest.mark.asyncio
88+
async def it_retries_5xx_then_succeeds(self, fake_clock):
89+
transport = AsyncScriptedTransport(
90+
[
91+
httpx.Response(500),
92+
httpx.Response(502),
93+
exchange_response(access_token="jwt-1", expires_in=900),
94+
]
95+
)
96+
http_client = httpx.AsyncClient(transport=transport)
97+
acc = AsyncClientCredentials(
98+
client_secret=SECRET,
99+
server_url=SERVER_URL,
100+
http_client=http_client,
101+
max_retries=3,
102+
)
103+
104+
assert await acc.acquire() == "jwt-1"
105+
assert len(transport.requests) == 3
106+
107+
@pytest.mark.asyncio
108+
async def it_serializes_concurrent_acquires(self, fake_clock):
109+
"""Ten concurrent ``acquire()`` calls must share one exchange."""
110+
transport = AsyncScriptedTransport(
111+
[exchange_response(access_token="jwt-1", expires_in=900)]
112+
)
113+
http_client = httpx.AsyncClient(transport=transport)
114+
acc = AsyncClientCredentials(
115+
client_secret=SECRET,
116+
server_url=SERVER_URL,
117+
http_client=http_client,
118+
)
119+
120+
results: List[str] = await asyncio.gather(*(acc.acquire() for _ in range(10)))
121+
122+
assert results == ["jwt-1"] * 10
123+
assert len(transport.requests) == 1
124+
125+
@pytest.mark.asyncio
126+
async def it_raises_outage_error_without_cached_token(self, fake_clock):
127+
transport = AsyncScriptedTransport([httpx.Response(500)] * 4)
128+
http_client = httpx.AsyncClient(transport=transport)
129+
acc = AsyncClientCredentials(
130+
client_secret=SECRET,
131+
server_url=SERVER_URL,
132+
http_client=http_client,
133+
max_retries=3,
134+
)
135+
136+
with pytest.raises(TokenExchangeError):
137+
await acc.acquire()
138+
139+
def it_sync_call_works_outside_running_loop(self, fake_clock):
140+
"""``__call__`` is the SDK entry point; must work without a loop."""
141+
transport = AsyncScriptedTransport(
142+
[exchange_response(access_token="jwt-1", expires_in=900)]
143+
)
144+
http_client = httpx.AsyncClient(transport=transport)
145+
acc = AsyncClientCredentials(
146+
client_secret=SECRET,
147+
server_url=SERVER_URL,
148+
http_client=http_client,
149+
)
150+
151+
assert acc() == "jwt-1"
152+
153+
@pytest.mark.asyncio
154+
async def it_sync_call_works_inside_running_loop(self, fake_clock):
155+
"""Driving __call__ from a running loop offloads to a worker thread."""
156+
transport = AsyncScriptedTransport(
157+
[exchange_response(access_token="jwt-1", expires_in=900)]
158+
)
159+
http_client = httpx.AsyncClient(transport=transport)
160+
acc = AsyncClientCredentials(
161+
client_secret=SECRET,
162+
server_url=SERVER_URL,
163+
http_client=http_client,
164+
)
165+
166+
token = await asyncio.to_thread(acc)
167+
assert token == "jwt-1"

0 commit comments

Comments
 (0)