Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions app/desktop/git_sync/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Match
from starlette.types import Receive, Scope, Send

from app.desktop.git_sync.config import get_git_sync_config, project_path_from_id
from app.desktop.git_sync.errors import (
Expand Down Expand Up @@ -71,6 +72,65 @@ class GitSyncMiddleware(BaseHTTPMiddleware):
passes through without buffering (preserves streaming responses).
"""

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[override]
# BaseHTTPMiddleware wraps the request receive channel in its own anyio
# task group, which breaks disconnect propagation to StreamingResponse
# endpoints: the task-group-owned receive never delivers http.disconnect
# to the downstream generator, so SSE jobs (evals, extractions) can't
# detect a browser hard-refresh and keep running. For self-managed
# (@no_write_lock) endpoints we bypass BaseHTTPMiddleware entirely and
# hand the real ASGI receive/send to the endpoint.
if scope["type"] != "http":
await self.app(scope, receive, send)
return

request = Request(scope)
endpoint = self._resolve_endpoint(request)
if endpoint is not None and getattr(endpoint, "_git_sync_no_write_lock", False):
await self._handle_self_managed(scope, receive, send, request)
return

await super().__call__(scope, receive, send)

async def _handle_self_managed(
self,
scope: Scope,
receive: Receive,
send: Send,
request: Request,
) -> None:
"""Pure-ASGI pass-through for @no_write_lock endpoints.

Mirrors the self-managed branch of dispatch() without going through
BaseHTTPMiddleware. Each self-managed endpoint builds its own
save_context per write inside its worker loop.
"""
manager = self._get_manager_for_request(request)
if manager is None:
await self.app(scope, receive, send)
return

try:
await manager.ensure_fresh_for_read()
except GitSyncError as e:
status, message = self._map_error(e)
response = Response(
content=json.dumps({"detail": message}, ensure_ascii=False),
status_code=status,
media_type="application/json",
)
await response(scope, receive, send)
return

self._notify_background_sync(manager)
# Attach the manager via scope["state"] so build_save_context(request)
# can find it via request.state.git_sync_manager.
if "state" not in scope:
scope["state"] = {}
scope["state"]["git_sync_manager"] = manager

await self.app(scope, receive, send)

async def dispatch(self, request: Request, call_next): # type: ignore[override]
manager = self._get_manager_for_request(request)

Expand Down
127 changes: 127 additions & 0 deletions app/desktop/git_sync/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,3 +945,130 @@ def test_unresolved_endpoint_no_warning_in_prod_mode(git_repos, monkeypatch, cap
if r.levelno == logging.WARNING and "could not resolve endpoint" in r.message
]
assert len(warning_records) == 0


# --- @no_write_lock pure-ASGI bypass ---


def test_no_write_lock_bypasses_base_http_middleware_dispatch(git_repos, monkeypatch):
"""@no_write_lock endpoints must skip BaseHTTPMiddleware.dispatch() so SSE
streams see the real ASGI receive/send and disconnects propagate.
"""
local_path, _ = git_repos
config = _auto_config(str(local_path))

@no_write_lock
def bypass_endpoint():
return {"ok": True}

app = _build_app(get_endpoint=bypass_endpoint)

dispatch_called = False
original_dispatch = GitSyncMiddleware.dispatch

async def tracking_dispatch(self, request, call_next):
nonlocal dispatch_called
dispatch_called = True
return await original_dispatch(self, request, call_next)

monkeypatch.setattr(GitSyncMiddleware, "dispatch", tracking_dispatch)

with mock_git_sync_config(config):
client = TestClient(app)
resp = client.get(f"/api/projects/{PROJECT_ID}/items")

assert resp.status_code == 200
assert resp.json() == {"ok": True}
assert dispatch_called is False, (
"dispatch() should not be called for @no_write_lock endpoints — "
"bypass must route around BaseHTTPMiddleware to preserve SSE disconnect "
"propagation."
)


def test_no_write_lock_bypass_still_attaches_manager_to_state(git_repos):
"""The bypass path must still attach the git sync manager to request.state
so build_save_context(request) can find it inside the endpoint.
"""
local_path, _ = git_repos
config = _auto_config(str(local_path))

captured_manager = {}

@no_write_lock
def endpoint(request: FastAPIRequest):
captured_manager["value"] = getattr(
request.state, "git_sync_manager", "MISSING"
)
return {"ok": True}

app = _build_app(get_endpoint=endpoint)

with mock_git_sync_config(config):
client = TestClient(app)
resp = client.get(f"/api/projects/{PROJECT_ID}/items")

assert resp.status_code == 200
assert captured_manager["value"] != "MISSING"
assert captured_manager["value"] is not None


@pytest.mark.asyncio
async def test_no_write_lock_bypass_delivers_http_disconnect_to_endpoint(git_repos):
"""The bypass must hand the real ASGI receive to the endpoint, so
http.disconnect messages reach the endpoint's receive channel directly.
Under the old BaseHTTPMiddleware path, the middleware's wrapped receive
masked disconnects, so the endpoint never saw them.
"""
local_path, _ = git_repos
config = _auto_config(str(local_path))

observed = {"messages": []}

@no_write_lock
async def endpoint(request: FastAPIRequest):
# Consume the ASGI receive channel directly to prove the endpoint
# gets the real receive under the bypass.
for _ in range(2):
msg = await request.receive()
observed["messages"].append(msg["type"])
if msg["type"] == "http.disconnect":
break
return {"ok": True}

app = _build_app(get_endpoint=endpoint)

sent_request_once = {"done": False}

async def receive():
if not sent_request_once["done"]:
sent_request_once["done"] = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.disconnect"}

sent: list[dict] = []

async def send(message):
sent.append(message)

scope = {
"type": "http",
"method": "GET",
"path": f"/api/projects/{PROJECT_ID}/items",
"raw_path": f"/api/projects/{PROJECT_ID}/items".encode(),
"query_string": b"",
"headers": [],
"asgi": {"version": "3.0", "spec_version": "2.4"},
"http_version": "1.1",
"scheme": "http",
"server": ("testserver", 80),
"client": ("testclient", 12345),
"app": app,
}

with mock_git_sync_config(config):
await app(scope, receive, send)

# The endpoint must see the real http.disconnect — proving the bypass
# handed it the unwrapped ASGI receive channel.
assert "http.disconnect" in observed["messages"]
45 changes: 32 additions & 13 deletions app/desktop/studio_server/eval_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import json
from collections import defaultdict
from typing import Annotated, Any, Dict, List, Set, Tuple
Expand Down Expand Up @@ -31,6 +32,7 @@
from kiln_ai.datamodel.task_output import normalize_rating
from kiln_ai.utils.name_generator import generate_memorable_name
from kiln_server.git_sync_decorators import build_save_context, no_write_lock
from kiln_server.sse import stream_with_heartbeat
from kiln_server.task_api import task_from_id
from kiln_server.utils.agent_checks.policy import (
ALLOW_AGENT,
Expand Down Expand Up @@ -123,20 +125,37 @@ def task_run_config_from_id(
)


async def run_eval_runner_with_status(eval_runner: EvalRunner) -> StreamingResponse:
def _format_progress_sse(progress) -> str:
data = {
"progress": progress.complete,
"total": progress.total,
"errors": progress.errors,
}
return f"data: {json.dumps(data)}\n\n"


async def run_eval_runner_with_status(
eval_runner: EvalRunner, request: Request
) -> StreamingResponse:
# Yields async messages designed to be used with server sent events (SSE)
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events
async def event_generator():
async for progress in eval_runner.run():
data = {
"progress": progress.complete,
"total": progress.total,
"errors": progress.errors,
}
yield f"data: {json.dumps(data)}\n\n"

# Send the final complete message the app expects, and uses to stop listening
yield "data: complete\n\n"
# aclosing ensures prompt cleanup: when event_generator is cancelled
# (client disconnect → Starlette aclose's the body iterator → GeneratorExit
# here), hb.aclose() fires, which runs stream_with_heartbeat's finally,
# which aclose's the runner, which runs AsyncJobRunner's finally and
# cancels its workers.
hb = stream_with_heartbeat(
eval_runner.run(),
_format_progress_sse,
is_disconnected=request.is_disconnected,
)
async with contextlib.aclosing(hb) as stream:
async for chunk in stream:
yield chunk
if not await request.is_disconnected():
# Send the final complete message the app expects, and uses to stop listening
yield "data: complete\n\n"

return StreamingResponse(
content=event_generator(),
Expand Down Expand Up @@ -939,7 +958,7 @@ async def run_eval_config(
save_context=build_save_context(request),
)

return await run_eval_runner_with_status(eval_runner)
return await run_eval_runner_with_status(eval_runner, request)

@app.post(
"/api/projects/{project_id}/tasks/{task_id}/evals/{eval_id}/set_current_eval_config/{eval_config_id}",
Expand Down Expand Up @@ -1017,7 +1036,7 @@ async def run_eval_config_eval(
save_context=build_save_context(request),
)

return await run_eval_runner_with_status(eval_runner)
return await run_eval_runner_with_status(eval_runner, request)

@app.get(
"/api/projects/{project_id}/tasks/{task_id}/evals/{eval_id}/eval_config/{eval_config_id}/run_config/{run_config_id}/results",
Expand Down
Loading
Loading