Skip to content

Commit cc3d532

Browse files
committed
feat: support binary WebSocket frames in message handlers
Closes #1332 Propagate binary WebSocket frames (opcode 0x2) through the entire Rust→Python pipeline instead of silently echoing them back. - Add WsPayload enum (Text/Binary) to replace String throughout the WebSocket channel, registry messages, and send methods - Forward ws::Message::Binary through the channel so Python handlers receive bytes for binary frames and str for text frames - Accept str|bytes in send/broadcast methods, emitting the correct frame type on the wire - Add receive() returning str|bytes, fix receive_text/receive_bytes to validate frame type Made-with: Cursor
1 parent 4c3be94 commit cc3d532

4 files changed

Lines changed: 136 additions & 79 deletions

File tree

robyn/ws.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(self, websocket_connector: WebSocketConnector, channel=None):
3333
self._connector = websocket_connector
3434
self._channel = channel
3535

36-
async def receive_text(self) -> str:
37-
"""Receive the next text message. Blocks until a message arrives.
36+
async def receive(self) -> str | bytes:
37+
"""Receive the next message. Returns str for text frames, bytes for binary frames.
3838
Raises WebSocketDisconnect when the connection is closed."""
3939
if self._channel is None:
4040
raise WebSocketDisconnect(reason="No message channel available")
@@ -43,31 +43,46 @@ async def receive_text(self) -> str:
4343
raise WebSocketDisconnect()
4444
return result
4545

46+
async def receive_text(self) -> str:
47+
"""Receive the next text frame. Raises TypeError if a binary frame arrives.
48+
Raises WebSocketDisconnect when the connection is closed."""
49+
msg = await self.receive()
50+
if not isinstance(msg, str):
51+
raise TypeError(f"Expected text frame, got {type(msg).__name__}")
52+
return msg
53+
4654
async def receive_bytes(self) -> bytes:
47-
"""Receive binary data (decoded from text)."""
48-
text = await self.receive_text()
49-
return text.encode("utf-8")
55+
"""Receive the next binary frame. Raises TypeError if a text frame arrives.
56+
Raises WebSocketDisconnect when the connection is closed."""
57+
msg = await self.receive()
58+
if not isinstance(msg, bytes):
59+
raise TypeError(f"Expected binary frame, got {type(msg).__name__}")
60+
return msg
5061

5162
async def receive_json(self):
52-
"""Receive and decode JSON data.
63+
"""Receive and decode JSON data from a text frame.
5364
Raises WebSocketDisconnect when the connection is closed."""
5465
text = await self.receive_text()
5566
return orjson.loads(text)
5667

68+
async def send(self, data: str | bytes):
69+
"""Send data to this WebSocket client. str sends a text frame, bytes sends a binary frame."""
70+
await self._connector.async_send_to(self._connector.id, data)
71+
5772
async def send_text(self, data: str):
58-
"""Send text data to this WebSocket client."""
73+
"""Send a text frame to this WebSocket client."""
5974
await self._connector.async_send_to(self._connector.id, data)
6075

6176
async def send_bytes(self, data: bytes):
62-
"""Send binary data (as text) to this WebSocket client."""
63-
await self._connector.async_send_to(self._connector.id, data.decode("utf-8"))
77+
"""Send a binary frame to this WebSocket client."""
78+
await self._connector.async_send_to(self._connector.id, data)
6479

6580
async def send_json(self, data):
66-
"""Send JSON data to this WebSocket client."""
81+
"""Send JSON data as a text frame to this WebSocket client."""
6782
await self.send_text(orjson.dumps(data).decode())
6883

69-
async def broadcast(self, data: str):
70-
"""Broadcast text data to all connected WebSocket clients on this endpoint."""
84+
async def broadcast(self, data: str | bytes):
85+
"""Broadcast data to all connected WebSocket clients on this endpoint."""
7186
await self._connector.async_broadcast(data)
7287

7388
async def close(self):

src/executors/web_socket_executors.rs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,20 @@ use pyo3::prelude::*;
44
use pyo3_async_runtimes::TaskLocals;
55

