@@ -3,7 +3,7 @@ pub mod registry;
33use crate :: executors:: web_socket_executors:: execute_ws_function;
44use crate :: types:: function_info:: FunctionInfo ;
55use crate :: types:: multimap:: QueryParams ;
6- use registry:: { Close , SendMessageToAll , SendText } ;
6+ use registry:: { Close , SendMessage , SendMessageToAll } ;
77
88use actix:: prelude:: * ;
99use actix:: { Actor , AsyncContext , StreamHandler } ;
@@ -13,6 +13,7 @@ use log::debug;
1313use once_cell:: sync:: OnceCell ;
1414use parking_lot:: RwLock ;
1515use pyo3:: prelude:: * ;
16+ use pyo3:: types:: PyBytes ;
1617use pyo3:: IntoPyObject ;
1718use pyo3_async_runtimes:: TaskLocals ;
1819use std:: sync:: Arc ;
@@ -24,24 +25,47 @@ use crate::runtime;
2425use registry:: { Register , WebSocketRegistry } ;
2526use std:: collections:: HashMap ;
2627
28+ #[ derive( Clone ) ]
29+ pub enum WsPayload {
30+ Text ( String ) ,
31+ Binary ( Vec < u8 > ) ,
32+ }
33+
34+ fn extract_payload ( message : & Bound < ' _ , PyAny > ) -> PyResult < WsPayload > {
35+ if let Ok ( s) = message. extract :: < String > ( ) {
36+ Ok ( WsPayload :: Text ( s) )
37+ } else if let Ok ( b) = message. extract :: < Vec < u8 > > ( ) {
38+ Ok ( WsPayload :: Binary ( b) )
39+ } else {
40+ Err ( pyo3:: exceptions:: PyTypeError :: new_err (
41+ "message must be str or bytes" ,
42+ ) )
43+ }
44+ }
45+
2746/// A Rust-backed channel receiver exposed to Python.
2847/// Python handlers call `await channel.receive()` to get the next message.
29- /// Returns the message string, or None when the connection is closed.
48+ /// Returns str for text frames, bytes for binary frames, or None when closed.
3049#[ pyclass]
3150pub struct WebSocketChannel {
32- receiver : Arc < tokio:: sync:: Mutex < mpsc:: UnboundedReceiver < Option < String > > > > ,
51+ receiver : Arc < tokio:: sync:: Mutex < mpsc:: UnboundedReceiver < Option < WsPayload > > > > ,
3352}
3453
3554#[ pymethods]
3655impl WebSocketChannel {
3756 /// Await the next message from the WebSocket.
38- /// Returns the message string, or None if the connection was closed.
57+ /// Returns str for text frames, bytes for binary frames, or None if closed.
3958 fn receive < ' py > ( & self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
4059 let receiver = self . receiver . clone ( ) ;
4160 pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
4261 let mut rx = receiver. lock ( ) . await ;
4362 match rx. recv ( ) . await {
44- Some ( Some ( msg) ) => Ok ( Some ( msg) ) ,
63+ Some ( Some ( WsPayload :: Text ( s) ) ) => Python :: with_gil ( |py| {
64+ Ok ( Some ( s. into_pyobject ( py) . unwrap ( ) . into_any ( ) . unbind ( ) ) )
65+ } ) ,
66+ Some ( Some ( WsPayload :: Binary ( b) ) ) => {
67+ Python :: with_gil ( |py| Ok ( Some ( PyBytes :: new ( py, & b) . into_any ( ) . unbind ( ) ) ) )
68+ }
4569 Some ( None ) | None => Ok ( None ) ,
4670 }
4771 } )
@@ -56,9 +80,7 @@ pub struct WebSocketConnector {
5680 pub task_locals : TaskLocals ,
5781 pub registry_addr : Addr < WebSocketRegistry > ,
5882 pub query_params : QueryParams ,
59- /// Sender side of the message channel (stays in the Actix actor).
60- pub message_sender : Option < mpsc:: UnboundedSender < Option < String > > > ,
61- /// Receiver side exposed to Python via WebSocketChannel.
83+ pub message_sender : Option < mpsc:: UnboundedSender < Option < WsPayload > > > ,
6284 pub message_channel : Option < Py < WebSocketChannel > > ,
6385}
6486
@@ -73,7 +95,7 @@ impl Actor for WebSocketConnector {
7395 addr : addr. clone ( ) ,
7496 } ) ;
7597
76- let ( tx, rx) = mpsc:: unbounded_channel :: < Option < String > > ( ) ;
98+ let ( tx, rx) = mpsc:: unbounded_channel :: < Option < WsPayload > > ( ) ;
7799 self . message_sender = Some ( tx) ;
78100 self . message_channel = Python :: with_gil ( |py| {
79101 Some (
@@ -94,9 +116,6 @@ impl Actor for WebSocketConnector {
94116 }
95117
96118 fn stopped ( & mut self , ctx : & mut Self :: Context ) {
97- // Drop the sender to close the channel.
98- // This causes any pending `channel.receive()` in Python to return None,
99- // which the WebSocketAdapter converts to WebSocketDisconnect.
100119 self . message_sender . take ( ) ;
101120
102121 let function = self . router . get ( "close" ) . unwrap ( ) ;
@@ -123,15 +142,21 @@ impl Clone for WebSocketConnector {
123142 }
124143}
125144
126- impl Handler < SendText > for WebSocketConnector {
145+ impl Handler < SendMessage > for WebSocketConnector {
127146 type Result = ( ) ;
128147
129- fn handle ( & mut self , msg : SendText , ctx : & mut Self :: Context ) {
148+ fn handle ( & mut self , msg : SendMessage , ctx : & mut Self :: Context ) {
130149 if self . id == msg. recipient_id {
131- ctx. text ( msg. message . clone ( ) ) ;
132- if msg. message == "Connection closed" {
133- // Close the WebSocket connection
134- ctx. stop ( ) ;
150+ match & msg. payload {
151+ WsPayload :: Text ( s) => {
152+ ctx. text ( s. clone ( ) ) ;
153+ if s == "Connection closed" {
154+ ctx. stop ( ) ;
155+ }
156+ }
157+ WsPayload :: Binary ( b) => {
158+ ctx. binary ( b. clone ( ) ) ;
159+ }
135160 }
136161 }
137162 }
@@ -151,14 +176,17 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WebSocketConnecto
151176 Ok ( ws:: Message :: Text ( text) ) => {
152177 debug ! ( "Text message received {:?}" , text) ;
153178 if let Some ( ref sender) = self . message_sender {
154- let _ = sender. send ( Some ( text. to_string ( ) ) ) ;
179+ let _ = sender. send ( Some ( WsPayload :: Text ( text. to_string ( ) ) ) ) ;
180+ }
181+ }
182+ Ok ( ws:: Message :: Binary ( bin) ) => {
183+ debug ! ( "Binary message received ({} bytes)" , bin. len( ) ) ;
184+ if let Some ( ref sender) = self . message_sender {
185+ let _ = sender. send ( Some ( WsPayload :: Binary ( bin. to_vec ( ) ) ) ) ;
155186 }
156187 }
157- Ok ( ws:: Message :: Binary ( bin) ) => ctx. binary ( bin) ,
158188 Ok ( ws:: Message :: Close ( _close_reason) ) => {
159189 debug ! ( "Socket was closed" ) ;
160- // Drop sender to signal channel closure so receive() returns None.
161- // The close handler is called once from stopped().
162190 self . message_sender . take ( ) ;
163191 ctx. stop ( ) ;
164192 }
@@ -169,63 +197,69 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WebSocketConnecto
169197
170198#[ pymethods]
171199impl WebSocketConnector {
172- pub fn sync_send_to ( & self , recipient_id : String , message : String ) {
200+ pub fn sync_send_to ( & self , recipient_id : String , message : & Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
201+ let payload = extract_payload ( message) ?;
173202 let recipient_id = Uuid :: parse_str ( & recipient_id) . unwrap ( ) ;
174203
175- match self . registry_addr . try_send ( SendText {
176- message ,
204+ match self . registry_addr . try_send ( SendMessage {
205+ payload ,
177206 sender_id : self . id ,
178207 recipient_id,
179208 } ) {
180- Ok ( _) => println ! ( "Message sent successfully" ) ,
181- Err ( e) => println ! ( "Failed to send message: {}" , e) ,
209+ Ok ( _) => debug ! ( "Message sent successfully" ) ,
210+ Err ( e) => debug ! ( "Failed to send message: {}" , e) ,
182211 }
212+ Ok ( ( ) )
183213 }
184214
185215 pub fn async_send_to (
186216 & self ,
187217 py : Python ,
188218 recipient_id : String ,
189- message : String ,
219+ message : & Bound < ' _ , PyAny > ,
190220 ) -> PyResult < Py < PyAny > > {
221+ let payload = extract_payload ( message) ?;
191222 let registry = self . registry_addr . clone ( ) ;
192223 let recipient_id = Uuid :: parse_str ( & recipient_id) . unwrap ( ) ;
193224 let sender_id = self . id ;
194225
195226 let awaitable = runtime:: future_into_py ( py, async move {
196- match registry. try_send ( SendText {
197- message ,
227+ match registry. try_send ( SendMessage {
228+ payload ,
198229 sender_id,
199230 recipient_id,
200231 } ) {
201- Ok ( _) => println ! ( "Message sent successfully" ) ,
202- Err ( e) => println ! ( "Failed to send message: {}" , e) ,
232+ Ok ( _) => debug ! ( "Message sent successfully" ) ,
233+ Err ( e) => debug ! ( "Failed to send message: {}" , e) ,
203234 }
204235 Ok ( ( ) )
205236 } ) ?;
206237
207238 Ok ( awaitable. into_pyobject ( py) ?. into_any ( ) . into ( ) )
208239 }
209240
210- pub fn sync_broadcast ( & self , message : String ) {
241+ pub fn sync_broadcast ( & self , message : & Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
242+ let payload = extract_payload ( message) ?;
211243 let registry = self . registry_addr . clone ( ) ;
212244 match registry. try_send ( SendMessageToAll {
213- message ,
245+ payload ,
214246 sender_id : self . id ,
215247 } ) {
216- Ok ( _) => println ! ( "Message sent successfully" ) ,
217- Err ( e) => println ! ( "Failed to send message: {}" , e) ,
248+ Ok ( _) => debug ! ( "Message sent successfully" ) ,
249+ Err ( e) => debug ! ( "Failed to send message: {}" , e) ,
218250 }
251+ Ok ( ( ) )
219252 }
220253
221- pub fn async_broadcast ( & self , py : Python , message : String ) -> PyResult < Py < PyAny > > {
254+ pub fn async_broadcast ( & self , py : Python , message : & Bound < ' _ , PyAny > ) -> PyResult < Py < PyAny > > {
255+ let payload = extract_payload ( message) ?;
222256 let registry = self . registry_addr . clone ( ) ;
223257 let sender_id = self . id ;
224258
225259 let awaitable = runtime:: future_into_py ( py, async move {
226- match registry. try_send ( SendMessageToAll { message , sender_id } ) {
227- Ok ( _) => println ! ( "Message sent successfully" ) ,
228- Err ( e) => println ! ( "Failed to send message: {}" , e) ,
260+ match registry. try_send ( SendMessageToAll { payload , sender_id } ) {
261+ Ok ( _) => debug ! ( "Message sent successfully" ) ,
262+ Err ( e) => debug ! ( "Failed to send message: {}" , e) ,
229263 }
230264 Ok ( ( ) )
231265 } ) ?;
@@ -247,7 +281,6 @@ impl WebSocketConnector {
247281 self . query_params . clone ( )
248282 }
249283
250- /// Get the message channel for WebSocket handlers.
251284 #[ getter]
252285 pub fn get_message_channel ( & self , py : Python ) -> Option < Py < WebSocketChannel > > {
253286 self . message_channel . as_ref ( ) . map ( |c| c. clone_ref ( py) )
0 commit comments