-
Notifications
You must be signed in to change notification settings - Fork 202
Expand file tree
/
Copy pathtest_async_client_modes.py
More file actions
346 lines (268 loc) · 13.6 KB
/
test_async_client_modes.py
File metadata and controls
346 lines (268 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
from contextlib import asynccontextmanager
import httpx
import pytest
from basic_memory.cli.auth import CLIAuth
from basic_memory.config import ProjectMode
from basic_memory.mcp import async_client as async_client_module
from basic_memory.mcp.async_client import (
get_client,
get_cloud_control_plane_client,
set_client_factory,
)
@pytest.fixture(autouse=True)
def _reset_async_client_state(monkeypatch):
async_client_module._client_factory = None
monkeypatch.delenv("BASIC_MEMORY_FORCE_LOCAL", raising=False)
monkeypatch.delenv("BASIC_MEMORY_FORCE_CLOUD", raising=False)
monkeypatch.delenv("BASIC_MEMORY_EXPLICIT_ROUTING", raising=False)
yield
async_client_module._client_factory = None
@pytest.mark.asyncio
async def test_get_client_uses_injected_factory(monkeypatch):
seen = {"used": False}
@asynccontextmanager
async def factory():
seen["used"] = True
async with httpx.AsyncClient(base_url="https://example.test") as client:
yield client
# Ensure we don't leak factory to other tests
set_client_factory(factory)
async with get_client() as client:
assert str(client.base_url) == "https://example.test"
assert seen["used"] is True
@pytest.mark.asyncio
async def test_get_client_default_uses_local_asgi_transport(config_manager):
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
config_manager.save_config(cfg)
async with get_client() as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_get_client_explicit_cloud_uses_api_key(config_manager, monkeypatch):
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
config_manager.save_config(cfg)
monkeypatch.setenv("BASIC_MEMORY_FORCE_CLOUD", "true")
monkeypatch.setenv("BASIC_MEMORY_EXPLICIT_ROUTING", "true")
async with get_client() as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy"
assert client.headers.get("Authorization") == "Bearer bmc_test_key_123"
@pytest.mark.asyncio
async def test_get_client_cloud_adds_workspace_header(config_manager):
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
async with get_client(project_name="research", workspace="tenant-123") as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy"
assert client.headers.get("X-Workspace-ID") == "tenant-123"
@pytest.mark.asyncio
async def test_get_client_explicit_cloud_raises_without_credentials(config_manager, monkeypatch):
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = None
cfg.cloud_client_id = "cid"
cfg.cloud_domain = "https://auth.example.test"
config_manager.save_config(cfg)
monkeypatch.setenv("BASIC_MEMORY_FORCE_CLOUD", "true")
monkeypatch.setenv("BASIC_MEMORY_EXPLICIT_ROUTING", "true")
with pytest.raises(RuntimeError, match="Cloud routing requested but no credentials found"):
async with get_client():
pass
@pytest.mark.asyncio
async def test_get_client_per_project_cloud_uses_api_key(config_manager):
"""Cloud-mode project routes through cloud with API key auth."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
async with get_client(project_name="research") as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy"
assert client.headers.get("Authorization") == "Bearer bmc_test_key_123"
@pytest.mark.asyncio
async def test_get_client_per_project_cloud_raises_without_credentials(config_manager):
"""Cloud-mode project raises with actionable auth guidance when no credentials exist."""
cfg = config_manager.load_config()
cfg.cloud_api_key = None
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
with pytest.raises(RuntimeError, match="Project 'research' is set to cloud mode"):
async with get_client(project_name="research"):
pass
@pytest.mark.asyncio
async def test_get_client_local_project_uses_asgi_transport(config_manager):
"""Local-mode project uses ASGI transport even if API key exists."""
cfg = config_manager.load_config()
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("main", ProjectMode.LOCAL)
config_manager.save_config(cfg)
async with get_client(project_name="main") as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_get_client_no_project_name_defaults_local(config_manager):
"""No project_name defaults to local ASGI routing."""
cfg = config_manager.load_config()
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
async with get_client() as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_get_client_factory_overrides_per_project_routing(config_manager):
"""Injected factory takes priority over per-project routing."""
cfg = config_manager.load_config()
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
@asynccontextmanager
async def factory():
async with httpx.AsyncClient(base_url="https://factory.test") as client:
yield client
set_client_factory(factory)
async with get_client(project_name="research") as client:
assert str(client.base_url) == "https://factory.test"
@pytest.mark.asyncio
async def test_get_client_force_local_without_explicit_does_not_override_project_mode(
config_manager, monkeypatch
):
"""FORCE_LOCAL alone should not bypass per-project cloud routing."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
monkeypatch.setenv("BASIC_MEMORY_FORCE_LOCAL", "true")
async with get_client(project_name="research") as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy"
@pytest.mark.asyncio
async def test_get_client_explicit_local_overrides_cloud_project(config_manager, monkeypatch):
"""EXPLICIT_ROUTING + FORCE_LOCAL should override a cloud project to local ASGI."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
monkeypatch.setenv("BASIC_MEMORY_FORCE_LOCAL", "true")
monkeypatch.setenv("BASIC_MEMORY_EXPLICIT_ROUTING", "true")
async with get_client(project_name="research") as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_get_client_per_project_cloud_oauth_fallback(config_manager):
"""Cloud-mode project uses OAuth token when no API key is configured."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = None
cfg.cloud_client_id = "cid"
cfg.cloud_domain = "https://auth.example.test"
cfg.set_project_mode("research", ProjectMode.CLOUD)
config_manager.save_config(cfg)
# Write OAuth token file so CLIAuth.get_valid_token() returns it
auth = CLIAuth(client_id=cfg.cloud_client_id, authkit_domain=cfg.cloud_domain)
auth.token_file.parent.mkdir(parents=True, exist_ok=True)
auth.token_file.write_text(
'{"access_token":"oauth-token-456","refresh_token":null,"expires_at":9999999999,"token_type":"Bearer"}',
encoding="utf-8",
)
async with get_client(project_name="research") as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy"
assert client.headers.get("Authorization") == "Bearer oauth-token-456"
@pytest.mark.asyncio
async def test_get_client_explicit_cloud_overrides_local_project(config_manager, monkeypatch):
"""EXPLICIT_ROUTING + FORCE_CLOUD should override a local project to cloud."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
config_manager.save_config(cfg)
monkeypatch.setenv("BASIC_MEMORY_FORCE_CLOUD", "true")
monkeypatch.setenv("BASIC_MEMORY_EXPLICIT_ROUTING", "true")
async with get_client(project_name="main") as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy"
assert client.headers.get("Authorization") == "Bearer bmc_test_key_123"
@pytest.mark.asyncio
async def test_get_cloud_control_plane_client_uses_api_key_when_available(config_manager):
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.cloud_client_id = "cid"
cfg.cloud_domain = "https://auth.example.test"
config_manager.save_config(cfg)
async with get_cloud_control_plane_client() as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test"
assert client.headers.get("Authorization") == "Bearer bmc_test_key_123"
@pytest.mark.asyncio
async def test_get_cloud_control_plane_client_uses_oauth_token(config_manager):
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = None
cfg.cloud_client_id = "cid"
cfg.cloud_domain = "https://auth.example.test"
config_manager.save_config(cfg)
auth = CLIAuth(client_id=cfg.cloud_client_id, authkit_domain=cfg.cloud_domain)
auth.token_file.parent.mkdir(parents=True, exist_ok=True)
auth.token_file.write_text(
'{"access_token":"oauth-control-123","refresh_token":null,"expires_at":9999999999,"token_type":"Bearer"}',
encoding="utf-8",
)
async with get_cloud_control_plane_client() as client:
assert str(client.base_url).rstrip("/") == "https://cloud.example.test"
assert client.headers.get("Authorization") == "Bearer oauth-control-123"
@pytest.mark.asyncio
async def test_get_cloud_control_plane_client_with_workspace(config_manager):
"""Control plane client passes X-Workspace-ID header when workspace is provided."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
config_manager.save_config(cfg)
async with get_cloud_control_plane_client(workspace="tenant-abc") as client:
assert client.headers.get("X-Workspace-ID") == "tenant-abc"
# Without workspace, header should not be present
async with get_cloud_control_plane_client() as client:
assert "X-Workspace-ID" not in client.headers
@pytest.mark.asyncio
async def test_get_client_auto_resolves_workspace_from_project_config(config_manager):
"""get_client resolves workspace from project entry when not explicitly passed."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
cfg.projects["research"].workspace_id = "tenant-from-config"
config_manager.save_config(cfg)
async with get_client(project_name="research") as client:
assert client.headers.get("X-Workspace-ID") == "tenant-from-config"
@pytest.mark.asyncio
async def test_get_client_auto_resolves_workspace_from_default(config_manager):
"""get_client falls back to default_workspace when project has no workspace_id."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
cfg.default_workspace = "default-tenant-456"
config_manager.save_config(cfg)
async with get_client(project_name="research") as client:
assert client.headers.get("X-Workspace-ID") == "default-tenant-456"
@pytest.mark.asyncio
async def test_get_client_explicit_workspace_overrides_config(config_manager):
"""Explicit workspace param takes priority over project config."""
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = "bmc_test_key_123"
cfg.set_project_mode("research", ProjectMode.CLOUD)
cfg.projects["research"].workspace_id = "tenant-from-config"
config_manager.save_config(cfg)
async with get_client(project_name="research", workspace="explicit-tenant") as client:
assert client.headers.get("X-Workspace-ID") == "explicit-tenant"
@pytest.mark.asyncio
async def test_get_cloud_control_plane_client_raises_without_credentials(config_manager):
cfg = config_manager.load_config()
cfg.cloud_host = "https://cloud.example.test"
cfg.cloud_api_key = None
cfg.cloud_client_id = "cid"
cfg.cloud_domain = "https://auth.example.test"
config_manager.save_config(cfg)
with pytest.raises(RuntimeError, match="Cloud routing requested but no credentials found"):
async with get_cloud_control_plane_client():
pass