@@ -16,7 +16,7 @@ use buttplug_core::{
1616 message:: serializer:: ButtplugSerializedMessage ,
1717} ;
1818use futures:: { FutureExt , SinkExt , StreamExt , future:: BoxFuture } ;
19- use std:: { sync:: Arc , time:: Duration } ;
19+ use std:: { fmt , sync:: Arc , time:: Duration } ;
2020use tokio:: {
2121 net:: { TcpListener , TcpStream } ,
2222 select,
@@ -27,19 +27,42 @@ use tokio::{
2727 time:: sleep,
2828} ;
2929
30+ #[ derive( Clone ) ]
31+ struct ListenerBoundCallback ( Arc < dyn Fn ( u16 ) + Send + Sync > ) ;
32+
33+ impl ListenerBoundCallback {
34+ fn new ( callback : impl Fn ( u16 ) + Send + Sync + ' static ) -> Self {
35+ Self ( Arc :: new ( callback) )
36+ }
37+
38+ fn call ( & self , port : u16 ) {
39+ ( self . 0 ) ( port) ;
40+ }
41+ }
42+
43+ impl fmt:: Debug for ListenerBoundCallback {
44+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
45+ f. debug_struct ( "ListenerBoundCallback" )
46+ . finish_non_exhaustive ( )
47+ }
48+ }
49+
3050#[ derive( Clone , Debug ) ]
3151pub struct ButtplugWebsocketServerTransportBuilder {
3252 /// If true, listens all on available interfaces. Otherwise, only listens on 127.0.0.1.
3353 listen_on_all_interfaces : bool ,
3454 /// Insecure port for listening for websocket connections.
3555 port : u16 ,
56+ /// Optional callback fired after the listener is bound and the actual local port is known.
57+ listener_bound_callback : Option < ListenerBoundCallback > ,
3658}
3759
3860impl Default for ButtplugWebsocketServerTransportBuilder {
3961 fn default ( ) -> Self {
4062 Self {
4163 listen_on_all_interfaces : false ,
4264 port : 12345 ,
65+ listener_bound_callback : None ,
4366 }
4467 }
4568}
@@ -55,10 +78,16 @@ impl ButtplugWebsocketServerTransportBuilder {
5578 self
5679 }
5780
81+ pub fn on_listener_bound ( & mut self , callback : impl Fn ( u16 ) + Send + Sync + ' static ) -> & mut Self {
82+ self . listener_bound_callback = Some ( ListenerBoundCallback :: new ( callback) ) ;
83+ self
84+ }
85+
5886 pub fn finish ( & self ) -> ButtplugWebsocketServerTransport {
5987 ButtplugWebsocketServerTransport {
6088 port : self . port ,
6189 listen_on_all_interfaces : self . listen_on_all_interfaces ,
90+ listener_bound_callback : self . listener_bound_callback . clone ( ) ,
6291 disconnect_notifier : Arc :: new ( Notify :: new ( ) ) ,
6392 }
6493 }
@@ -193,6 +222,7 @@ async fn run_connection_loop(
193222pub struct ButtplugWebsocketServerTransport {
194223 port : u16 ,
195224 listen_on_all_interfaces : bool ,
225+ listener_bound_callback : Option < ListenerBoundCallback > ,
196226 disconnect_notifier : Arc < Notify > ,
197227}
198228
@@ -203,6 +233,7 @@ impl ButtplugConnectorTransport for ButtplugWebsocketServerTransport {
203233 incoming_sender : Sender < ButtplugTransportIncomingMessage > ,
204234 ) -> BoxFuture < ' static , Result < ( ) , ButtplugConnectorError > > {
205235 let disconnect_notifier = self . disconnect_notifier . clone ( ) ;
236+ let listener_bound_callback = self . listener_bound_callback . clone ( ) ;
206237
207238 let base_addr = if self . listen_on_all_interfaces {
208239 "0.0.0.0"
@@ -231,6 +262,19 @@ impl ButtplugConnectorTransport for ButtplugWebsocketServerTransport {
231262 )
232263 } ) ?;
233264 debug ! ( "Websocket: Listening on: {}" , addr) ;
265+ if let Some ( callback) = & listener_bound_callback {
266+ let local_port = listener
267+ . local_addr ( )
268+ . map_err ( |e| {
269+ ButtplugConnectorError :: TransportSpecificError (
270+ ButtplugConnectorTransportSpecificError :: GenericNetworkError ( format ! (
271+ "Could not determine websocket listener local address: {e}"
272+ ) ) ,
273+ )
274+ } ) ?
275+ . port ( ) ;
276+ callback. call ( local_port) ;
277+ }
234278 if let Ok ( ( stream, _) ) = listener. accept ( ) . await {
235279 info ! ( "Websocket: Got connection" ) ;
236280 let ws_stream = tokio_tungstenite:: accept_async ( stream)
@@ -288,6 +332,7 @@ mod test {
288332 message:: serializer:: ButtplugSerializedMessage ,
289333 } ;
290334 use std:: io:: ErrorKind ;
335+ use std:: sync:: { Arc , Mutex } ;
291336 use tokio:: { net:: TcpListener , sync:: mpsc} ;
292337
293338 #[ tokio:: test]
@@ -322,4 +367,37 @@ mod test {
322367 other => panic ! ( "Unexpected error: {other:?}" ) ,
323368 }
324369 }
370+
371+ #[ tokio:: test]
372+ async fn listener_bound_callback_receives_actual_port ( ) {
373+ let bound_port = Arc :: new ( Mutex :: new ( None ) ) ;
374+ let callback_port = bound_port. clone ( ) ;
375+ let transport = ButtplugWebsocketServerTransportBuilder :: default ( )
376+ . on_listener_bound ( move |port| {
377+ * callback_port. lock ( ) . unwrap ( ) = Some ( port) ;
378+ } )
379+ . finish ( ) ;
380+ let ( _outgoing_sender, outgoing_receiver) = mpsc:: channel :: < ButtplugSerializedMessage > ( 1 ) ;
381+ let ( incoming_sender, _incoming_receiver) =
382+ mpsc:: channel :: < ButtplugTransportIncomingMessage > ( 1 ) ;
383+ let connect_task = tokio:: spawn ( async move {
384+ let _ = transport. connect ( outgoing_receiver, incoming_sender) . await ;
385+ } ) ;
386+
387+ tokio:: time:: timeout ( std:: time:: Duration :: from_secs ( 1 ) , async {
388+ loop {
389+ if let Some ( port) = * bound_port. lock ( ) . unwrap ( ) {
390+ return port;
391+ }
392+ tokio:: task:: yield_now ( ) . await ;
393+ }
394+ } )
395+ . await
396+ . expect ( "listener bound callback was not called" ) ;
397+
398+ let port = bound_port. lock ( ) . unwrap ( ) . unwrap ( ) ;
399+ assert ! ( port > 0 ) ;
400+
401+ connect_task. abort ( ) ;
402+ }
325403}
0 commit comments