Skip to content

Commit 70ac013

Browse files
committed
Use the same UdpSocket in WASIp{2,3}
This commit refactors the implementation of `wasi:sockets` for WASIp2 and WASIp3 to use the same underlying host data structure for the `UdpSocket` resource in WIT. Previously each version of WASI had its own socket which resulted in duplicated code. There's some minor differences between WASIp2 and WASIp3 but it's easy enough to paper over with the same socket type. This is intended to help with the maintainability of this going forward to only have one type to operate on rather than two (which also ensures that bugfixes for one should affect the other). One other change made in this commit is that sprinkled checks for whether or not UDP is allowed are all removed and canonicalized during UDP socket creation. This means that UDP socket creation is the only location that checks for whether UDP is allowed. Once a UDP socket is created it can be used freely regardless of whether the UDP setting is enabled or disabled. This is not intended to have a large practical effect but it does mean the behavior of hosts that deny UDP but manually give access to a UDP socket resource to a component may behave subtly differently.
1 parent d526299 commit 70ac013

File tree

12 files changed

+159
-289
lines changed

12 files changed

+159
-289
lines changed

crates/wasi/src/p2/bindings.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ pub mod sync {
173173
"wasi:sockets/tcp/tcp-socket": super::super::sockets::tcp::TcpSocket,
174174
"wasi:sockets/udp/incoming-datagram-stream": super::super::sockets::udp::IncomingDatagramStream,
175175
"wasi:sockets/udp/outgoing-datagram-stream": super::super::sockets::udp::OutgoingDatagramStream,
176-
"wasi:sockets/udp/udp-socket": super::super::sockets::udp::UdpSocket,
176+
"wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket,
177177

178178
// Error host trait from wasmtime-wasi-io is synchronous, so we can alias it
179179
"wasi:io/error": wasmtime_wasi_io::bindings::wasi::io::error,
@@ -394,7 +394,7 @@ mod async_io {
394394
// this crate
395395
"wasi:sockets/network/network": crate::p2::network::Network,
396396
"wasi:sockets/tcp/tcp-socket": crate::p2::tcp::TcpSocket,
397-
"wasi:sockets/udp/udp-socket": crate::p2::udp::UdpSocket,
397+
"wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket,
398398
"wasi:sockets/udp/incoming-datagram-stream": crate::p2::udp::IncomingDatagramStream,
399399
"wasi:sockets/udp/outgoing-datagram-stream": crate::p2::udp::OutgoingDatagramStream,
400400
"wasi:sockets/ip-name-lookup/resolve-address-stream": crate::p2::ip_name_lookup::ResolveAddressStream,

crates/wasi/src/p2/host/udp.rs

Lines changed: 47 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network};
22
use crate::p2::bindings::sockets::udp;
3-
use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState};
3+
use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState};
44
use crate::p2::{Pollable, SocketError, SocketResult};
5-
use crate::sockets::util::{
6-
get_ip_ttl, get_ipv6_unicast_hops, is_valid_address_family, is_valid_remote_address,
7-
receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
8-
set_unicast_hop_limit, udp_bind, udp_disconnect,
9-
};
5+
use crate::sockets::util::{is_valid_address_family, is_valid_remote_address};
106
use crate::sockets::{
11-
MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView,
7+
MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView,
128
};
139
use anyhow::anyhow;
1410
use async_trait::async_trait;
15-
use io_lifetimes::AsSocketlike;
16-
use rustix::io::Errno;
1711
use std::net::SocketAddr;
1812
use tokio::io::Interest;
1913
use wasmtime::component::Resource;
@@ -28,51 +22,20 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
2822
network: Resource<Network>,
2923
local_address: IpSocketAddress,
3024
) -> SocketResult<()> {
31-
self.ctx.allowed_network_uses.check_allowed_udp()?;
32-
33-
match self.table.get(&this)?.udp_state {
34-
UdpState::Default => {}
35-
UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
36-
UdpState::Bound | UdpState::Connected => return Err(ErrorCode::InvalidState.into()),
37-
}
38-
39-
// Set the socket addr check on the socket so later functions have access to it through the socket handle
25+
let local_address = SocketAddr::from(local_address);
4026
let check = self.table.get(&network)?.socket_addr_check.clone();
41-
self.table
42-
.get_mut(&this)?
43-
.socket_addr_check
44-
.replace(check.clone());
45-
46-
let socket = self.table.get(&this)?;
47-
let local_address: SocketAddr = local_address.into();
48-
49-
if !is_valid_address_family(local_address.ip(), socket.family) {
50-
return Err(ErrorCode::InvalidArgument.into());
51-
}
52-
53-
{
54-
check.check(local_address, SocketAddrUse::UdpBind).await?;
55-
56-
// Perform the OS bind call.
57-
udp_bind(socket.udp_socket(), local_address)?;
58-
}
27+
check.check(local_address, SocketAddrUse::UdpBind).await?;
5928

6029
let socket = self.table.get_mut(&this)?;
61-
socket.udp_state = UdpState::BindStarted;
30+
socket.bind(local_address)?;
31+
socket.set_socket_addr_check(Some(check.clone()));
6232

6333
Ok(())
6434
}
6535

