diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 9f327c590..517f7d3e6 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -7,7 +7,7 @@ from typing import List, Optional, TypedDict from integration_tests.subroutes import di_subrouter, static_router, sub_router -from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, WebSocketDisconnect, jsonify, serve_file, serve_html +from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, StreamingResponse, WebSocketDisconnect, jsonify, serve_file, serve_html from robyn.authentication import AuthenticationHandler, BearerGetter, Identity from robyn.robyn import QueryParams, Url from robyn.templating import JinjaTemplate @@ -1646,6 +1646,64 @@ def sync_pydantic_return_list(user: UserCreate) -> list[UserCreate]: async def async_pydantic_return_list(user: UserCreate) -> list[UserCreate]: return [user, user] +# --- Binary streaming endpoints --- + + +@app.get("/stream/bytes") +def stream_bytes(request): + """Stream binary data using bytes chunks""" + + def bytes_generator(): + # Generate 3 chunks of known binary data + for i in range(3): + yield bytes([i] * 1024) # 1KB chunks filled with the chunk index + + return StreamingResponse( + content=bytes_generator(), + media_type="application/octet-stream", + headers=Headers({"Content-Type": "application/octet-stream"}), + ) + + +@app.get("/stream/bytes_file") +def stream_bytes_file(request): + """Stream a file in binary mode using yield from""" + test_file = os.path.join(current_file_path, "build", "index.html") + + def file_generator(): + with open(test_file, "rb") as f: + while True: + chunk = f.read(512) + if not chunk: + break + yield chunk + + return StreamingResponse( + content=file_generator(), + media_type="application/octet-stream", + headers=Headers( + { + "Content-Type": "application/octet-stream", + "Content-Disposition": "attachment; filename=index.html", + } + ), + ) + + +@app.get("/stream/mixed_text") +def stream_mixed_text(request): + """Stream text data using string chunks (ensures str still works)""" + + def text_generator(): + for i in range(3): + yield f"text chunk {i}\n" + + return StreamingResponse( + content=text_generator(), + media_type="text/plain", + headers=Headers({"Content-Type": "text/plain"}), + ) + def main(): app.set_response_header("server", "robyn") diff --git a/integration_tests/test_binary_streaming.py b/integration_tests/test_binary_streaming.py new file mode 100644 index 000000000..a4953ea91 --- /dev/null +++ b/integration_tests/test_binary_streaming.py @@ -0,0 +1,79 @@ +import os + +import pytest +import requests + +from integration_tests.helpers.http_methods_helpers import BASE_URL + + +@pytest.mark.benchmark +def test_stream_bytes_basic(session): + """Test that binary bytes can be streamed without error""" + response = requests.get(f"{BASE_URL}/stream/bytes", stream=True, timeout=5) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/octet-stream" + + # Collect all streamed data + data = b"" + for chunk in response.iter_content(chunk_size=None): + if chunk: + data += chunk + + # We expect 3 chunks of 1024 bytes each + assert len(data) == 3 * 1024 + + # Verify chunk contents: chunk i is filled with byte value i + for i in range(3): + chunk = data[i * 1024 : (i + 1) * 1024] + assert chunk == bytes([i] * 1024), f"Chunk {i} has unexpected content" + + +@pytest.mark.benchmark +def test_stream_bytes_no_sse_headers(session): + """Test that binary streaming responses do NOT include SSE-specific headers""" + response = requests.get(f"{BASE_URL}/stream/bytes", stream=True, timeout=5) + assert response.status_code == 200 + + # SSE-specific headers should NOT be present for binary streams + assert response.headers.get("X-Accel-Buffering") is None + assert response.headers.get("Pragma") is None + assert response.headers.get("Expires") is None + + +@pytest.mark.benchmark +def test_stream_bytes_file(session): + """Test streaming a file in binary mode""" + response = requests.get(f"{BASE_URL}/stream/bytes_file", stream=True, timeout=5) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/octet-stream" + assert "attachment" in response.headers.get("Content-Disposition", "") + + # Collect all streamed data + streamed_data = b"" + for chunk in response.iter_content(chunk_size=None): + if chunk: + streamed_data += chunk + + # Read the original file to compare + test_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "build", "index.html") + with open(test_file, "rb") as f: + original_data = f.read() + + assert streamed_data == original_data, "Streamed file content does not match original" + + +@pytest.mark.benchmark +def test_stream_text_still_works(session): + """Test that string-based streaming still works after the bytes change""" + response = requests.get(f"{BASE_URL}/stream/mixed_text", stream=True, timeout=5) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "text/plain" + + content = b"" + for chunk in response.iter_content(chunk_size=None): + if chunk: + content += chunk + + text = content.decode("utf-8") + for i in range(3): + assert f"text chunk {i}" in text diff --git a/robyn/responses.py b/robyn/responses.py index 8d3717f11..c88d06971 100644 --- a/robyn/responses.py +++ b/robyn/responses.py @@ -65,7 +65,7 @@ def serve_file(file_path: str, file_name: Optional[str] = None) -> FileResponse: class AsyncGeneratorWrapper: """Optimized true-streaming wrapper for async generators""" - def __init__(self, async_gen: AsyncGenerator[str, None]): + def __init__(self, async_gen: AsyncGenerator[Union[str, bytes], None]): self.async_gen = async_gen self._loop = None self._iterator = None @@ -124,7 +124,10 @@ async def get_next(): class StreamingResponse: def __init__( self, - content: Union[Generator[str, None, None], AsyncGenerator[str, None]], + content: Union[ + Generator[Union[str, bytes], None, None], + AsyncGenerator[Union[str, bytes], None], + ], status_code: Optional[int] = None, headers: Optional[Headers] = None, media_type: str = "text/event-stream", @@ -149,7 +152,10 @@ def __init__( def SSEResponse( - content: Union[Generator[str, None, None], AsyncGenerator[str, None]], + content: Union[ + Generator[Union[str, bytes], None, None], + AsyncGenerator[Union[str, bytes], None], + ], status_code: Optional[int] = None, headers: Optional[Headers] = None, ) -> StreamingResponse: diff --git a/src/types/response.rs b/src/types/response.rs index b13d7d47e..757bd8fe3 100644 --- a/src/types/response.rs +++ b/src/types/response.rs @@ -31,6 +31,7 @@ pub struct StreamingResponse { pub status_code: u16, pub headers: Headers, pub content_generator: Py, + pub media_type: String, } #[derive(Debug)] @@ -85,11 +86,17 @@ impl Responder for Response { } impl StreamingResponse { - pub fn new(status_code: u16, headers: Headers, content_generator: Py) -> Self { + pub fn new( + status_code: u16, + headers: Headers, + content_generator: Py, + media_type: String, + ) -> Self { Self { status_code, headers, content_generator, + media_type, } } } @@ -104,13 +111,25 @@ impl Responder for StreamingResponse { apply_hashmap_headers(&mut response_builder, &self.headers); - // Optimized headers for SSE streaming - response_builder - .append_header(("Connection", "keep-alive")) - .append_header(("X-Accel-Buffering", "no")) // Disable nginx buffering - .append_header(("Cache-Control", "no-cache, no-store, must-revalidate")) - .append_header(("Pragma", "no-cache")) - .append_header(("Expires", "0")); + // Only add SSE-specific headers for event-stream responses if not already present + if self.media_type == "text/event-stream" { + if !self.headers.contains("Connection".to_string()) { + response_builder.append_header(("Connection", "keep-alive")); + } + if !self.headers.contains("X-Accel-Buffering".to_string()) { + response_builder.append_header(("X-Accel-Buffering", "no")); // Disable nginx buffering + } + if !self.headers.contains("Cache-Control".to_string()) { + response_builder + .append_header(("Cache-Control", "no-cache, no-store, must-revalidate")); + } + if !self.headers.contains("Pragma".to_string()) { + response_builder.append_header(("Pragma", "no-cache")); + } + if !self.headers.contains("Expires".to_string()) { + response_builder.append_header(("Expires", "0")); + } + } // Create the optimized stream from the Python generator let stream = create_python_stream(self.content_generator); @@ -129,7 +148,15 @@ fn create_python_stream( let gen = generator.bind(py); match gen.call_method0("__next__") { - Ok(value) => value.extract::().ok().map(|s| (s, generator)), + Ok(value) => { + if let Ok(py_bytes) = value.downcast::() { + Some((py_bytes.as_bytes().to_vec(), generator)) + } else if let Ok(s) = value.extract::() { + Some((s.into_bytes(), generator)) + } else { + None + } + } Err(e) => { if !e.is_instance_of::(py) { log::error!("Generator error: {}", e); @@ -141,7 +168,7 @@ fn create_python_stream( }) .await { - Ok(Some((string_value, generator))) => Some((Ok(Bytes::from(string_value)), generator)), + Ok(Some((data, generator))) => Some((Ok(Bytes::from(data)), generator)), _ => None, } })) @@ -282,7 +309,6 @@ impl PyStreamingResponse { let mut headers = Headers::new(None); if media_type == "text/event-stream" { headers.set("Content-Type".to_string(), "text/event-stream".to_string()); - headers.set("Cache-Control".to_string(), "no-cache".to_string()); headers.set("Connection".to_string(), "keep-alive".to_string()); } else { // For non-SSE streaming responses, still set appropriate headers @@ -443,18 +469,15 @@ impl FromPyObject<'_, '_> for StreamingResponse { .and_then(|a| a.extract()) .unwrap_or_else(|_| "text/event-stream".to_string()); - if media_type == "text/event-stream" { - headers.set("Content-Type".to_string(), "text/event-stream".to_string()); - if headers.get("Cache-Control".to_string()).is_none() { - headers.set("Cache-Control".to_string(), "no-cache".to_string()); - } - if headers.get("Connection".to_string()).is_none() { - headers.set("Connection".to_string(), "keep-alive".to_string()); - } - } + headers.set("Content-Type".to_string(), media_type.clone()); let content: pyo3::Py = obj.getattr("content")?.unbind(); - Ok(StreamingResponse::new(status_code, headers, content)) + Ok(StreamingResponse::new( + status_code, + headers, + content, + media_type, + )) } }