Skip to content

Commit 2b4e56c

Browse files
committed
fix: sse downloads
1 parent 5b1e77f commit 2b4e56c

4 files changed

Lines changed: 190 additions & 26 deletions

File tree

integration_tests/base_routes.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional
88

99
from integration_tests.subroutes import di_subrouter, static_router, sub_router
10-
from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, WebSocketDisconnect, jsonify, serve_file, serve_html
10+
from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, StreamingResponse, WebSocketDisconnect, jsonify, serve_file, serve_html
1111
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
1212
from robyn.robyn import QueryParams, Url
1313
from robyn.templating import JinjaTemplate
@@ -1353,6 +1353,63 @@ def event_generator():
13531353
return SSEResponse(event_generator(), status_code=201)
13541354

13551355

1356+
# --- Binary streaming endpoints ---
1357+
1358+
1359+
@app.get("/stream/bytes")
1360+
def stream_bytes(request):
1361+
"""Stream binary data using bytes chunks"""
1362+
1363+
def bytes_generator():
1364+
# Generate 3 chunks of known binary data
1365+
for i in range(3):
1366+
yield bytes([i] * 1024) # 1KB chunks filled with the chunk index
1367+
1368+
return StreamingResponse(
1369+
content=bytes_generator(),
1370+
media_type="application/octet-stream",
1371+
headers=Headers({"Content-Type": "application/octet-stream"}),
1372+
)
1373+
1374+
1375+
@app.get("/stream/bytes_file")
1376+
def stream_bytes_file(request):
1377+
"""Stream a file in binary mode using yield from"""
1378+
test_file = os.path.join(current_file_path, "build", "index.html")
1379+
1380+
def file_generator():
1381+
with open(test_file, "rb") as f:
1382+
while True:
1383+
chunk = f.read(512)
1384+
if not chunk:
1385+
break
1386+
yield chunk
1387+
1388+
return StreamingResponse(
1389+
content=file_generator(),
1390+
media_type="application/octet-stream",
1391+
headers=Headers({
1392+
"Content-Type": "application/octet-stream",
1393+
"Content-Disposition": "attachment; filename=index.html",
1394+
}),
1395+
)
1396+
1397+
1398+
@app.get("/stream/mixed_text")
1399+
def stream_mixed_text(request):
1400+
"""Stream text data using string chunks (ensures str still works)"""
1401+
1402+
def text_generator():
1403+
for i in range(3):
1404+
yield f"text chunk {i}\n"
1405+
1406+
return StreamingResponse(
1407+
content=text_generator(),
1408+
media_type="text/plain",
1409+
headers=Headers({"Content-Type": "text/plain"}),
1410+
)
1411+
1412+
13561413
def main():
13571414
app.set_response_header("server", "robyn")
13581415
app.serve_directory(
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
3+
import pytest
4+
import requests
5+
6+
from integration_tests.helpers.http_methods_helpers import BASE_URL
7+
8+
9+
@pytest.mark.benchmark
10+
def test_stream_bytes_basic(session):
11+
"""Test that binary bytes can be streamed without error"""
12+
response = requests.get(f"{BASE_URL}/stream/bytes", stream=True, timeout=5)
13+
assert response.status_code == 200
14+
assert response.headers.get("Content-Type") == "application/octet-stream"
15+
16+
# Collect all streamed data
17+
data = b""
18+
for chunk in response.iter_content(chunk_size=None):
19+
if chunk:
20+
data += chunk
21+
22+
# We expect 3 chunks of 1024 bytes each
23+
assert len(data) == 3 * 1024
24+
25+
# Verify chunk contents: chunk i is filled with byte value i
26+
for i in range(3):
27+
chunk = data[i * 1024 : (i + 1) * 1024]
28+
assert chunk == bytes([i] * 1024), f"Chunk {i} has unexpected content"
29+
30+
31+
@pytest.mark.benchmark
32+
def test_stream_bytes_no_sse_headers(session):
33+
"""Test that binary streaming responses do NOT include SSE-specific headers"""
34+
response = requests.get(f"{BASE_URL}/stream/bytes", stream=True, timeout=5)
35+
assert response.status_code == 200
36+
37+
# SSE-specific headers should NOT be present for binary streams
38+
assert response.headers.get("X-Accel-Buffering") is None
39+
assert response.headers.get("Pragma") is None
40+
assert response.headers.get("Expires") is None
41+
42+
43+
@pytest.mark.benchmark
44+
def test_stream_bytes_file(session):
45+
"""Test streaming a file in binary mode"""
46+
response = requests.get(f"{BASE_URL}/stream/bytes_file", stream=True, timeout=5)
47+
assert response.status_code == 200
48+
assert response.headers.get("Content-Type") == "application/octet-stream"
49+
assert "attachment" in response.headers.get("Content-Disposition", "")
50+
51+
# Collect all streamed data
52+
streamed_data = b""
53+
for chunk in response.iter_content(chunk_size=None):
54+
if chunk:
55+
streamed_data += chunk
56+
57+
# Read the original file to compare
58+
test_file = os.path.join(
59+
os.path.dirname(os.path.abspath(__file__)), "build", "index.html"
60+
)
61+
with open(test_file, "rb") as f:
62+
original_data = f.read()
63+
64+
assert streamed_data == original_data, "Streamed file content does not match original"
65+
66+
67+
@pytest.mark.benchmark
68+
def test_stream_text_still_works(session):
69+
"""Test that string-based streaming still works after the bytes change"""
70+
response = requests.get(f"{BASE_URL}/stream/mixed_text", stream=True, timeout=5)
71+
assert response.status_code == 200
72+
assert response.headers.get("Content-Type") == "text/plain"
73+
74+
content = b""
75+
for chunk in response.iter_content(chunk_size=None):
76+
if chunk:
77+
content += chunk
78+
79+
text = content.decode("utf-8")
80+
for i in range(3):
81+
assert f"text chunk {i}" in text

robyn/responses.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def serve_file(file_path: str, file_name: Optional[str] = None) -> FileResponse:
6565
class AsyncGeneratorWrapper:
6666
"""Optimized true-streaming wrapper for async generators"""
6767

68-
def __init__(self, async_gen: AsyncGenerator[str, None]):
68+
def __init__(self, async_gen: AsyncGenerator[Union[str, bytes], None]):
6969
self.async_gen = async_gen
7070
self._loop = None
7171
self._iterator = None
@@ -124,7 +124,10 @@ async def get_next():
124124
class StreamingResponse:
125125
def __init__(
126126
self,
127-
content: Union[Generator[str, None, None], AsyncGenerator[str, None]],
127+
content: Union[
128+
Generator[Union[str, bytes], None, None],
129+
AsyncGenerator[Union[str, bytes], None],
130+
],
128131
status_code: Optional[int] = None,
129132
headers: Optional[Headers] = None,
130133
media_type: str = "text/event-stream",
@@ -149,7 +152,10 @@ def __init__(
149152

150153

151154
def SSEResponse(
152-
content: Union[Generator[str, None, None], AsyncGenerator[str, None]],
155+
content: Union[
156+
Generator[Union[str, bytes], None, None],
157+
AsyncGenerator[Union[str, bytes], None],
158+
],
153159
status_code: Optional[int] = None,
154160
headers: Optional[Headers] = None,
155161
) -> StreamingResponse:

src/types/response.rs

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub struct StreamingResponse {
3232
pub status_code: u16,
3333
pub headers: Headers,
3434
pub content_generator: Py<PyAny>,
35+
pub media_type: String,
3536
}
3637

3738
#[derive(Debug)]
@@ -86,11 +87,17 @@ impl Responder for Response {
8687
}
8788

8889
impl StreamingResponse {
89-
pub fn new(status_code: u16, headers: Headers, content_generator: Py<PyAny>) -> Self {
90+
pub fn new(
91+
status_code: u16,
92+
headers: Headers,
93+
content_generator: Py<PyAny>,
94+
media_type: String,
95+
) -> Self {
9096
Self {
9197
status_code,
9298
headers,
9399
content_generator,
100+
media_type,
94101
}
95102
}
96103
}
@@ -105,13 +112,15 @@ impl Responder for StreamingResponse {
105112

106113
apply_hashmap_headers(&mut response_builder, &self.headers);
107114

108-
// Optimized headers for SSE streaming
109-
response_builder
110-
.append_header(("Connection", "keep-alive"))
111-
.append_header(("X-Accel-Buffering", "no")) // Disable nginx buffering
112-
.append_header(("Cache-Control", "no-cache, no-store, must-revalidate"))
113-
.append_header(("Pragma", "no-cache"))
114-
.append_header(("Expires", "0"));
115+
// Only add SSE-specific headers for event-stream responses
116+
if self.media_type == "text/event-stream" {
117+
response_builder
118+
.append_header(("Connection", "keep-alive"))
119+
.append_header(("X-Accel-Buffering", "no")) // Disable nginx buffering
120+
.append_header(("Cache-Control", "no-cache, no-store, must-revalidate"))
121+
.append_header(("Pragma", "no-cache"))
122+
.append_header(("Expires", "0"));
123+
}
115124

116125
// Create the optimized stream from the Python generator
117126
let stream = create_python_stream(self.content_generator);
@@ -142,18 +151,29 @@ fn create_python_stream(
142151
// Try to get the next value from the generator (sync)
143152
match gen.call_method0("__next__") {
144153
Ok(value) => {
145-
match value.extract::<String>() {
146-
Ok(string_value) => {
147-
debug!("Generator yielded: {}", string_value);
148-
Some((string_value, generator))
149-
}
150-
Err(extract_err) => {
151-
debug!(
152-
"Failed to extract string from generator value: {}",
153-
extract_err
154-
);
155-
None // End of stream
156-
}
154+
// Try bytes first (common for binary file streaming),
155+
// then fall back to string extraction
156+
if let Ok(py_bytes) = value.downcast::<PyBytes>() {
157+
let data = py_bytes.as_bytes().to_vec();
158+
debug!("Generator yielded {} bytes", data.len());
159+
Some((data, generator))
160+
} else if let Ok(string_value) = value.extract::<String>() {
161+
debug!(
162+
"Generator yielded string of len {}",
163+
string_value.len()
164+
);
165+
Some((string_value.into_bytes(), generator))
166+
} else {
167+
let type_name = value
168+
.get_type()
169+
.name()
170+
.map(|n| n.to_string())
171+
.unwrap_or_else(|_| "unknown".to_string());
172+
debug!(
173+
"Generator yielded unsupported type: {}",
174+
type_name
175+
);
176+
None // End of stream
157177
}
158178
}
159179
Err(call_err) => {
@@ -171,7 +191,7 @@ fn create_python_stream(
171191
})
172192
.await
173193
{
174-
Ok(Some((string_value, generator))) => Some((Ok(Bytes::from(string_value)), generator)),
194+
Ok(Some((data, generator))) => Some((Ok(Bytes::from(data)), generator)),
175195
Ok(None) => None,
176196
Err(join_err) => {
177197
debug!(
@@ -583,6 +603,6 @@ impl FromPyObject<'_, '_> for StreamingResponse {
583603
"Successfully extracted StreamingResponse with status {} from type {}",
584604
status_code, type_name
585605
);
586-
Ok(StreamingResponse::new(status_code, headers, content))
606+
Ok(StreamingResponse::new(status_code, headers, content, media_type))
587607
}
588608
}

0 commit comments

Comments
 (0)