6636
fn finish_bind(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
67-
let socket = self.table.get_mut(&this)?;
68-
69-
match socket.udp_state {
70-
UdpState::BindStarted => {
71-
socket.udp_state = UdpState::Bound;
72-
Ok(())
73-
}
74-
_ => Err(ErrorCode::NotInProgress.into()),
75-
}
37+
self.table.get_mut(&this)?.finish_bind()?;
38+
Ok(())
7639
}
7740

7841
async fn stream(
@@ -95,9 +58,8 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
9558
let socket = self.table.get_mut(&this)?;
9659
let remote_address = remote_address.map(SocketAddr::from);
9760

98-
match socket.udp_state {
99-
UdpState::Bound | UdpState::Connected => {}
100-
_ => return Err(ErrorCode::InvalidState.into()),
61+
if !socket.is_bound() {
62+
return Err(ErrorCode::InvalidState.into());
10163
}
10264

10365
// We disconnect & (re)connect in two distinct steps for two reasons:
@@ -107,48 +69,30 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
10769
// if there isn't a disconnect in between.
10870

10971
// Step #1: Disconnect
110-
if let UdpState::Connected = socket.udp_state {
111-
udp_disconnect(socket.udp_socket())?;
112-
socket.udp_state = UdpState::Bound;
72+
if socket.is_connected() {
73+
socket.disconnect()?;
11374
}
11475

11576
// Step #2: (Re)connect
11677
if let Some(connect_addr) = remote_address {
117-
let Some(check) = socket.socket_addr_check.as_ref() else {
78+
let connect_addr = SocketAddr::from(connect_addr);
79+
let Some(check) = socket.socket_addr_check() else {
11880
return Err(ErrorCode::InvalidState.into());
11981
};
120-
if !is_valid_remote_address(connect_addr)
121-
|| !is_valid_address_family(connect_addr.ip(), socket.family)
122-
{
123-
return Err(ErrorCode::InvalidArgument.into());
124-
}
12582
check.check(connect_addr, SocketAddrUse::UdpConnect).await?;
126-
127-
rustix::net::connect(socket.udp_socket(), &connect_addr).map_err(
128-
|error| match error {
129-
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `bind` implementation.
130-
Errno::INPROGRESS => {
131-
tracing::debug!(
132-
"UDP connect returned EINPROGRESS, which should never happen"
133-
);
134-
ErrorCode::Unknown
135-
}
136-
_ => ErrorCode::from(error),
137-
},
138-
)?;
139-
socket.udp_state = UdpState::Connected;
83+
socket.connect(connect_addr)?;
14084
}
14185

14286
let incoming_stream = IncomingDatagramStream {
143-
inner: socket.inner.clone(),
87+
inner: socket.socket().clone(),
14488
remote_address,
14589
};
14690
let outgoing_stream = OutgoingDatagramStream {
147-
inner: socket.inner.clone(),
91+
inner: socket.socket().clone(),
14892
remote_address,
149-
family: socket.family,
93+
family: socket.address_family(),
15094
send_state: SendState::Idle,
151-
socket_addr_check: socket.socket_addr_check.clone(),
95+
socket_addr_check: socket.socket_addr_check().cloned(),
15296
};
15397

15498
Ok((
@@ -159,56 +103,25 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
159103

160104
fn local_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
161105
let socket = self.table.get(&this)?;
162-
163-
match socket.udp_state {
164-
UdpState::Default => return Err(ErrorCode::InvalidState.into()),
165-
UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
166-
_ => {}
167-
}
168-
169-
let addr = socket
170-
.udp_socket()
171-
.as_socketlike_view::<std::net::UdpSocket>()
172-
.local_addr()?;
173-
Ok(addr.into())
106+
Ok(socket.local_address()?.into())
174107
}
175108

176109
fn remote_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
177110
let socket = self.table.get(&this)?;
178-
179-
match socket.udp_state {
180-
UdpState::Connected => {}
181-
_ => return Err(ErrorCode::InvalidState.into()),
182-
}
183-
184-
let addr = socket
185-
.udp_socket()
186-
.as_socketlike_view::<std::net::UdpSocket>()
187-
.peer_addr()?;
188-
Ok(addr.into())
111+
Ok(socket.remote_address()?.into())
189112
}
190113

191114
fn address_family(
192115
&mut self,
193116
this: Resource<udp::UdpSocket>,
194117
) -> Result<IpAddressFamily, anyhow::Error> {
195118
let socket = self.table.get(&this)?;
196-
197-
match socket.family {
198-
SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
199-
SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
200-
}
119+
Ok(socket.address_family().into())
201120
}
202121

203122
fn unicast_hop_limit(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u8> {
204123
let socket = self.table.get(&this)?;
205-
206-
let ttl = match socket.family {
207-
SocketAddressFamily::Ipv4 => get_ip_ttl(socket.udp_socket())?,
208-
SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(socket.udp_socket())?,
209-
};
210-
211-
Ok(ttl)
124+
Ok(socket.unicast_hop_limit()?)
212125
}
213126

214127
fn set_unicast_hop_limit(
@@ -217,17 +130,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
217130
value: u8,
218131
) -> SocketResult<()> {
219132
let socket = self.table.get(&this)?;
220-
221-
set_unicast_hop_limit(socket.udp_socket(), socket.family, value)?;
222-
133+
socket.set_unicast_hop_limit(value)?;
223134
Ok(())
224135
}
225136

226137
fn receive_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
227138
let socket = self.table.get(&this)?;
228-
229-
let value = receive_buffer_size(socket.udp_socket())?;
230-
Ok(value)
139+
Ok(socket.receive_buffer_size()?)
231140
}
232141

233142
fn set_receive_buffer_size(
@@ -236,33 +145,22 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
236145
value: u64,
237146
) -> SocketResult<()> {
238147
let socket = self.table.get(&this)?;
239-
240-
set_receive_buffer_size(socket.udp_socket(), value)?;
148+
socket.set_receive_buffer_size(value)?;
241149
Ok(())
242150
}
243151

244152
fn send_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
245153
let socket = self.table.get(&this)?;
246-
247-
let value = send_buffer_size(socket.udp_socket())?;
248-
Ok(value)
154+
Ok(socket.send_buffer_size()?)
249155
}
250156

251-
fn set_send_buffer_size(
252-
&mut self,
253-
this: Resource<udp::UdpSocket>,
254-
value: u64,
255-
) -> SocketResult<()> {
157+
fn set_send_buffer_size(&mut self, this: Resource<UdpSocket>, value: u64) -> SocketResult<()> {
256158
let socket = self.table.get(&this)?;
257-
258-
set_send_buffer_size(socket.udp_socket(), value)?;
159+
socket.set_send_buffer_size(value)?;
259160
Ok(())
260161
}
261162

262-
fn subscribe(
263-
&mut self,
264-
this: Resource<udp::UdpSocket>,
265-
) -> anyhow::Result<Resource<DynPollable>> {
163+
fn subscribe(&mut self, this: Resource<UdpSocket>) -> anyhow::Result<Resource<DynPollable>> {
266164
wasmtime_wasi_io::poll::subscribe(self.table, this)
267165
}
268166

@@ -276,6 +174,13 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
276174
}
277175
}
278176

177+
#[async_trait]
178+
impl Pollable for UdpSocket {
179+
async fn ready(&mut self) {
180+
// None of the socket-level operations block natively
181+
}
182+
}
183+
279184
impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
280185
fn receive(
281186
&mut self,
@@ -504,6 +409,15 @@ impl Pollable for OutgoingDatagramStream {
504409
}
505410
}
506411

412+
impl From<SocketAddressFamily> for IpAddressFamily {
413+
fn from(family: SocketAddressFamily) -> IpAddressFamily {
414+
match family {
415+
SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4,
416+
SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6,
417+
}
418+
}
419+
}
420+
507421
pub mod sync {
508422
use wasmtime::component::Resource;
509423

crates/wasi/src/p2/host/udp_create_socket.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::p2::SocketResult;
22
use crate::p2::bindings::{sockets::network::IpAddressFamily, sockets::udp_create_socket};
3-
use crate::p2::udp::UdpSocket;
3+
use crate::sockets::UdpSocket;
44
use crate::sockets::WasiSocketsCtxView;
55
use wasmtime::component::Resource;
66

@@ -9,7 +9,7 @@ impl udp_create_socket::Host for WasiSocketsCtxView<'_> {
99
&mut self,
1010
address_family: IpAddressFamily,
1111
) -> SocketResult<Resource<UdpSocket>> {
12-
let socket = UdpSocket::new(address_family.into())?;
12+
let socket = UdpSocket::new(self.ctx, address_family.into())?;
1313
let socket = self.table.push(socket)?;
1414
Ok(socket)
1515
}

crates/wasi/src/p2/network.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ impl From<crate::sockets::util::ErrorCode> for ErrorCode {
4848
crate::sockets::util::ErrorCode::ConnectionReset => Self::ConnectionReset,
4949
crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted,
5050
crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge,
51+
crate::sockets::util::ErrorCode::NotInProgress => Self::NotInProgress,
5152
}
5253
}
5354
}

0 commit comments

Comments
 (0)