-
Notifications
You must be signed in to change notification settings - Fork 188
Expand file tree
/
Copy pathapi_client.py
More file actions
139 lines (112 loc) · 5.19 KB
/
api_client.py
File metadata and controls
139 lines (112 loc) · 5.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Cloud API client utilities."""
from collections.abc import AsyncIterator
from typing import Optional
from contextlib import asynccontextmanager
from typing import AsyncContextManager, Callable
import httpx
import typer
from rich.console import Console
from basic_memory.cli.auth import CLIAuth
from basic_memory.config import ConfigManager
console = Console()
HttpClientFactory = Callable[[], AsyncContextManager[httpx.AsyncClient]]
class CloudAPIError(Exception):
"""Exception raised for cloud API errors."""
def __init__(
self, message: str, status_code: Optional[int] = None, detail: Optional[dict] = None
):
super().__init__(message)
self.status_code = status_code
self.detail = detail or {}
class SubscriptionRequiredError(CloudAPIError):
"""Exception raised when user needs an active subscription."""
def __init__(self, message: str, subscribe_url: str):
super().__init__(message, status_code=403, detail={"error": "subscription_required"})
self.subscribe_url = subscribe_url
def get_cloud_config() -> tuple[str, str, str]:
"""Get cloud OAuth configuration from config."""
config_manager = ConfigManager()
config = config_manager.config
return config.cloud_client_id, config.cloud_domain, config.cloud_host
async def get_authenticated_headers(auth: CLIAuth | None = None) -> dict[str, str]:
"""
Get authentication headers for cloud API requests.
Credential priority mirrors async_client._resolve_cloud_token():
1. API key (config.cloud_api_key) — fast, no refresh needed
2. OAuth token via CLIAuth — handles JWT refresh automatically
"""
# --- API key (preferred) ---
config_manager = ConfigManager()
api_key = config_manager.config.cloud_api_key
if api_key:
return {"Authorization": f"Bearer {api_key}"}
# --- OAuth fallback ---
client_id, domain, _ = get_cloud_config()
auth_obj = auth or CLIAuth(client_id=client_id, authkit_domain=domain)
token = await auth_obj.get_valid_token()
if not token:
console.print(
"[red]Not authenticated. Run 'bm cloud set-key <key>' or 'bm cloud login' first.[/red]"
)
raise typer.Exit(1)
return {"Authorization": f"Bearer {token}"}
@asynccontextmanager
async def _default_http_client(timeout: float) -> AsyncIterator[httpx.AsyncClient]:
async with httpx.AsyncClient(timeout=timeout) as client:
yield client
async def make_api_request(
method: str,
url: str,
headers: Optional[dict] = None,
json_data: Optional[dict] = None,
timeout: float = 30.0,
*,
auth: CLIAuth | None = None,
http_client_factory: HttpClientFactory | None = None,
) -> httpx.Response:
"""Make an API request to the cloud service."""
headers = headers or {}
auth_headers = await get_authenticated_headers(auth=auth)
headers.update(auth_headers)
# Add debug headers to help with compression issues
headers.setdefault("Accept-Encoding", "identity") # Disable compression for debugging
client_factory = http_client_factory or (lambda: _default_http_client(timeout))
async with client_factory() as client:
try:
response = await client.request(method=method, url=url, headers=headers, json=json_data)
response.raise_for_status()
return response
except httpx.HTTPError as e:
# Check if this is a response error with response details
if hasattr(e, "response") and e.response is not None: # pyright: ignore [reportAttributeAccessIssue]
response = e.response # type: ignore
# Try to parse error detail from response
error_detail = None
try:
error_detail = response.json()
except Exception:
# If JSON parsing fails, we'll handle it as a generic error
pass
# Check for subscription_required error (403)
if response.status_code == 403 and isinstance(error_detail, dict):
# Handle both FastAPI HTTPException format (nested under "detail")
# and direct format
detail_obj = error_detail.get("detail", error_detail)
if (
isinstance(detail_obj, dict)
and detail_obj.get("error") == "subscription_required"
):
message = detail_obj.get("message", "Active subscription required")
subscribe_url = detail_obj.get(
"subscribe_url", "https://basicmemory.com/subscribe"
)
raise SubscriptionRequiredError(
message=message, subscribe_url=subscribe_url
) from e
# Raise generic CloudAPIError with status code and detail
raise CloudAPIError(
f"API request failed: {e}",
status_code=response.status_code,
detail=error_detail if isinstance(error_detail, dict) else {},
) from e
raise CloudAPIError(f"API request failed: {e}") from e