66
use crate::types::function_info::FunctionInfo;
7-
use crate::websockets::WebSocketConnector;
7+
use crate::websockets::{WebSocketConnector, WsPayload};
8+
9+
fn extract_ws_return(_py: Python, output: &Bound<'_, PyAny>) -> Option<WsPayload> {
10+
if output.is_none() {
11+
return None;
12+
}
13+
if let Ok(s) = output.extract::<String>() {
14+
Some(WsPayload::Text(s))
15+
} else if let Ok(b) = output.extract::<Vec<u8>>() {
16+
Some(WsPayload::Binary(b))
17+
} else {
18+
None
19+
}
20+
}
821

922
pub fn execute_ws_function(
1023
function: &FunctionInfo,
@@ -23,25 +36,23 @@ pub fn execute_ws_function(
2336
});
2437
let f = async {
2538
let output = fut.await.unwrap();
26-
Python::with_gil(|py| output.extract::<Option<String>>(py).unwrap())
39+
Python::with_gil(|py| extract_ws_return(py, output.bind(py)))
2740
}
2841
.into_actor(ws)
29-
.map(|res, _, ctx| {
30-
if let Some(msg) = res {
31-
ctx.text(msg);
32-
}
42+
.map(|res, _, ctx| match res {
43+
Some(WsPayload::Text(s)) => ctx.text(s),
44+
Some(WsPayload::Binary(b)) => ctx.binary(b),
45+
None => {}
3346
});
3447
ctx.spawn(f);
3548
} else {
3649
Python::with_gil(|py| {
3750
let handler = function.handler.bind(py).downcast().unwrap();
38-
if let Some(op) = handler
39-
.call1((ws.clone(),))
40-
.unwrap()
41-
.extract::<Option<String>>()
42-
.unwrap()
43-
{
44-
ctx.text(op);
51+
let result = handler.call1((ws.clone(),)).unwrap();
52+
match extract_ws_return(py, &result) {
53+
Some(WsPayload::Text(s)) => ctx.text(s),
54+
Some(WsPayload::Binary(b)) => ctx.binary(b),
55+
None => {}
4556
}
4657
});
4758
}

src/websockets/mod.rs

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ pub mod registry;
33
use crate::executors::web_socket_executors::execute_ws_function;
44
use crate::types::function_info::FunctionInfo;
55
use crate::types::multimap::QueryParams;
6-
use registry::{Close, SendMessageToAll, SendText};
6+
use registry::{Close, SendMessage, SendMessageToAll};
77

88
use actix::prelude::*;
99
use actix::{Actor, AsyncContext, StreamHandler};
@@ -13,6 +13,7 @@ use log::debug;
1313
use once_cell::sync::OnceCell;
1414
use parking_lot::RwLock;
1515
use pyo3::prelude::*;
16+
use pyo3::types::PyBytes;
1617
use pyo3::IntoPyObject;
1718
use pyo3_async_runtimes::TaskLocals;
1819
use std::sync::Arc;
@@ -24,24 +25,47 @@ use crate::runtime;
2425
use registry::{Register, WebSocketRegistry};
2526
use std::collections::HashMap;
2627

28+
#[derive(Clone)]
29+
pub enum WsPayload {
30+
Text(String),
31+
Binary(Vec<u8>),
32+
}
33+
34+
fn extract_payload(message: &Bound<'_, PyAny>) -> PyResult<WsPayload> {
35+
if let Ok(s) = message.extract::<String>() {
36+
Ok(WsPayload::Text(s))
37+
} else if let Ok(b) = message.extract::<Vec<u8>>() {
38+
Ok(WsPayload::Binary(b))
39+
} else {
40+
Err(pyo3::exceptions::PyTypeError::new_err(
41+
"message must be str or bytes",
42+
))
43+
}
44+
}
45+
2746
/// A Rust-backed channel receiver exposed to Python.
2847
/// Python handlers call `await channel.receive()` to get the next message.
29-
/// Returns the message string, or None when the connection is closed.
48+
/// Returns str for text frames, bytes for binary frames, or None when closed.
3049
#[pyclass]
3150
pub struct WebSocketChannel {
32-
receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<Option<String>>>>,
51+
receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<Option<WsPayload>>>>,
3352
}
3453

3554
#[pymethods]
3655
impl WebSocketChannel {
3756
/// Await the next message from the WebSocket.
38-
/// Returns the message string, or None if the connection was closed.
57+
/// Returns str for text frames, bytes for binary frames, or None if closed.
3958
fn receive<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
4059
let receiver = self.receiver.clone();
4160
pyo3_async_runtimes::tokio::future_into_py(py, async move {
4261
let mut rx = receiver.lock().await;
4362
match rx.recv().await {
44-
Some(Some(msg)) => Ok(Some(msg)),
63+
Some(Some(WsPayload::Text(s))) => Python::with_gil(|py| {
64+
Ok(Some(s.into_pyobject(py).unwrap().into_any().unbind()))
65+
}),
66+
Some(Some(WsPayload::Binary(b))) => {
67+
Python::with_gil(|py| Ok(Some(PyBytes::new(py, &b).into_any().unbind())))
68+
}
4569
Some(None) | None => Ok(None),
4670
}
4771
})
@@ -56,9 +80,7 @@ pub struct WebSocketConnector {
5680
pub task_locals: TaskLocals,
5781
pub registry_addr: Addr<WebSocketRegistry>,
5882
pub query_params: QueryParams,
59-
/// Sender side of the message channel (stays in the Actix actor).
60-
pub message_sender: Option<mpsc::UnboundedSender<Option<String>>>,
61-
/// Receiver side exposed to Python via WebSocketChannel.
83+
pub message_sender: Option<mpsc::UnboundedSender<Option<WsPayload>>>,
6284
pub message_channel: Option<Py<WebSocketChannel>>,
6385
}
6486

@@ -73,7 +95,7 @@ impl Actor for WebSocketConnector {
7395
addr: addr.clone(),
7496
});
7597

76-
let (tx, rx) = mpsc::unbounded_channel::<Option<String>>();
98+
let (tx, rx) = mpsc::unbounded_channel::<Option<WsPayload>>();
7799
self.message_sender = Some(tx);
78100
self.message_channel = Python::with_gil(|py| {
79101
Some(
@@ -94,9 +116,6 @@ impl Actor for WebSocketConnector {
94116
}
95117

96118
fn stopped(&mut self, ctx: &mut Self::Context) {
97-
// Drop the sender to close the channel.
98-
// This causes any pending `channel.receive()` in Python to return None,
99-
// which the WebSocketAdapter converts to WebSocketDisconnect.
100119
self.message_sender.take();
101120

102121
let function = self.router.get("close").unwrap();
@@ -123,15 +142,21 @@ impl Clone for WebSocketConnector {
123142
}
124143
}
125144

126-
impl Handler<SendText> for WebSocketConnector {
145+
impl Handler<SendMessage> for WebSocketConnector {
127146
type Result = ();
128147

129-
fn handle(&mut self, msg: SendText, ctx: &mut Self::Context) {
148+
fn handle(&mut self, msg: SendMessage, ctx: &mut Self::Context) {
130149
if self.id == msg.recipient_id {
131-
ctx.text(msg.message.clone());
132-
if msg.message == "Connection closed" {
133-
// Close the WebSocket connection
134-
ctx.stop();
150+
match &msg.payload {
151+
WsPayload::Text(s) => {
152+
ctx.text(s.clone());
153+
if s == "Connection closed" {
154+
ctx.stop();
155+
}
156+
}
157+
WsPayload::Binary(b) => {
158+
ctx.binary(b.clone());
159+
}
135160
}
136161
}
137162
}
@@ -151,14 +176,17 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WebSocketConnecto
151176
Ok(ws::Message::Text(text)) => {
152177
debug!("Text message received {:?}", text);
153178
if let Some(ref sender) = self.message_sender {
154-
let _ = sender.send(Some(text.to_string()));
179+
let _ = sender.send(Some(WsPayload::Text(text.to_string())));
180+
}
181+
}
182+
Ok(ws::Message::Binary(bin)) => {
183+
debug!("Binary message received ({} bytes)", bin.len());
184+
if let Some(ref sender) = self.message_sender {
185+
let _ = sender.send(Some(WsPayload::Binary(bin.to_vec())));
155186
}
156187
}
157-
Ok(ws::Message::Binary(bin)) => ctx.binary(bin),
158188
Ok(ws::Message::Close(_close_reason)) => {
159189
debug!("Socket was closed");
160-
// Drop sender to signal channel closure so receive() returns None.
161-
// The close handler is called once from stopped().
162190
self.message_sender.take();
163191
ctx.stop();
164192
}
@@ -169,63 +197,69 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WebSocketConnecto
169197

170198
#[pymethods]
171199
impl WebSocketConnector {
172-
pub fn sync_send_to(&self, recipient_id: String, message: String) {
200+
pub fn sync_send_to(&self, recipient_id: String, message: &Bound<'_, PyAny>) -> PyResult<()> {
201+
let payload = extract_payload(message)?;
173202
let recipient_id = Uuid::parse_str(&recipient_id).unwrap();
174203

175-
match self.registry_addr.try_send(SendText {
176-
message,
204+
match self.registry_addr.try_send(SendMessage {
205+
payload,
177206
sender_id: self.id,
178207
recipient_id,
179208
}) {
180-
Ok(_) => println!("Message sent successfully"),
181-
Err(e) => println!("Failed to send message: {}", e),
209+
Ok(_) => debug!("Message sent successfully"),
210+
Err(e) => debug!("Failed to send message: {}", e),
182211
}
212+
Ok(())
183213
}
184214

185215
pub fn async_send_to(
186216
&self,
187217
py: Python,
188218
recipient_id: String,
189-
message: String,
219+
message: &Bound<'_, PyAny>,
190220
) -> PyResult<Py<PyAny>> {
221+
let payload = extract_payload(message)?;
191222
let registry = self.registry_addr.clone();
192223
let recipient_id = Uuid::parse_str(&recipient_id).unwrap();
193224
let sender_id = self.id;
194225

195226
let awaitable = runtime::future_into_py(py, async move {
196-
match registry.try_send(SendText {
197-
message,
227+
match registry.try_send(SendMessage {
228+
payload,
198229
sender_id,
199230
recipient_id,
200231
}) {
201-
Ok(_) => println!("Message sent successfully"),
202-
Err(e) => println!("Failed to send message: {}", e),
232+
Ok(_) => debug!("Message sent successfully"),
233+
Err(e) => debug!("Failed to send message: {}", e),
203234
}
204235
Ok(())
205236
})?;
206237

207238
Ok(awaitable.into_pyobject(py)?.into_any().into())
208239
}
209240

210-
pub fn sync_broadcast(&self, message: String) {
241+
pub fn sync_broadcast(&self, message: &Bound<'_, PyAny>) -> PyResult<()> {
242+
let payload = extract_payload(message)?;
211243
let registry = self.registry_addr.clone();
212244
match registry.try_send(SendMessageToAll {
213-
message,
245+
payload,
214246
sender_id: self.id,
215247
}) {
216-
Ok(_) => println!("Message sent successfully"),
217-
Err(e) => println!("Failed to send message: {}", e),
248+
Ok(_) => debug!("Message sent successfully"),
249+
Err(e) => debug!("Failed to send message: {}", e),
218250
}
251+
Ok(())
219252
}
220253

221-
pub fn async_broadcast(&self, py: Python, message: String) -> PyResult<Py<PyAny>> {
254+
pub fn async_broadcast(&self, py: Python, message: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
255+
let payload = extract_payload(message)?;
222256
let registry = self.registry_addr.clone();
223257
let sender_id = self.id;
224258

225259
let awaitable = runtime::future_into_py(py, async move {
226-
match registry.try_send(SendMessageToAll { message, sender_id }) {
227-
Ok(_) => println!("Message sent successfully"),
228-
Err(e) => println!("Failed to send message: {}", e),
260+
match registry.try_send(SendMessageToAll { payload, sender_id }) {
261+
Ok(_) => debug!("Message sent successfully"),
262+
Err(e) => debug!("Failed to send message: {}", e),
229263
}
230264
Ok(())
231265
})?;
@@ -247,7 +281,6 @@ impl WebSocketConnector {
247281
self.query_params.clone()
248282
}
249283

250-
/// Get the message channel for WebSocket handlers.
251284
#[getter]
252285
pub fn get_message_channel(&self, py: Python) -> Option<Py<WebSocketChannel>> {
253286
self.message_channel.as_ref().map(|c| c.clone_ref(py))

0 commit comments

Comments
 (0)