11use crate :: p2:: bindings:: sockets:: network:: { ErrorCode , IpAddressFamily , IpSocketAddress , Network } ;
22use crate :: p2:: bindings:: sockets:: udp;
3- use crate :: p2:: udp:: { IncomingDatagramStream , OutgoingDatagramStream , SendState , UdpState } ;
3+ use crate :: p2:: udp:: { IncomingDatagramStream , OutgoingDatagramStream , SendState } ;
44use crate :: p2:: { Pollable , SocketError , SocketResult } ;
5- use crate :: sockets:: util:: {
6- get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address,
7- receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
8- set_unicast_hop_limit, udp_bind, udp_disconnect,
9- } ;
5+ use crate :: sockets:: util:: { is_valid_address_family, is_valid_remote_address} ;
106use crate :: sockets:: {
11- MAX_UDP_DATAGRAM_SIZE , SocketAddrUse , SocketAddressFamily , WasiSocketsCtxView ,
7+ MAX_UDP_DATAGRAM_SIZE , SocketAddrUse , SocketAddressFamily , UdpSocket , WasiSocketsCtxView ,
128} ;
139use anyhow:: anyhow;
1410use async_trait:: async_trait;
15- use io_lifetimes:: AsSocketlike ;
16- use rustix:: io:: Errno ;
1711use std:: net:: SocketAddr ;
1812use tokio:: io:: Interest ;
1913use wasmtime:: component:: Resource ;
@@ -28,51 +22,20 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
2822 network : Resource < Network > ,
2923 local_address : IpSocketAddress ,
3024 ) -> SocketResult < ( ) > {
31- self . ctx . allowed_network_uses . check_allowed_udp ( ) ?;
32-
33- match self . table . get ( & this) ?. udp_state {
34- UdpState :: Default => { }
35- UdpState :: BindStarted => return Err ( ErrorCode :: ConcurrencyConflict . into ( ) ) ,
36- UdpState :: Bound | UdpState :: Connected => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
37- }
38-
39- // Set the socket addr check on the socket so later functions have access to it through the socket handle
25+ let local_address = SocketAddr :: from ( local_address) ;
4026 let check = self . table . get ( & network) ?. socket_addr_check . clone ( ) ;
41- self . table
42- . get_mut ( & this) ?
43- . socket_addr_check
44- . replace ( check. clone ( ) ) ;
45-
46- let socket = self . table . get ( & this) ?;
47- let local_address: SocketAddr = local_address. into ( ) ;
48-
49- if !is_valid_address_family ( local_address. ip ( ) , socket. family ) {
50- return Err ( ErrorCode :: InvalidArgument . into ( ) ) ;
51- }
52-
53- {
54- check. check ( local_address, SocketAddrUse :: UdpBind ) . await ?;
55-
56- // Perform the OS bind call.
57- udp_bind ( socket. udp_socket ( ) , local_address) ?;
58- }
27+ check. check ( local_address, SocketAddrUse :: UdpBind ) . await ?;
5928
6029 let socket = self . table . get_mut ( & this) ?;
61- socket. udp_state = UdpState :: BindStarted ;
30+ socket. bind ( local_address) ?;
31+ socket. set_socket_addr_check ( Some ( check. clone ( ) ) ) ;
6232
6333 Ok ( ( ) )
6434 }
6535
6636 fn finish_bind ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < ( ) > {
67- let socket = self . table . get_mut ( & this) ?;
68-
69- match socket. udp_state {
70- UdpState :: BindStarted => {
71- socket. udp_state = UdpState :: Bound ;
72- Ok ( ( ) )
73- }
74- _ => Err ( ErrorCode :: NotInProgress . into ( ) ) ,
75- }
37+ self . table . get_mut ( & this) ?. finish_bind ( ) ?;
38+ Ok ( ( ) )
7639 }
7740
7841 async fn stream (
@@ -95,9 +58,8 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
9558 let socket = self . table . get_mut ( & this) ?;
9659 let remote_address = remote_address. map ( SocketAddr :: from) ;
9760
98- match socket. udp_state {
99- UdpState :: Bound | UdpState :: Connected => { }
100- _ => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
61+ if !socket. is_bound ( ) {
62+ return Err ( ErrorCode :: InvalidState . into ( ) ) ;
10163 }
10264
10365 // We disconnect & (re)connect in two distinct steps for two reasons:
@@ -107,48 +69,30 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
10769 // if there isn't a disconnect in between.
10870
10971 // Step #1: Disconnect
110- if let UdpState :: Connected = socket. udp_state {
111- udp_disconnect ( socket. udp_socket ( ) ) ?;
112- socket. udp_state = UdpState :: Bound ;
72+ if socket. is_connected ( ) {
73+ socket. disconnect ( ) ?;
11374 }
11475
11576 // Step #2: (Re)connect
11677 if let Some ( connect_addr) = remote_address {
117- let Some ( check) = socket. socket_addr_check . as_ref ( ) else {
78+ let connect_addr = SocketAddr :: from ( connect_addr) ;
79+ let Some ( check) = socket. socket_addr_check ( ) else {
11880 return Err ( ErrorCode :: InvalidState . into ( ) ) ;
11981 } ;
120- if !is_valid_remote_address ( connect_addr)
121- || !is_valid_address_family ( connect_addr. ip ( ) , socket. family )
122- {
123- return Err ( ErrorCode :: InvalidArgument . into ( ) ) ;
124- }
12582 check. check ( connect_addr, SocketAddrUse :: UdpConnect ) . await ?;
126-
127- rustix:: net:: connect ( socket. udp_socket ( ) , & connect_addr) . map_err (
128- |error| match error {
129- Errno :: AFNOSUPPORT => ErrorCode :: InvalidArgument , // See `bind` implementation.
130- Errno :: INPROGRESS => {
131- tracing:: debug!(
132- "UDP connect returned EINPROGRESS, which should never happen"
133- ) ;
134- ErrorCode :: Unknown
135- }
136- _ => ErrorCode :: from ( error) ,
137- } ,
138- ) ?;
139- socket. udp_state = UdpState :: Connected ;
83+ socket. connect ( connect_addr) ?;
14084 }
14185
14286 let incoming_stream = IncomingDatagramStream {
143- inner : socket. inner . clone ( ) ,
87+ inner : socket. socket ( ) . clone ( ) ,
14488 remote_address,
14589 } ;
14690 let outgoing_stream = OutgoingDatagramStream {
147- inner : socket. inner . clone ( ) ,
91+ inner : socket. socket ( ) . clone ( ) ,
14892 remote_address,
149- family : socket. family ,
93+ family : socket. address_family ( ) ,
15094 send_state : SendState :: Idle ,
151- socket_addr_check : socket. socket_addr_check . clone ( ) ,
95+ socket_addr_check : socket. socket_addr_check ( ) . cloned ( ) ,
15296 } ;
15397
15498 Ok ( (
@@ -159,56 +103,25 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
159103
160104 fn local_address ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < IpSocketAddress > {
161105 let socket = self . table . get ( & this) ?;
162-
163- match socket. udp_state {
164- UdpState :: Default => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
165- UdpState :: BindStarted => return Err ( ErrorCode :: ConcurrencyConflict . into ( ) ) ,
166- _ => { }
167- }
168-
169- let addr = socket
170- . udp_socket ( )
171- . as_socketlike_view :: < std:: net:: UdpSocket > ( )
172- . local_addr ( ) ?;
173- Ok ( addr. into ( ) )
106+ Ok ( socket. local_address ( ) ?. into ( ) )
174107 }
175108
176109 fn remote_address ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < IpSocketAddress > {
177110 let socket = self . table . get ( & this) ?;
178-
179- match socket. udp_state {
180- UdpState :: Connected => { }
181- _ => return Err ( ErrorCode :: InvalidState . into ( ) ) ,
182- }
183-
184- let addr = socket
185- . udp_socket ( )
186- . as_socketlike_view :: < std:: net:: UdpSocket > ( )
187- . peer_addr ( ) ?;
188- Ok ( addr. into ( ) )
111+ Ok ( socket. remote_address ( ) ?. into ( ) )
189112 }
190113
191114 fn address_family (
192115 & mut self ,
193116 this : Resource < udp:: UdpSocket > ,
194117 ) -> Result < IpAddressFamily , anyhow:: Error > {
195118 let socket = self . table . get ( & this) ?;
196-
197- match socket. family {
198- SocketAddressFamily :: Ipv4 => Ok ( IpAddressFamily :: Ipv4 ) ,
199- SocketAddressFamily :: Ipv6 => Ok ( IpAddressFamily :: Ipv6 ) ,
200- }
119+ Ok ( socket. address_family ( ) . into ( ) )
201120 }
202121
203122 fn unicast_hop_limit ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < u8 > {
204123 let socket = self . table . get ( & this) ?;
205-
206- let ttl = match socket. family {
207- SocketAddressFamily :: Ipv4 => get_ip_ttl ( socket. udp_socket ( ) ) ?,
208- SocketAddressFamily :: Ipv6 => get_ipv6_unicast_hops ( socket. udp_socket ( ) ) ?,
209- } ;
210-
211- Ok ( ttl)
124+ Ok ( socket. unicast_hop_limit ( ) ?)
212125 }
213126
214127 fn set_unicast_hop_limit (
@@ -217,17 +130,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
217130 value : u8 ,
218131 ) -> SocketResult < ( ) > {
219132 let socket = self . table . get ( & this) ?;
220-
221- set_unicast_hop_limit ( socket. udp_socket ( ) , socket. family , value) ?;
222-
133+ socket. set_unicast_hop_limit ( value) ?;
223134 Ok ( ( ) )
224135 }
225136
226137 fn receive_buffer_size ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < u64 > {
227138 let socket = self . table . get ( & this) ?;
228-
229- let value = receive_buffer_size ( socket. udp_socket ( ) ) ?;
230- Ok ( value)
139+ Ok ( socket. receive_buffer_size ( ) ?)
231140 }
232141
233142 fn set_receive_buffer_size (
@@ -236,33 +145,22 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
236145 value : u64 ,
237146 ) -> SocketResult < ( ) > {
238147 let socket = self . table . get ( & this) ?;
239-
240- set_receive_buffer_size ( socket. udp_socket ( ) , value) ?;
148+ socket. set_receive_buffer_size ( value) ?;
241149 Ok ( ( ) )
242150 }
243151
244152 fn send_buffer_size ( & mut self , this : Resource < udp:: UdpSocket > ) -> SocketResult < u64 > {
245153 let socket = self . table . get ( & this) ?;
246-
247- let value = send_buffer_size ( socket. udp_socket ( ) ) ?;
248- Ok ( value)
154+ Ok ( socket. send_buffer_size ( ) ?)
249155 }
250156
251- fn set_send_buffer_size (
252- & mut self ,
253- this : Resource < udp:: UdpSocket > ,
254- value : u64 ,
255- ) -> SocketResult < ( ) > {
157+ fn set_send_buffer_size ( & mut self , this : Resource < UdpSocket > , value : u64 ) -> SocketResult < ( ) > {
256158 let socket = self . table . get ( & this) ?;
257-
258- set_send_buffer_size ( socket. udp_socket ( ) , value) ?;
159+ socket. set_send_buffer_size ( value) ?;
259160 Ok ( ( ) )
260161 }
261162
262- fn subscribe (
263- & mut self ,
264- this : Resource < udp:: UdpSocket > ,
265- ) -> anyhow:: Result < Resource < DynPollable > > {
163+ fn subscribe ( & mut self , this : Resource < UdpSocket > ) -> anyhow:: Result < Resource < DynPollable > > {
266164 wasmtime_wasi_io:: poll:: subscribe ( self . table , this)
267165 }
268166
@@ -276,6 +174,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
276174 }
277175}
278176
177+ #[ async_trait]
178+ impl Pollable for UdpSocket {
179+ async fn ready ( & mut self ) {
180+ // None of the socket-level operations block natively
181+ }
182+ }
183+
279184impl udp:: HostIncomingDatagramStream for WasiSocketsCtxView < ' _ > {
280185 fn receive (
281186 & mut self ,
@@ -504,6 +409,15 @@ impl Pollable for OutgoingDatagramStream {
504409 }
505410}
506411
412+ impl From < SocketAddressFamily > for IpAddressFamily {
413+ fn from ( family : SocketAddressFamily ) -> IpAddressFamily {
414+ match family {
415+ SocketAddressFamily :: Ipv4 => IpAddressFamily :: Ipv4 ,
416+ SocketAddressFamily :: Ipv6 => IpAddressFamily :: Ipv6 ,
417+ }
418+ }
419+ }
420+
507421pub mod sync {
508422 use wasmtime:: component:: Resource ;
509423
0 commit comments