-
Notifications
You must be signed in to change notification settings - Fork 368
Expand file tree
/
Copy pathmiddleware.py
More file actions
404 lines (346 loc) · 15.5 KB
/
middleware.py
File metadata and controls
404 lines (346 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import json
import logging
import os
import re
import time
from pathlib import Path
from typing import Any, Callable
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Match
from starlette.types import Receive, Scope, Send
from app.desktop.git_sync.config import get_git_sync_config, project_path_from_id
from app.desktop.git_sync.errors import (
CorruptRepoError,
GitSyncError,
RemoteUnreachableError,
SyncConflictError,
WriteLockTimeoutError,
)
from app.desktop.git_sync.git_sync_manager import GitSyncManager
from app.desktop.git_sync.registry import GitSyncRegistry
logger = logging.getLogger(__name__)
MUTATING_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
PROJECT_ID_PATTERN = re.compile(r"^/api/projects/([^/]+)")
LONG_LOCK_HOLD_THRESHOLD = 5.0
def _is_dev_mode() -> bool:
return os.environ.get("KILN_DEV_MODE", "false") == "true"
class _StreamingUnderWriteLock(Exception):
"""Sentinel raised when an SSE response is detected under the write lock.
Raising inside an atomic_write block triggers rollback of any dirty
changes. The middleware catches this sentinel just outside the block
and returns a 500 JSON response.
"""
ERROR_MAP: dict[type[GitSyncError], tuple[int, str]] = {
RemoteUnreachableError: (
503,
"Cannot sync with remote. Check your connection.",
),
SyncConflictError: (
409,
"There was a problem saving. Please try again.",
),
WriteLockTimeoutError: (
503,
"Another save is in progress. Please wait a moment and try again.",
),
CorruptRepoError: (
500,
"Git repository is in an unexpected state.",
),
}
class GitSyncMiddleware(BaseHTTPMiddleware):
"""Wraps mutating requests with write lock + git commit/push.
For non-mutating requests and non-auto-sync routes,
passes through without buffering (preserves streaming responses).
"""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[override]
# BaseHTTPMiddleware wraps the request receive channel in its own anyio
# task group, which breaks disconnect propagation to StreamingResponse
# endpoints: the task-group-owned receive never delivers http.disconnect
# to the downstream generator, so SSE jobs (evals, extractions) can't
# detect a browser hard-refresh and keep running. For self-managed
# (@no_write_lock) endpoints we bypass BaseHTTPMiddleware entirely and
# hand the real ASGI receive/send to the endpoint.
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope)
endpoint = self._resolve_endpoint(request)
if endpoint is not None and getattr(endpoint, "_git_sync_no_write_lock", False):
await self._handle_self_managed(scope, receive, send, request)
return
await super().__call__(scope, receive, send)
async def _handle_self_managed(
self,
scope: Scope,
receive: Receive,
send: Send,
request: Request,
) -> None:
"""Pure-ASGI pass-through for @no_write_lock endpoints.
Mirrors the self-managed branch of dispatch() without going through
BaseHTTPMiddleware. Each self-managed endpoint builds its own
save_context per write inside its worker loop.
"""
manager = self._get_manager_for_request(request)
if manager is None:
await self.app(scope, receive, send)
return
try:
await manager.ensure_fresh_for_read()
except GitSyncError as e:
status, message = self._map_error(e)
response = Response(
content=json.dumps({"detail": message}, ensure_ascii=False),
status_code=status,
media_type="application/json",
)
await response(scope, receive, send)
return
self._notify_background_sync(manager)
# Attach the manager via scope["state"] so build_save_context(request)
# can find it via request.state.git_sync_manager.
if "state" not in scope:
scope["state"] = {}
scope["state"]["git_sync_manager"] = manager
await self.app(scope, receive, send)
async def dispatch(self, request: Request, call_next): # type: ignore[override]
manager = self._get_manager_for_request(request)
if manager is None:
return await self._unmatched_dispatch(request, call_next)
endpoint = self._resolve_endpoint(request)
if endpoint is None and _is_dev_mode():
logger.warning(
"GitSyncMiddleware: could not resolve endpoint for "
"project-scoped URL %s %s. Falling back to HTTP-method-only "
"lock decision. If this is a @no_write_lock SSE endpoint, "
"the lock fallback may buffer the response.",
request.method,
request.url.path,
)
needs_lock = (
request.method in MUTATING_METHODS
or getattr(endpoint, "_git_sync_write_lock", False)
) and not getattr(endpoint, "_git_sync_no_write_lock", False)
if not needs_lock:
# Expose the manager so @no_write_lock endpoints can build a
# SaveContext without importing desktop-layer code.
request.state.git_sync_manager = manager
try:
await manager.ensure_fresh_for_read()
except GitSyncError as e:
status, message = self._map_error(e)
return Response(
content=json.dumps({"detail": message}, ensure_ascii=False),
status_code=status,
media_type="application/json",
)
self._notify_background_sync(manager)
# @no_write_lock endpoints manage their own atomic_write blocks
# per job, so a dirty check here would race in-flight commits.
# Skip them entirely, per the functional spec.
is_self_managed = getattr(endpoint, "_git_sync_no_write_lock", False)
if is_self_managed:
return await call_next(request)
response = await call_next(request)
if _is_dev_mode():
return await self._dev_mode_dirty_check(request, response, manager)
return response
self._notify_background_sync(manager)
lock_start = time.monotonic()
try:
async with manager.atomic_write(f"{request.method} {request.url.path}"):
response = await call_next(request)
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
logger.error(
"Streaming response under write lock for %s %s -- "
"use @no_write_lock instead",
request.method,
request.url.path,
)
raise _StreamingUnderWriteLock()
body_chunks: list[bytes] = []
# body_iterator is always present on StreamingResponse from
# call_next; the union type includes None only because the
# base Response class doesn't guarantee it.
async for chunk in response.body_iterator: # type: ignore[union-attr]
body_chunks.append(chunk)
body = b"".join(body_chunks)
held = time.monotonic() - lock_start
if held > LONG_LOCK_HOLD_THRESHOLD:
logger.warning(
"Write lock held %.1fs for %s %s -- consider @no_write_lock",
held,
request.method,
request.url.path,
)
proxy = Response(
content=body,
status_code=response.status_code,
media_type=response.media_type,
background=response.background,
)
# Use raw_headers to preserve duplicate headers (e.g. Set-Cookie)
# that dict(response.headers) would collapse.
proxy.raw_headers = response.raw_headers
return proxy
except _StreamingUnderWriteLock:
return Response(
content=json.dumps(
{
"detail": "Internal error: streaming endpoint missing @no_write_lock decorator."
},
ensure_ascii=False,
),
status_code=500,
media_type="application/json",
)
except GitSyncError as e:
status, message = self._map_error(e)
return Response(
content=json.dumps({"detail": message}, ensure_ascii=False),
status_code=status,
media_type="application/json",
)
async def _dev_mode_dirty_check(
self,
request: Request,
response: Response,
manager: GitSyncManager,
) -> Response:
"""In dev mode, surface missing write locks immediately.
Runs only on the regular read path (not write-locked, not
@no_write_lock). If the response is SSE, log the missing decorator.
If the repo is dirty, log the offending request and return 500.
"""
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
logger.error(
"DEV MODE: SSE endpoint missing @no_write_lock: %s %s",
request.method,
request.url.path,
)
return response
dirty = await manager.get_dirty_file_paths()
if dirty:
logger.error(
"DEV MODE: Request left repo dirty without write lock! "
"(May also be caused by a parallel request with @no_write_lock "
"mid-atomic_write — check all recent logs before blaming this request.)\n"
" API: %s %s\n Project: %s\n Dirty files: %s",
request.method,
request.url.path,
manager.repo_path,
dirty,
)
return Response(
content=json.dumps(
{
"detail": "Dev mode: this endpoint wrote files without "
"holding a write lock, or a parallel request is "
"mid-write. See server logs for details."
},
ensure_ascii=False,
),
status_code=500,
media_type="application/json",
)
return response
async def _unmatched_dispatch(self, request: Request, call_next) -> Response:
"""Handle requests whose URL does not match /api/projects/{id}/...
In dev mode, after any request completes, sweep all cached managers
for dirty repos. A dirty repo here means the endpoint wrote project
files but lives outside the middleware-matched URL prefix, silently
bypassing git commit/push. This runs on all methods (not just
mutating ones) to match the dev-mode dirty check on matched URLs,
since a GET that dirties a synced repo is also a bug worth surfacing.
Only detects projects whose manager is currently cached in the
registry (i.e. accessed at least once this session). Projects
configured for auto-sync but not yet opened would be missed.
"""
response = await call_next(request)
if not _is_dev_mode():
return response
for mgr in GitSyncRegistry.all_managers():
dirty = await mgr.get_dirty_file_paths()
if dirty:
logger.error(
"DEV MODE: Non-project-scoped endpoint wrote to a synced repo! "
"Endpoints that write project files MUST live under "
"/api/projects/{project_id}/... so GitSyncMiddleware can "
"commit and push changes.\n"
"(May also be caused by a parallel request with @no_write_lock "
"mid-atomic_write — check all recent logs before blaming this request.)\n"
" API: %s %s\n Repo: %s\n Dirty files: %s",
request.method,
request.url.path,
mgr.repo_path,
dirty,
)
return Response(
content=json.dumps(
{
"detail": "Dev mode: a non-project-scoped endpoint wrote "
"to a synced repo without going through "
"GitSyncMiddleware. See server logs for details."
},
ensure_ascii=False,
),
status_code=500,
media_type="application/json",
)
return response
def _resolve_endpoint(self, request: Request) -> Callable[..., Any] | None:
"""Resolve the endpoint function for this request by matching routes.
BaseHTTPMiddleware runs before routing, so request.scope["endpoint"]
is not yet populated. We manually match against the app's routes to
find the endpoint and read decorator attributes.
Note: This performs a linear scan over all registered routes on each
request. This is acceptable for typical apps with tens-to-low-hundreds
of routes, but would need revisiting for very large route tables.
"""
app = request.scope.get("app")
if app is None:
return None
for route in app.routes:
match, scope = route.matches(request.scope)
if match == Match.FULL:
return scope.get("endpoint")
return None
def _get_manager_for_request(self, request: Request) -> GitSyncManager | None:
"""Extract project_id from URL, resolve to path, return manager if auto-sync enabled."""
match = PROJECT_ID_PATTERN.match(request.url.path)
if match is None:
return None
project_id = match.group(1)
project_path = project_path_from_id(project_id)
if project_path is None:
return None
config = get_git_sync_config(project_path)
if config is None:
return None
if config["sync_mode"] != "auto":
return None
clone_path = config.get("clone_path")
if clone_path is None:
return None
return GitSyncRegistry.get_or_create(
repo_path=Path(clone_path),
remote_name=config["remote_name"],
pat_token=config.get("pat_token"),
oauth_token=config.get("oauth_token"),
auth_mode=config["auth_mode"],
)
def _notify_background_sync(self, manager: GitSyncManager) -> None:
"""Notify background sync of activity to prevent idle pause."""
bg_sync = GitSyncRegistry.get_background_sync(manager.repo_path)
if bg_sync is not None:
bg_sync.notify_request()
def _map_error(self, error: GitSyncError) -> tuple[int, str]:
for error_type, (status, message) in ERROR_MAP.items():
if isinstance(error, error_type):
return status, message
return 500, "An unexpected git sync error occurred."