Skip to content

Commit 26f5397

Browse files
committed
address coderabbit review comments
Signed-off-by: Jordan Dubrick <jdubrick@redhat.com>
1 parent 97de845 commit 26f5397

4 files changed

Lines changed: 68 additions & 11 deletions

File tree

src/app/endpoints/stream_interrupt.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,19 @@ async def stream_interrupt_endpoint_handler(
4545
StreamInterruptRegistry, Depends(get_stream_interrupt_registry)
4646
],
4747
) -> StreamingInterruptResponse:
48-
"""Interrupt an in-progress streaming query by request identifier."""
48+
"""Interrupt an in-progress streaming query by request identifier.
49+
50+
Parameters:
51+
interrupt_request: Request payload containing the stream request ID.
52+
auth: Auth context tuple resolved from the authentication dependency.
53+
registry: Stream interrupt registry dependency used to cancel streams.
54+
55+
Returns:
56+
StreamingInterruptResponse: Confirmation payload when interruption succeeds.
57+
58+
Raises:
59+
HTTPException: If no active stream for the given request ID can be interrupted.
60+
"""
4961
user_id, _, _, _ = auth
5062
request_id = interrupt_request.request_id
5163
interrupted = registry.cancel_stream(request_id, user_id)

src/app/endpoints/streaming_query.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@ async def generate_response(
353353
user_id=user_id,
354354
task=current_task,
355355
)
356+
else:
357+
logger.warning(
358+
"No current asyncio task for request %s; stream interruption will not be available",
359+
request_id,
360+
)
356361

357362
stream_completed = False
358363
try:

src/models/requests.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def validate_media_type(self) -> Self:
268268

269269

270270
class StreamingInterruptRequest(BaseModel):
271-
"""Model representing a request to interrupt an active streaming query."""
271+
"""Model representing a request to interrupt an active streaming query.
272+
273+
Attributes:
274+
request_id: Unique ID of the active streaming request to interrupt.
275+
"""
272276

273277
request_id: str = Field(
274278
description="The active streaming request ID to interrupt",
@@ -287,7 +291,17 @@ class StreamingInterruptRequest(BaseModel):
287291
@field_validator("request_id")
288292
@classmethod
289293
def check_request_id(cls, value: str) -> str:
290-
"""Validate that request identifier matches expected SUID format."""
294+
"""Validate that request identifier matches expected SUID format.
295+
296+
Parameters:
297+
value: Request identifier submitted by the caller.
298+
299+
Returns:
300+
str: The validated request identifier.
301+
302+
Raises:
303+
ValueError: If the request identifier is not a valid SUID.
304+
"""
291305
if not suid.check_suid(value):
292306
raise ValueError(f"Improper request ID {value}")
293307
return value

src/utils/stream_interrupts.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
"""In-memory registry for interrupting active streaming requests."""
22

33
import asyncio
4-
import logging
4+
from log import get_logger
55
from dataclasses import dataclass
66
from threading import Lock
7-
from typing import Optional
7+
from typing import Any
88

9-
logger = logging.getLogger(__name__)
9+
logger = get_logger(__name__)
1010

1111

1212
@dataclass
1313
class ActiveStream:
14-
"""Represents one active streaming request bound to a user."""
14+
"""Represents one active streaming request bound to a user.
15+
16+
Attributes:
17+
user_id: Owner of the streaming request.
18+
task: Asyncio task producing the stream response.
19+
"""
1520

1621
user_id: str
1722
task: asyncio.Task
@@ -28,7 +33,13 @@ def __init__(self) -> None:
2833
def register_stream(
2934
self, request_id: str, user_id: str, task: asyncio.Task
3035
) -> None:
31-
"""Register an active stream task for interrupt support."""
36+
"""Register an active stream task for interrupt support.
37+
38+
Parameters:
39+
request_id: Unique streaming request identifier.
40+
user_id: User identifier that owns the stream.
41+
task: Asyncio task associated with the stream.
42+
"""
3243
with self._lock:
3344
self._streams[request_id] = ActiveStream(user_id=user_id, task=task)
3445

@@ -39,6 +50,10 @@ def cancel_stream(self, request_id: str, user_id: str) -> bool:
3950
lock so that a concurrent ``deregister_stream`` cannot remove
4051
the entry between the ownership check and the cancel call.
4152
53+
Parameters:
54+
request_id: Unique streaming request identifier.
55+
user_id: User identifier attempting the interruption.
56+
4257
Returns:
4358
bool: True when cancellation was requested, otherwise False.
4459
"""
@@ -59,12 +74,23 @@ def cancel_stream(self, request_id: str, user_id: str) -> bool:
5974
return True
6075

6176
def deregister_stream(self, request_id: str) -> None:
62-
"""Remove stream task from registry once completed/cancelled."""
77+
"""Remove stream task from registry once completed/cancelled.
78+
79+
Parameters:
80+
request_id: Unique streaming request identifier.
81+
"""
6382
with self._lock:
6483
self._streams.pop(request_id, None)
6584

66-
def get_stream(self, request_id: str) -> Optional[ActiveStream]:
67-
"""Get currently registered stream metadata for tests/introspection."""
85+
def get_stream(self, request_id: str) -> ActiveStream | None:
86+
"""Get currently registered stream metadata for tests/introspection.
87+
88+
Parameters:
89+
request_id: Unique streaming request identifier.
90+
91+
Returns:
92+
ActiveStream | None: Registered stream metadata, or None when absent.
93+
"""
6894
with self._lock:
6995
return self._streams.get(request_id)
7096

0 commit comments

Comments
 (0)