@@ -3,7 +3,7 @@ use std::{
33 io,
44 net:: SocketAddr ,
55 pin:: Pin ,
6- sync:: { atomic:: AtomicBool , RwLock , RwLockReadGuard , TryLockError } ,
6+ sync:: { atomic:: AtomicBool , Arc , RwLock , RwLockReadGuard , TryLockError } ,
77 task:: { Context , Poll } ,
88} ;
99
@@ -321,7 +321,7 @@ impl UdpSocket {
321321 panic ! ( "lock poisoned: {:?}" , e) ;
322322 }
323323 Err ( TryLockError :: WouldBlock ) => {
324- return Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "" ) ) ;
324+ return Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "locked " ) ) ;
325325 }
326326 } ;
327327 let ( socket, state) = guard. try_get_connected ( ) ?;
@@ -340,6 +340,50 @@ impl UdpSocket {
340340 }
341341 }
342342
343+ /// poll send a quinn based `Transmit`.
344+ pub fn poll_send_quinn (
345+ & self ,
346+ cx : & mut Context ,
347+ transmit : & Transmit < ' _ > ,
348+ ) -> Poll < io:: Result < ( ) > > {
349+ loop {
350+ if let Err ( err) = self . maybe_rebind ( ) {
351+ return Poll :: Ready ( Err ( err) ) ;
352+ }
353+
354+ let guard = n0_future:: ready!( self . poll_read_socket( & self . send_waker, cx) ) ;
355+ let ( socket, state) = guard. try_get_connected ( ) ?;
356+
357+ match socket. poll_send_ready ( cx) {
358+ Poll :: Pending => {
359+ self . send_waker . register ( cx. waker ( ) ) ;
360+ return Poll :: Pending ;
361+ }
362+ Poll :: Ready ( Ok ( ( ) ) ) => {
363+ let res =
364+ socket. try_io ( Interest :: WRITABLE , || state. send ( socket. into ( ) , transmit) ) ;
365+ if let Err ( err) = res {
366+ if err. kind ( ) == io:: ErrorKind :: WouldBlock {
367+ continue ;
368+ }
369+
370+ if let Some ( err) = self . handle_write_error ( err) {
371+ return Poll :: Ready ( Err ( err) ) ;
372+ }
373+ continue ;
374+ }
375+ return Poll :: Ready ( res) ;
376+ }
377+ Poll :: Ready ( Err ( err) ) => {
378+ if let Some ( err) = self . handle_write_error ( err) {
379+ return Poll :: Ready ( Err ( err) ) ;
380+ }
381+ continue ;
382+ }
383+ }
384+ }
385+ }
386+
343387 /// quinn based `poll_recv`
344388 pub fn poll_recv_quinn (
345389 & self ,
@@ -401,6 +445,11 @@ impl UdpSocket {
401445 }
402446 }
403447
448+ /// Creates a [`UdpSender`] sender.
449+ pub fn create_sender ( self : Arc < Self > ) -> UdpSender {
450+ UdpSender :: new ( self . clone ( ) )
451+ }
452+
404453 /// Whether transmitted datagrams might get fragmented by the IP layer
405454 ///
406455 /// Returns `false` on targets which employ e.g. the `IPV6_DONTFRAG` socket option.
@@ -806,6 +855,151 @@ impl Drop for UdpSocket {
806855 }
807856}
808857
858+ pin_project_lite:: pin_project! {
859+ pub struct UdpSender {
860+ socket: Arc <UdpSocket >,
861+ #[ pin]
862+ fut: Option <Pin <Box <dyn Future <Output = io:: Result <( ) >> + Send + Sync + ' static >>>,
863+ }
864+ }
865+
866+ impl std:: fmt:: Debug for UdpSender {
867+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
868+ f. write_str ( "UdpSender" )
869+ }
870+ }
871+
872+ impl UdpSender {
873+ fn new ( socket : Arc < UdpSocket > ) -> Self {
874+ Self { socket, fut : None }
875+ }
876+
877+ /// Async sending
878+ pub fn send < ' a , ' b > ( & self , transmit : & ' a quinn_udp:: Transmit < ' b > ) -> SendFutQuinn < ' a , ' b > {
879+ SendFutQuinn {
880+ socket : self . socket . clone ( ) ,
881+ transmit,
882+ }
883+ }
884+
885+ /// Poll send
886+ pub fn poll_send (
887+ self : Pin < & mut Self > ,
888+ transmit : & quinn_udp:: Transmit ,
889+ cx : & mut Context ,
890+ ) -> Poll < io:: Result < ( ) > > {
891+ let mut this = self . project ( ) ;
892+ loop {
893+ if let Err ( err) = this. socket . maybe_rebind ( ) {
894+ return Poll :: Ready ( Err ( err) ) ;
895+ }
896+
897+ let guard =
898+ n0_future:: ready!( this. socket. poll_read_socket( & this. socket. send_waker, cx) ) ;
899+
900+ if this. fut . is_none ( ) {
901+ let socket = this. socket . clone ( ) ;
902+ this. fut . set ( Some ( Box :: pin ( async move {
903+ n0_future:: future:: poll_fn ( |cx| socket. poll_writable ( cx) ) . await
904+ } ) ) ) ;
905+ }
906+ // We're forced to `unwrap` here because `Fut` may be `!Unpin`, which means we can't safely
907+ // obtain an `&mut Fut` after storing it in `this.fut` when `this` is already behind `Pin`,
908+ // and if we didn't store it then we wouldn't be able to keep it alive between
909+ // `poll_writable` calls.
910+ let result = n0_future:: ready!( this. fut. as_mut( ) . as_pin_mut( ) . unwrap( ) . poll( cx) ) ;
911+
912+ // Polling an arbitrary `Future` after it becomes ready is a logic error, so arrange for
913+ // a new `Future` to be created on the next call.
914+ this. fut . set ( None ) ;
915+
916+ // If .writable() fails, propagate the error
917+ result?;
918+
919+ let ( socket, state) = guard. try_get_connected ( ) ?;
920+ let result = socket. try_io ( Interest :: WRITABLE , || state. send ( socket. into ( ) , transmit) ) ;
921+
922+ match result {
923+ // We thought the socket was writable, but it wasn't, then retry so that either another
924+ // `writable().await` call determines that the socket is indeed not writable and
925+ // registers us for a wakeup, or the send succeeds if this really was just a
926+ // transient failure.
927+ Err ( ref e) if e. kind ( ) == io:: ErrorKind :: WouldBlock => continue ,
928+ // In all other cases, either propagate the error or we're Ok
929+ _ => return Poll :: Ready ( result) ,
930+ }
931+ }
932+ }
933+
934+ /// Best effort sending
935+ pub fn try_send ( & self , transmit : & quinn_udp:: Transmit ) -> io:: Result < ( ) > {
936+ self . socket . maybe_rebind ( ) ?;
937+
938+ match self . socket . socket . try_read ( ) {
939+ Ok ( guard) => {
940+ let ( socket, state) = guard. try_get_connected ( ) ?;
941+ socket. try_io ( Interest :: WRITABLE , || state. send ( socket. into ( ) , transmit) )
942+ }
943+ Err ( TryLockError :: Poisoned ( e) ) => panic ! ( "socket lock poisoned: {e}" ) ,
944+ Err ( TryLockError :: WouldBlock ) => {
945+ Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "locked" ) )
946+ }
947+ }
948+ }
949+ }
950+
951+ /// Send future quinn
952+ #[ derive( Debug ) ]
953+ pub struct SendFutQuinn < ' a , ' b > {
954+ socket : Arc < UdpSocket > ,
955+ transmit : & ' a quinn_udp:: Transmit < ' b > ,
956+ }
957+
958+ impl Future for SendFutQuinn < ' _ , ' _ > {
959+ type Output = io:: Result < ( ) > ;
960+
961+ fn poll ( self : Pin < & mut Self > , cx : & mut std:: task:: Context < ' _ > ) -> Poll < Self :: Output > {
962+ loop {
963+ if let Err ( err) = self . socket . maybe_rebind ( ) {
964+ return Poll :: Ready ( Err ( err) ) ;
965+ }
966+
967+ let guard =
968+ n0_future:: ready!( self . socket. poll_read_socket( & self . socket. send_waker, cx) ) ;
969+ let ( socket, state) = guard. try_get_connected ( ) ?;
970+
971+ match socket. poll_send_ready ( cx) {
972+ Poll :: Pending => {
973+ self . socket . send_waker . register ( cx. waker ( ) ) ;
974+ return Poll :: Pending ;
975+ }
976+ Poll :: Ready ( Ok ( ( ) ) ) => {
977+ let res = socket. try_io ( Interest :: WRITABLE , || {
978+ state. send ( socket. into ( ) , self . transmit )
979+ } ) ;
980+
981+ if let Err ( err) = res {
982+ if err. kind ( ) == io:: ErrorKind :: WouldBlock {
983+ continue ;
984+ }
985+ if let Some ( err) = self . socket . handle_write_error ( err) {
986+ return Poll :: Ready ( Err ( err) ) ;
987+ }
988+ continue ;
989+ }
990+ return Poll :: Ready ( res) ;
991+ }
992+ Poll :: Ready ( Err ( err) ) => {
993+ if let Some ( err) = self . socket . handle_write_error ( err) {
994+ return Poll :: Ready ( Err ( err) ) ;
995+ }
996+ continue ;
997+ }
998+ }
999+ }
1000+ }
1001+ }
1002+
8091003#[ cfg( test) ]
8101004mod tests {
8111005 use testresult:: TestResult ;
0 commit comments