From 881ab364d84bcd963ccba46bd606d3ec69b980ff Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Wed, 30 Apr 2025 16:04:02 -0600 Subject: [PATCH] add comments to `unsafe` code in `concurrent.rs` This also removes a few cases of `unsafe` that turned out to be unnecessary, as well as adding the `unsafe` modifier to a few functions where appropriate. `futures_and_streams.rs` also has quite a few `unsafe` blocks; I'll tackle those next. Signed-off-by: Joel Dice --- .../src/runtime/component/concurrent.rs | 438 ++++++++++++++---- crates/wasmtime/src/runtime/component/func.rs | 39 +- .../src/runtime/component/func/host.rs | 15 +- .../src/runtime/component/func/typed.rs | 53 ++- crates/wasmtime/src/runtime/component/mod.rs | 4 +- 5 files changed, 409 insertions(+), 140 deletions(-) diff --git a/crates/wasmtime/src/runtime/component/concurrent.rs b/crates/wasmtime/src/runtime/component/concurrent.rs index 495ff2c336..1fde72e083 100644 --- a/crates/wasmtime/src/runtime/component/concurrent.rs +++ b/crates/wasmtime/src/runtime/component/concurrent.rs @@ -2,7 +2,7 @@ use { crate::{ component::{ func::{self, Func, Options}, - Instance, Lift, Lower, Val, + Instance, }, store::{StoreInner, StoreOpaque}, vm::{ @@ -161,6 +161,28 @@ pub struct Access<'a, T, U>(&'a mut Accessor); impl<'a, T, U> Access<'a, T, U> { /// Get mutable access to the store data. pub fn get(&mut self) -> &mut U { + // SAFETY: This relies on `Accessor::get` either returning a pair of + // pointers such that the first is a valid `*mut U` and the second is a + // `*mut dyn VMStore` whose data is of type `T` _or_ panicking if it is + // called outside its intended scope. + // + // See `ComponentInstance::wrap_call`, `Instance::run_with_raw`, and the + // code generated in + // `wasmtime_wit_bindgen::InterfaceGenerator::generate_guest_import_closure` + // for where we create `Accessor` instances and pass them `get` + // functions backed by thread-local state, thereby ensuring type- and + // lifetime-safety. + // + // See also `poll_with_state` in this module and in the code generated + // by `wasmtime_wit_bindgen::concurrent_declarations`, where we populate + // and then clear the thread-local state before and after polling + // futures which have `Accessor` instances, respectively. + // + // Finally, see `ComponentInstance::poll_until` for which is the only + // place we poll `ConcurrentState::futures`, which contains all futures + // which have `Accessor` instances, and note that we do so while we have + // exclusive access to the store and `ComponentInstance` can thus + // soundly populate the thread-local state to match. unsafe { &mut *(self.0.get)().0.cast() } } @@ -184,7 +206,8 @@ impl<'a, T, U> Deref for Access<'a, T, U> { type Target = U; fn deref(&self) -> &U { - unsafe { &mut *(self.0.get)().0.cast() } + // SAFETY: See comment in `Access::get` + unsafe { &*(self.0.get)().0.cast() } } } @@ -198,12 +221,14 @@ impl<'a, T, U> AsContext for Access<'a, T, U> { type Data = T; fn as_context(&self) -> StoreContext { + // SAFETY: See comment in `Access::get` unsafe { StoreContext(&*(self.0.get)().1.cast()) } } } impl<'a, T, U> AsContextMut for Access<'a, T, U> { fn as_context_mut(&mut self) -> StoreContextMut { + // SAFETY: See comment in `Access::get` unsafe { StoreContextMut(&mut *(self.0.get)().1.cast()) } } } @@ -214,16 +239,26 @@ impl<'a, T, U> AsContextMut for Access<'a, T, U> { /// This allows multiple host import futures to execute concurrently and access /// the store data between (but not across) `await` points. pub struct Accessor { - get: Arc (*mut u8, *mut u8)>, + get: Arc (*mut u8, *mut u8) + Send + Sync>, spawn: fn(Spawned), instance: Option, _phantom: PhantomData (*mut U, *mut StoreInner)>, } -unsafe impl Send for Accessor {} -unsafe impl Sync for Accessor {} - impl Accessor { + /// Creates a new `Accessor` backed by the specified functions. + /// + /// - `get`: used to retrieve the host data and store + /// + /// - `spawn`: used to queue spawned background tasks to be run later + /// + /// - `instance`: used to access the `Instance` to which this `Accessor` + /// (and the future which closes over it) belongs + /// + /// SAFETY: This relies on `get` either returning a pair of pointers such + /// that the first is a valid `*mut U` and the second is a `*mut dyn VMStore` + /// whose data is of type `T` _or_ panicking if it is called outside its + /// intended scope. See the comment in `Access::get` for further details. #[doc(hidden)] pub unsafe fn new( get: fn() -> (*mut u8, *mut u8), @@ -263,7 +298,9 @@ impl Accessor { let mut accessor = Accessor { get: Arc::new(move || { let (host, store) = get(); - unsafe { ((fun(&mut *host.cast()) as *mut V).cast(), store) } + // SAFETY: See `Accessor::new` doc comment + let host = unsafe { &mut *host.cast() }; + ((fun(host) as *mut V).cast(), store) }), spawn: self.spawn, instance: self.instance, @@ -289,14 +326,19 @@ impl Accessor { _phantom: PhantomData, }; let future = Arc::new(Mutex::new(AbortWrapper::Unpolled(unsafe { - // This is to avoid a `U: 'static` bound. Rationale: We don't - // actually store a value of type `U` in the `Accessor` we're - // `move`ing into the `async` block, and access to a `U` is brokered - // via `Accessor::with` by way of a thread-local variable in - // `wasmtime-wit-bindgen`-generated code. Furthermore, + // This `transmute` is to avoid requiring a `U: 'static` bound, + // which should be unnecessary. + // + // SAFETY: We don't actually store a value of type `U` in the + // `Accessor` we're `move`ing into the `async` block; access to a + // `U` is brokered via `Accessor::with` by way of a thread-local + // variable in `wasmtime-wit-bindgen`-generated code or the + // `poll_with_state` function in this module. Furthermore, // `AccessorTask` implementations are required to be `'static`, so // no lifetime issues there. We have no way to explain any of that // to the compiler, though, so we resort to a transmute here. + // + // See the comment in `Access::get` for further details. mem::transmute::< Pin> + Send>>, Pin> + Send + 'static>>, @@ -437,12 +479,23 @@ fn spawn_task(task: Spawned) { STATE.with(|v| v.borrow_mut().as_mut().unwrap().spawned.push(task)); } -fn poll_with_state( +// SAFETY: `store` must be a valid `*mut dyn VMStore` with a data type of `T`, +// and `instance` must be a valid `*mut ComponentInstance`. +// +// Note that we must smuggle these pointers in as `VMStoreRawPtr` and +// `SendSyncPtr`, respectively, to allow this function to be +// called within a future that is `Send`. This is sound because +// `ComponentInstance::poll_until` is the only place those futures are polled, +// that function has exclusive access to both the store and the +// `ComponentInstance`. +unsafe fn poll_with_state( store: VMStoreRawPtr, instance: SendSyncPtr, cx: &mut Context, future: Pin<&mut F>, ) -> Poll { + // SAFETY: Per the function precondition, `store` must be a valid `*mut dyn + // VMStore` with a data type of `T`. let mut store_cx = unsafe { StoreContextMut::new(&mut *store.0.as_ptr().cast()) }; let (result, spawned) = { @@ -460,6 +513,8 @@ fn poll_with_state( (future.poll(cx), STATE.with(|v| v.take()).unwrap().spawned) }; + // SAFETY: Per the function precondition, `instance` must be a valid `*mut + // ComponentInstance` let instance_ref = unsafe { &mut *instance.as_ptr() }; for spawned in spawned { instance_ref.spawn(future::poll_fn(move |cx| { @@ -707,6 +762,8 @@ impl ComponentInstance { if task.lift_result.is_some() { log::trace!("push call context for {guest_task:?}"); let call_context = task.call_context.take().unwrap(); + // SAFETY: This `ComponentInstance` belongs to the store in which it + // resides, so if it is valid then so is its store. unsafe { &mut (*self.store()) } .store_opaque_mut() .component_resource_state() @@ -720,6 +777,8 @@ impl ComponentInstance { if self.get_mut(guest_task)?.lift_result.is_some() { log::trace!("pop call context for {guest_task:?}"); let call_context = Some( + // SAFETY: This `ComponentInstance` belongs to the store in + // which it resides, so if it is valid then so is its store. unsafe { &mut (*self.store()) } .component_resource_state() .0 @@ -901,6 +960,16 @@ impl ComponentInstance { assert!(async_); Box::new(move |instance: &mut ComponentInstance| { + // SAFETY: This `ComponentInstance` belongs to the store in + // which it resides, so if it is valid then so is its store. + // + // In addition, the store's data type is known to be `T` because + // this closure will have been called with the same + // `ComponentInstance` that was passed to the outer `queue_call` + // scope. See + // `ComponentInstance::{poll_until,handle_work_item,handle_guest_call}`, + // where we pop the work item containing this closure and pass + // it the same `ComponentInstance`. let mut store = unsafe { StoreContextMut(&mut *instance.store().cast()) }; let old_task = instance.guest_task().replace(guest_task); log::trace!( @@ -920,6 +989,8 @@ impl ComponentInstance { *instance.guest_task() = old_task; log::trace!("stackless call: restored {old_task:?} as current task"); + // SAFETY: `wasmparser` will have validated that the callback + // function returns a `i32` result. let code = unsafe { storage[0].assume_init() }.get_i32() as u32; instance.handle_callback_code( @@ -935,6 +1006,8 @@ impl ComponentInstance { }) as Box Result<()> + Send + Sync> } else { Box::new(move |instance: &mut ComponentInstance| { + // SAFETY: See comment in the closure created in the + // `callback.is_some()` case above. let mut store = unsafe { StoreContextMut::(&mut *instance.store().cast()) }; let old_task = instance.guest_task().replace(guest_task); log::trace!( @@ -962,6 +1035,8 @@ impl ComponentInstance { assert!(instance.get(guest_task)?.result.is_none()); + // SAFETY: `result_count` represents the number of core Wasm + // results returned, per `wasmparser`. let result = (lift.lift)(instance, unsafe { mem::transmute::<&[MaybeUninit], &[ValRaw]>( &storage[..result_count], @@ -973,9 +1048,16 @@ impl ComponentInstance { if let Some(func) = post_return { let arg = match result_count { 0 => ValRaw::i32(0), + // SAFETY: `result_count` represents the number of + // core Wasm results returned, per `wasmparser`. 1 => unsafe { storage[0].assume_init() }, _ => unreachable!(), }; + // SAFETY: `func` is a valid `*mut VMFuncRef` from + // either `wasmtime-cranelift`-generated fused adapter + // code or `component::Options`. Per `wasmparser` + // post-return signature validation, we know it takes a + // single parameter. unsafe { crate::Func::call_unchecked_raw( &mut store, @@ -1042,6 +1124,16 @@ impl ComponentInstance { let new_task = GuestTask::new( self, Box::new(move |instance, dst| { + // SAFETY: This `ComponentInstance` belongs to the store in + // which it resides, so if it is valid then so is its store. + // + // In addition, the store's data type is known to be `T` because + // this closure will have been called with the same + // `ComponentInstance` that was passed to the outer `enter_call` + // scope. See + // `ComponentInstance::{poll_until,handle_work_item,handle_guest_call}`, + // where we pop the work item containing this closure and pass + // it the same `ComponentInstance`. let mut store = unsafe { StoreContextMut::(&mut *instance.store().cast()) }; assert!(dst.len() <= MAX_FLAT_PARAMS); let mut src = [MaybeUninit::uninit(); MAX_FLAT_PARAMS]; @@ -1051,12 +1143,20 @@ impl ComponentInstance { 1 } CallerInfo::Sync { params, .. } => { + // SAFETY: Transmuting from `&[T]` to + // `&[MaybeUninit]` should be sound for any `T`. src[..params.len()].copy_from_slice(unsafe { mem::transmute::<&[ValRaw], &[MaybeUninit]>(¶ms) }); params.len() } }; + // SAFETY: `start` is a valid `*mut VMFuncRef` from + // `wasmtime-cranelift`-generated fused adapter code. Based on + // how it was constructed (see + // `wasmtime_environ::fact::trampoline::Compiler::compile_async_start_adapter` + // for details) we know it takes count parameters and returns + // `dst.len()` results. unsafe { crate::Func::call_unchecked_raw( &mut store, @@ -1079,11 +1179,19 @@ impl ComponentInstance { }), LiftResult { lift: Box::new(move |instance, src| { + // SAFETY: See comment in closure passed as `lower_params` + // parameter above. let mut store = unsafe { StoreContextMut::(&mut *instance.store().cast()) }; let mut my_src = src.to_owned(); // TODO: use stack to avoid allocation? if let ResultInfo::Heap { results } = &result_info { my_src.push(ValRaw::u32(*results)); } + // SAFETY: `return_` is a valid `*mut VMFuncRef` from + // `wasmtime-cranelift`-generated fused adapter code. Based + // on how it was constructed (see + // `wasmtime_environ::fact::trampoline::Compiler::compile_async_return_adapter` + // for details) we know it takes `src.len()` parameters and + // returns up to 1 result. unsafe { crate::Func::call_unchecked_raw( &mut store, @@ -1141,6 +1249,17 @@ impl ComponentInstance { event: Event, handle: u32, ) -> Result { + // SAFETY: This `ComponentInstance` belongs to the store in which it + // resides, so if it is valid then so is its store. + // + // In addition, the store's data type is known to be `T` because this + // function will have been called with the same `ComponentInstance` that + // was in scope in `ComponentInstance::exit_call` or `prepare_call` -- + // the two functions where this function is monomophized and stored as + // an `fn(..)` in `GuestTask::callback`. See + // `ComponentInstance::{poll_until,handle_work_item,handle_guest_call}`, + // where we pop the work item containing the pointer to this function + // and pass it the same `ComponentInstance`. let mut store = unsafe { StoreContextMut::(&mut *self.store().cast()) }; let mut flags = self.instance_flags(callee_instance); @@ -1150,6 +1269,10 @@ impl ComponentInstance { ValRaw::u32(handle), ValRaw::u32(result), ]; + // SAFETY: `func` is a valid `*mut VMFuncRef` from either + // `wasmtime-cranelift`-generated fused adapter code or + // `component::Options`. Per `wasmparser` callback signature + // validation, we know it takes three parameters and returns one. unsafe { flags.set_may_enter(false); crate::Func::call_unchecked_raw( @@ -1267,7 +1390,9 @@ impl ComponentInstance { Ok(status.pack(waitable)) } - pub(crate) fn wrap_call( + /// SAFETY: The returned future must only be polled (directly or + /// transitively) using `ComponentInstance::poll_until`. + pub(crate) unsafe fn wrap_call( &mut self, store: StoreContextMut, closure: Arc, @@ -1281,18 +1406,35 @@ impl ComponentInstance { + Send + Sync + 'static, - P: Lift + Send + Sync + 'static, - R: Lower + Send + Sync + 'static, + P: Send + Sync + 'static, + R: Send + Sync + 'static, { + // SAFETY: The `get_host_and_store` function we pass here is backed by a + // thread-local variable which `poll_with_state` will populate and reset + // with valid pointers to the store data and the store itself each time + // the returned future is polled, respectively. let mut accessor = unsafe { Accessor::new(get_host_and_store, spawn_task, self.instance()) }; let mut future = Box::pin(async move { closure(&mut accessor, params).await }); let store = VMStoreRawPtr(store.traitobj()); let instance = SendSyncPtr::new(NonNull::new(self).unwrap()); - let future = future::poll_fn(move |cx| { + // SAFETY: `poll_with_state` will populate and reset the thread-local + // state as described above. + // + // This is sound because `ComponentInstance::poll_until` is the only + // place we will poll this future (see the doc comment on this + // function), and it has exclusive access to the store and + // `ComponentInstance` when doing so. + let future = future::poll_fn(move |cx| unsafe { poll_with_state::(store, instance, cx, future.as_mut()) }); + // This `transmute` is to avoid requiring a `T: 'static` bound, which + // should be unnecessary. + // + // SAFETY: We don't store a value of type `T` in the above future, and + // access to the data of type `T` will only happen via the thread-local + // state described above. unsafe { mem::transmute::< Pin> + Send>>, @@ -1301,38 +1443,6 @@ impl ComponentInstance { } } - pub(crate) fn wrap_dynamic_call( - &mut self, - store: StoreContextMut, - closure: Arc, - params: Vec, - ) -> Pin>> + Send + 'static>> - where - F: for<'a> Fn( - &'a mut Accessor, - Vec, - ) -> Pin>> + Send + 'a>> - + Send - + Sync - + 'static, - { - let mut accessor = - unsafe { Accessor::new(get_host_and_store, spawn_task, self.instance()) }; - let mut future = Box::pin(async move { closure(&mut accessor, params).await }); - let store = VMStoreRawPtr(store.traitobj()); - let instance = SendSyncPtr::new(NonNull::new(self).unwrap()); - let future = future::poll_fn(move |cx| { - poll_with_state::(store, instance, cx, future.as_mut()) - }); - - unsafe { - mem::transmute::< - Pin>> + Send>>, - Pin>> + Send + 'static>>, - >(Box::pin(future)) - } - } - pub(crate) fn first_poll( &mut self, store: StoreContextMut, @@ -1352,6 +1462,11 @@ impl ComponentInstance { // respectively. This involves unsafe shenanigans in order to smuggle the // store pointer into the wrapping future, alas. // + // SAFETY: We'll only poll the future in (at most) two places: here in + // this function where we have exclusive access to the store, and (if + // necessary) in `ComponentInstance::poll_until`, where we will again + // have exclusive access to the same store. + // // Note that we also wrap the future in order to provide cancellation // support via `AbortWrapper`. @@ -1360,6 +1475,7 @@ impl ComponentInstance { call_context: &mut Option, task: TableId, ) { + // SAFETY: See SAFETY comment in above in wrapping function. let store = unsafe { StoreContextMut::(&mut *store.0.as_ptr().cast()) }; if let Some(call_context) = call_context.take() { log::trace!("push call context for {task:?}"); @@ -1373,6 +1489,7 @@ impl ComponentInstance { task: TableId, ) { log::trace!("pop call context for {task:?}"); + // SAFETY: See SAFETY comment in above in wrapping function. let store = unsafe { StoreContextMut::(&mut *store.0.as_ptr().cast()) }; *call_context = Some(store.0.component_resource_state().0.pop().unwrap()); } @@ -1413,6 +1530,7 @@ impl ComponentInstance { let mut future = Box::pin(future.map(move |result| { if let Some(result) = result { HostTaskOutput::Function(Box::new(move |instance| { + // SAFETY: See SAFETY comment in above in wrapping function. let store = unsafe { StoreContextMut(&mut *instance.store().cast()) }; lower(store, result?)?; instance.get_mut(task)?.abort_handle.take(); @@ -1453,9 +1571,8 @@ impl ComponentInstance { ) } - pub(crate) fn poll_and_block( + pub(crate) fn poll_and_block( &mut self, - mut store: StoreContextMut, future: impl Future> + Send + 'static, caller_instance: RuntimeComponentInstanceIndex, ) -> Result { @@ -1490,34 +1607,35 @@ impl ComponentInstance { })) })) as HostTaskFuture; - let Some(cx) = AsyncCx::try_new(&mut store.0) else { - return Err(anyhow!("future dropped")); - }; - - Ok(match unsafe { cx.poll(future.as_mut()) } { - Poll::Ready(output) => { - output.consume(self)?; - log::trace!("delete host task {task:?} (already ready)"); - self.delete(task)?; - let result = *mem::replace(&mut self.get_mut(caller)?.result, old_result) - .unwrap() - .downcast() - .unwrap(); - result - } - Poll::Pending => { - self.push_future(future); + Ok( + match future + .as_mut() + .poll(&mut Context::from_waker(&dummy_waker())) + { + Poll::Ready(output) => { + output.consume(self)?; + log::trace!("delete host task {task:?} (already ready)"); + self.delete(task)?; + let result = *mem::replace(&mut self.get_mut(caller)?.result, old_result) + .unwrap() + .downcast() + .unwrap(); + result + } + Poll::Pending => { + self.push_future(future); - let set = self.get_mut(caller)?.sync_call_set; - Waitable::Host(task).join(self, Some(set))?; + let set = self.get_mut(caller)?.sync_call_set; + Waitable::Host(task).join(self, Some(set))?; - self.suspend(SuspendReason::Waiting { set, task: caller })?; + self.suspend(SuspendReason::Waiting { set, task: caller })?; - let result = self.get_mut(caller)?.result.take().unwrap(); - self.get_mut(caller)?.result = old_result; - *result.downcast().unwrap() - } - }) + let result = self.get_mut(caller)?.result.take().unwrap(); + self.get_mut(caller)?.result = old_result; + *result.downcast().unwrap() + } + }, + ) } async fn poll_until( @@ -1534,6 +1652,16 @@ impl ComponentInstance { future.as_mut().poll(cx) } + // Here we smuggle the `ComponentInstance` pointer into the future so we + // can use it while polling without upsetting the borrow checker given + // that we're also mutably borrowing `ConcurrentState::futures` to poll + // it. + // + // SAFETY: This is morally equivalent to a split borrow, since we are + // careful not to touch `ConcurrentState::futures` at any time while + // polling. See `ComponentInstance::push_future` which, explicitly + // defers touching `futures` by queuing a work item which we'll run only + // _after_ polling. let instance = SendSyncPtr::new(NonNull::new(self).unwrap()); let mut future = pin!(future); @@ -1548,6 +1676,7 @@ impl ComponentInstance { let next = match poll_with(instance, &mut next, cx) { Poll::Ready(Some(output)) => { + // SAFETY: See SAFETY comment in outer scope above. let me = unsafe { &mut *instance.as_ptr() }; if let Err(e) = output.consume(me) { return Poll::Ready(Err(e)); @@ -1558,6 +1687,7 @@ impl ComponentInstance { Poll::Pending => Poll::Pending, }; + // SAFETY: See SAFETY comment in outer scope above. let me = unsafe { &mut *instance.as_ptr() }; let ready = mem::take(&mut me.concurrent_state.high_priority); let ready = if ready.is_empty() { @@ -1688,6 +1818,14 @@ impl ComponentInstance { log::trace!("resume_fiber: save current task {old_task:?}"); let guard_range = fiber.guard_range(); let mut fiber = Some(fiber); + // Here we pass control of the store to the fiber, which requires + // smuggling it as a `VMStoreRawPtr` in order to ensure the future is + // `Send`. + // + // SAFETY: This `ComponentInstance` belongs to the store in which it + // resides, so if it is valid then so is its store. By the time the + // future returned by `poll_fn` completes, we'll have exclusive access + // to it again. let fiber = unsafe { poll_fn( VMStoreRawPtr(NonNull::new(self.store()).unwrap()), @@ -1804,10 +1942,23 @@ impl ComponentInstance { let worker = if let Some(fiber) = self.worker().take() { fiber } else { - let instance = SendSyncPtr::new(NonNull::new(self).unwrap()); + // Here we smuggle the `ComponentInstance` pointer into the closure + // so that the fiber can use it without upsetting the borrow + // checker. + // + // SAFETY: We will only resume this fiber in either + // `ComponentInstance::handle_work_item` or + // `ComponentInstnace::run_on_worker`, where we'll have exclusive + // access to the same `ComponentInstance` and thus be able to grant + // the same access to the fiber we're resuming. + // + // TODO: Consider adding `*mut ComponentInstance` parameters to + // `StoreFiber`'s `suspend` and `resume` signatures to make this + // handoff more explicit. + let instance = self as *mut Self; unsafe { make_fiber(self.store(), move |_| { - let instance = &mut *instance.as_ptr(); + let instance = &mut *instance; loop { let call = instance.guest_call().take().unwrap(); instance.handle_guest_call(call)?; @@ -1842,6 +1993,11 @@ impl ComponentInstance { assert!(self.suspend_reason().is_none()); *self.suspend_reason() = Some(reason); + // SAFETY: This `ComponentInstance` belongs to the store in which it + // resides, so if it is valid then so is its store. In addition, this + // is only ever called from a fiber that belongs (via the + // `ComponentInstance`) to that store (and would in any case panic if + // called from outside any fiber). unsafe { let async_cx = AsyncCx::new((*self.store()).store_opaque_mut()); async_cx.suspend(Some(self.store()))?; @@ -1863,6 +2019,9 @@ impl ComponentInstance { storage: *mut ValRaw, storage_len: usize, ) -> Result<()> { + // SAFETY: The `wasmtime_cranelift`-generated code that calls this + // method will have ensured that `storage` is a valid pointer containing + // at least `storage_len` items. let storage = unsafe { std::slice::from_raw_parts(storage, storage_len) }; let guest_task = self.guest_task().unwrap(); let lift = self @@ -1916,6 +2075,8 @@ impl ComponentInstance { result: Box, status: Status, ) -> Result<()> { + // SAFETY: This `ComponentInstance` belongs to the store in which it + // resides, so if it is valid then so is its store. let (calls, host_table, _) = unsafe { &mut *self.store() } .store_opaque_mut() .component_resource_state(); @@ -2070,6 +2231,11 @@ impl ComponentInstance { WaitableCheck::Wait(params) | WaitableCheck::Poll(params) => { let event = self.get_event(guest_task, params.caller_instance, Some(params.set))?; + // SAFETY: This `ComponentInstance` belongs to the store in + // which it resides, so if it is valid then so is its store. In + // addition, `params.memory` is a valid `*mut + // VMMemoryDefinition` passed to this intrinsic via + // `wasmtime_cranelift`-generated code. let store_and_options = |me: &mut Self| unsafe { let store = (*me.store()).store_opaque_mut(); let options = Options::new( @@ -2380,6 +2546,9 @@ impl Instance { ) -> Result { check_recursive_run(); let store = store.as_context_mut(); + // SAFETY: We have exclusive access to the store, which we means we have + // exclusive access to any `ComponentInstance` which resides in the + // store. let instance = unsafe { &mut *store.0[self.0].as_ref().unwrap().instance_ptr() }; instance.poll_until(fut).await } @@ -2496,17 +2665,23 @@ impl Instance { + 'static, { let store = store.as_context_mut(); + // SAFETY: We have exclusive access to the store, which we means we have + // exclusive access to any `ComponentInstance` which resides in the + // store. let instance = unsafe { &mut *store.0[self.0].as_ref().unwrap().instance_ptr() }; + // SAFETY: See corresponding comment in `ComponentInstance::wrap_call`. let mut accessor = unsafe { Accessor::new(get_host_and_store, spawn_task, instance.instance()) }; let mut future = Box::pin(async move { fun(&mut accessor).await }); let store = VMStoreRawPtr(store.traitobj()); let instance = SendSyncPtr::new(NonNull::new(instance).unwrap()); - let future = future::poll_fn(move |cx| { + // SAFETY: See corresponding comment in `ComponentInstance::wrap_call`. + let future = future::poll_fn(move |cx| unsafe { poll_with_state::(store, instance, cx, future.as_mut()) }); + // SAFETY: See corresponding comment in `ComponentInstance::wrap_call`. unsafe { mem::transmute::< Pin + Send>>, @@ -2547,6 +2722,9 @@ impl Instance { mut store: impl AsContextMut, task: impl std::future::Future> + Send + 'static, ) { + // SAFETY: We have exclusive access to the store, which we means we have + // exclusive access to any `ComponentInstance` which resides in the + // store. let instance = unsafe { &mut *store.as_context_mut().0[self.0] .as_ref() @@ -2741,6 +2919,9 @@ unsafe impl VMComponentAsyncStore for StoreInner { memory, string_encoding, CallerInfo::Sync { + // SAFETY: The `wasmtime_cranelift`-generated code that calls + // this method will have ensured that `storage` is a valid + // pointer containing at least `storage_len` items. params: unsafe { std::slice::from_raw_parts(storage, storage_len) }.to_vec(), result_count, }, @@ -2764,6 +2945,9 @@ unsafe impl VMComponentAsyncStore for StoreInner { param_count, 1, EXIT_FLAG_ASYNC_CALLEE, + // SAFETY: The `wasmtime_cranelift`-generated code that calls + // this method will have ensured that `storage` is a valid + // pointer containing at least `storage_len` items. Some(unsafe { std::slice::from_raw_parts_mut(storage, storage_len) }), ) .map(drop) @@ -3388,7 +3572,8 @@ impl TableDebug for WaitableSet { type RawLower = Box]) -> Result<()> + Send + Sync>; -pub type LowerFn = fn(Func, *mut dyn VMStore, *mut u8, &mut [MaybeUninit]) -> Result<()>; +pub type LowerFn = + unsafe fn(Func, *mut dyn VMStore, *mut u8, &mut [MaybeUninit]) -> Result<()>; type RawLift = Box< dyn FnOnce(&mut ComponentInstance, &[ValRaw]) -> Result> @@ -3396,7 +3581,8 @@ type RawLift = Box< + Sync, >; -pub type LiftFn = fn(Func, *mut dyn VMStore, &[ValRaw]) -> Result>; +pub type LiftFn = + unsafe fn(Func, *mut dyn VMStore, &[ValRaw]) -> Result>; type LiftedResult = Box; @@ -3952,6 +4138,8 @@ impl<'a> Drop for ResetPtr<'a> { } } +// SAFETY: The specified pointer must be a valid `*mut T` which was originally +// allocated as a `Box`. pub(crate) unsafe fn drop_params(pointer: *mut u8) { drop(unsafe { Box::from_raw(pointer as *mut T) }) } @@ -3971,7 +4159,34 @@ impl PreparedCall { } } -pub(crate) fn prepare_call( +/// Prepare a call to the specified exported Wasm function, providing functions +/// for lowering the parameters and lifting the result. +/// +/// To enqueue the returned `PreparedCall` in the `ComponentInstance`'s event +/// loop, use `queue_call`. +/// +/// Note that this function is used in `TypedFunc::call_async`, which accepts +/// parameters of a generic type which might not be `'static`. However the +/// `GuestTask` created by this function must be `'static`, so it can't safely +/// close over those parameters. Instead, `PreparedCall` has a `params` field +/// of type `Arc>`, which the caller is responsible for setting to +/// a valid, non-null pointer to the params prior to polling the event loop (at +/// least until the parameters have been lowered), and then resetting back to +/// null afterward. That ensures that the lowering code never sees a stale +/// pointer, even if the application `drop`s or `mem::forget`s the future +/// returned by `TypedFunc::call_async`. +/// +/// In the case where the parameters are passed using a type that _is_ +/// `'static`, they can be boxed and stored in `PreparedCall::params` +/// indefinitely; `drop_params` will be called when they are no longer needed. +/// +/// SAFETY: The `lower_params` and `drop_params` functions must accept (and +/// either ignore or safely use) any non-null pointer stored in the `params` +/// field of the returned `PreparedCall`, and that pointer must be valid for as +/// long as it is stored in that field. Also, the `lower_params` and +/// `lift_result` functions must both use their other pointer arguments safely +/// or not at all. +pub(crate) unsafe fn prepare_call( store: StoreContextMut, lower_params: LowerFn, drop_params: unsafe fn(*mut u8), @@ -3987,6 +4202,8 @@ pub(crate) fn prepare_call( let memory = func_data.options.memory.map(SendSyncPtr::new); let string_encoding = func_data.options.string_encoding(); + // SAFETY: We have exclusive access to the store, which we means we have + // exclusive access to any `ComponentInstance` which resides in the store. let instance = unsafe { &mut *store.0[instance.0].as_ref().unwrap().instance_ptr() }; let params = Arc::new(AtomicPtr::new(ptr::null_mut())); @@ -4004,6 +4221,9 @@ pub(crate) fn prepare_call( fn drop(&mut self) { let ptr = self.params.swap(ptr::null_mut(), Relaxed); if !ptr.is_null() { + // SAFETY: Per the contract of the `prepare_call`, `ptr` must be + // valid and `self.dropper` must either use it safely or not at + // all. unsafe { (self.dropper)(ptr); } @@ -4024,13 +4244,20 @@ pub(crate) fn prepare_call( let ptr = param_ptr.load(Relaxed); let result = if ptr.is_null() { // If we've reached here, it presumably means we were called - // via `{Typed}Func::call_async` and the future was dropped - // or `mem::forget`ed by the caller, meaning we no longer - // have access to the parameters. In that case, we should + // via `TypedFunc::call_async` and the future was dropped or + // `mem::forget`ed by the caller, meaning we no longer have + // access to the parameters. In that case, we should // gracefully cancel the call without trapping or panicking. todo!("gracefully cancel `call_async` tasks when future is dropped") } else { - lower_params(handle, instance.store(), ptr, params) + // SAFETY: The provided `ComponentInstance` belongs to the + // store in which it resides, so if it is valid then so is + // its store. + // + // Also, per the contract of `prepare_call`, `ptr` must be + // valid and `lower_params` must either use it safely or not at + // all. + unsafe { lower_params(handle, instance.store(), ptr, params) } }; drop(drop_params); result @@ -4038,7 +4265,9 @@ pub(crate) fn prepare_call( })), LiftResult { lift: Box::new(for_any_lift(move |instance, result| { - lift_result(handle, instance.store(), result) + // SAFETY: The provided `ComponentInstance` belongs to the store + // in which it resides, so if it is valid then so is its store. + unsafe { lift_result(handle, instance.store(), result) } })), ty: task_return_type, memory, @@ -4064,7 +4293,7 @@ pub(crate) fn prepare_call( }) } -pub(crate) fn defer_call( +pub(crate) fn queue_call( mut store: StoreContextMut, prepared: PreparedCall, ) -> Result> + Send + 'static + use> { @@ -4106,6 +4335,8 @@ fn start_call( let callback = func_data.options.callback; let post_return = func_data.post_return; + // SAFETY: We have exclusive access to the store, which we means we have + // exclusive access to any `ComponentInstance` which resides in the store. let instance = unsafe { &mut *store.0[instance.0].as_ref().unwrap().instance_ptr() }; log::trace!("starting call {guest_task:?}"); @@ -4142,6 +4373,20 @@ fn start_call( Ok(()) } +/// Wrap the specified function in a future which, when polled, will store a +/// pointer to the `Context` in the `AsyncState::current_poll_cx` field for the +/// specified store and then call the function. +/// +/// This is intended for use with functions that resume fibers which may need to +/// poll futures using the stored `Context` pointer. The function should return +/// `Ok(_)` when complete, or `Err(_)` if the future should be polled again +/// later. +/// +/// SAFETY: The `store` parameter must be a valid `*mut dyn VMStore` and `fun` +/// must use it safely. The returned future must be polled to completion before +/// the store can be used by the caller again. Finally, `fun` must not attempt +/// to use the `Context` pointer storied in `AsyncState::current_poll_cx` beyond +/// the scope of the current call. async unsafe fn poll_fn( store: VMStoreRawPtr, guard_range: (Option>, Option>), @@ -4162,15 +4407,20 @@ async unsafe fn poll_fn( future::poll_fn({ let mut store = Some(store); - move |cx| unsafe { + move |cx| { let _reset = Reset(poll_cx.0, *poll_cx.0); let guard_range_start = guard_range.0.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut()); let guard_range_end = guard_range.1.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut()); - *poll_cx.0 = PollContext { - future_context: mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx), - guard_range_start, - guard_range_end, - }; + // SAFETY: We store the pointer to the `Context` only for the + // duration of this call and then reset it to its previous value + // afterward, thereby ensuring `fun` never sees a stale pointer. + unsafe { + *poll_cx.0 = PollContext { + future_context: mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx), + guard_range_start, + guard_range_end, + }; + } #[allow(dropping_copy_types)] drop(poll_cx); diff --git a/crates/wasmtime/src/runtime/component/func.rs b/crates/wasmtime/src/runtime/component/func.rs index a84d27ad9d..edae9fe31b 100644 --- a/crates/wasmtime/src/runtime/component/func.rs +++ b/crates/wasmtime/src/runtime/component/func.rs @@ -16,7 +16,7 @@ use wasmtime_environ::component::{ }; #[cfg(feature = "component-model-async")] -use crate::component::concurrent::{self, LiftFn, LowerFn, PreparedCall}; +use crate::component::concurrent::{self, PreparedCall}; #[cfg(feature = "component-model-async")] use crate::VMStore; #[cfg(feature = "component-model-async")] @@ -387,14 +387,21 @@ impl Func { let result = (|| { self.check_param_count(store.as_context_mut(), params.len())?; - let prepared = self.prepare_call_dynamic( - store.as_context_mut(), - concurrent::drop_params::>, - )?; + // SAFETY: We uphold the contract documented in + // `concurrent::prepare_call` by setting `PreparedCall::params` to a + // valid pointer prior to polling the event loop for this function's + // instance and providing a `drop_params` parameter which will + // correctly dispose of it after lowering. + let prepared = unsafe { + self.prepare_call_dynamic( + store.as_context_mut(), + concurrent::drop_params::>, + ) + }?; prepared .params() .store(Box::into_raw(Box::new(params)).cast(), Relaxed); - concurrent::defer_call(store, prepared) + concurrent::queue_call(store, prepared) })(); match result { @@ -403,8 +410,12 @@ impl Func { } } + /// Calls `concurrent::prepare_call` with monomorphized functions for + /// lowering the parameters and lifting the result. + /// + /// SAFETY: See `concurrent::prepare_call`. #[cfg(feature = "component-model-async")] - fn prepare_call_dynamic<'a, T: Send>( + unsafe fn prepare_call_dynamic<'a, T: Send>( self, mut store: StoreContextMut<'a, T>, drop_params: unsafe fn(*mut u8), @@ -418,7 +429,7 @@ impl Func { Self::lift_results_sync_fn:: }; - self.prepare_call(store, lower, drop_params, lift, MAX_FLAT_PARAMS) + concurrent::prepare_call(store, lower, drop_params, lift, self, MAX_FLAT_PARAMS) } fn call_impl( @@ -470,18 +481,6 @@ impl Func { ) } - #[cfg(feature = "component-model-async")] - fn prepare_call<'a, T: Send, Return: Send + Sync + 'static>( - &self, - store: StoreContextMut<'a, T>, - lower: LowerFn, - drop_params: unsafe fn(*mut u8), - lift: LiftFn, - param_count: usize, - ) -> Result> { - concurrent::prepare_call(store, lower, drop_params, lift, *self, param_count) - } - /// Invokes the underlying wasm function, lowering arguments and lifting the /// result. /// diff --git a/crates/wasmtime/src/runtime/component/func/host.rs b/crates/wasmtime/src/runtime/component/func/host.rs index dc9620e7af..bee5fd9e4f 100644 --- a/crates/wasmtime/src/runtime/component/func/host.rs +++ b/crates/wasmtime/src/runtime/component/func/host.rs @@ -83,9 +83,8 @@ impl HostFunc { R: ComponentNamedList + Lower + Send + Sync + 'static, { let func = Arc::new(func); - Self::from_canonical(move |store, instance, params| { - let instance = unsafe { &mut *instance }; - instance.wrap_call(store, func.clone(), params) + Self::from_canonical(move |store, instance, params| unsafe { + (*instance).wrap_call(store, func.clone(), params) }) } @@ -184,9 +183,8 @@ impl HostFunc { + 'static, { let func = Arc::new(func); - Self::new_dynamic_canonical(move |store, instance, params, _| { - let instance = unsafe { &mut *instance }; - instance.wrap_dynamic_call(store, func.clone(), params) + Self::new_dynamic_canonical(move |store, instance, params, _| unsafe { + (*instance).wrap_call(store, func.clone(), params) }) } @@ -374,7 +372,7 @@ where let future = closure(cx.as_context_mut(), instance, params); - let ret = (*instance).poll_and_block(cx.as_context_mut(), future, caller_instance)?; + let ret = (*instance).poll_and_block(future, caller_instance)?; flags.set_may_leave(false); let mut lower = LowerContext::new(cx, &options, types, instance); @@ -639,8 +637,7 @@ where args, result_tys.types.len(), ); - let result_vals = - (*instance).poll_and_block(store.as_context_mut(), future, caller_instance)?; + let result_vals = (*instance).poll_and_block(future, caller_instance)?; flags.set_may_leave(false); diff --git a/crates/wasmtime/src/runtime/component/func/typed.rs b/crates/wasmtime/src/runtime/component/func/typed.rs index 5635de71f5..d5146aaf38 100644 --- a/crates/wasmtime/src/runtime/component/func/typed.rs +++ b/crates/wasmtime/src/runtime/component/func/typed.rs @@ -210,9 +210,14 @@ where let mut params = params; let mut store = store; let instance = store.0[self.func.0].instance; - let prepared = self.prepare_call(store.as_context_mut(), drop)?; + // SAFETY: We uphold the contract documented in + // `concurrent::prepare_call` by only setting `PreparedCall::params` + // to a valid pointer while polling the event loop and resetting it + // to null afterward, thus ensuring that the parameter lowering code + // never sees a stale pointer. + let prepared = unsafe { self.prepare_call(store.as_context_mut(), drop) }?; let param_ptr = prepared.params().clone(); - let call = concurrent::defer_call(store.as_context_mut(), prepared)?; + let call = concurrent::queue_call(store.as_context_mut(), prepared)?; let mut future = pin!(instance.run(store, call)); future::poll_fn(move |cx| { let params = &mut params; @@ -260,12 +265,18 @@ where ); let result = (|| { - let prepared = - self.prepare_call(store.as_context_mut(), concurrent::drop_params::)?; + // SAFETY: We uphold the contract documented in + // `concurrent::prepare_call` by setting `PreparedCall::params` to a + // valid pointer prior to polling the event loop for this function's + // instance and providing a `drop_params` parameter which will + // correctly dispose of it after lowering. + let prepared = unsafe { + self.prepare_call(store.as_context_mut(), concurrent::drop_params::) + }?; prepared .params() .store(Box::into_raw(Box::new(params)).cast(), Relaxed); - concurrent::defer_call(store, prepared) + concurrent::queue_call(store, prepared) })(); match result { @@ -274,8 +285,14 @@ where } } + /// Calls `concurrent::prepare_call` with monomorphized functions for + /// lowering the parameters and lifting the result according to the number + /// of core Wasm parameters and results in the signature of the function to + /// be called. + /// + /// SAFETY: See `concurrent::prepare_call`. #[cfg(feature = "component-model-async")] - fn prepare_call<'a, T: Send>( + unsafe fn prepare_call<'a, T: Send>( self, store: StoreContextMut<'a, T>, drop_params: unsafe fn(*mut u8), @@ -288,74 +305,82 @@ where if store.0[self.func.0].options.async_() { if Params::flatten_count() <= MAX_FLAT_PARAMS { if Return::flatten_count() <= MAX_FLAT_PARAMS { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_stack_args_fn::, drop_params, Self::lift_stack_result_fn::, + self.func, param_count, ) } else { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_stack_args_fn::, drop_params, Self::lift_heap_result_fn::, + self.func, param_count, ) } } else { if Return::flatten_count() <= MAX_FLAT_PARAMS { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_heap_args_fn::, drop_params, Self::lift_stack_result_fn::, + self.func, 1, ) } else { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_heap_args_fn::, drop_params, Self::lift_heap_result_fn::, + self.func, 1, ) } } } else if Params::flatten_count() <= MAX_FLAT_PARAMS { if Return::flatten_count() <= MAX_FLAT_RESULTS { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_stack_args_fn::, drop_params, Self::lift_stack_result_fn::, + self.func, param_count, ) } else { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_stack_args_fn::, drop_params, Self::lift_heap_result_fn::, + self.func, param_count, ) } } else { if Return::flatten_count() <= MAX_FLAT_RESULTS { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_heap_args_fn::, drop_params, Self::lift_stack_result_fn::, + self.func, 1, ) } else { - self.func.prepare_call( + concurrent::prepare_call( store, Self::lower_heap_args_fn::, drop_params, Self::lift_heap_result_fn::, + self.func, 1, ) } diff --git a/crates/wasmtime/src/runtime/component/mod.rs b/crates/wasmtime/src/runtime/component/mod.rs index a6f2cef5ef..787dd1c372 100644 --- a/crates/wasmtime/src/runtime/component/mod.rs +++ b/crates/wasmtime/src/runtime/component/mod.rs @@ -705,7 +705,6 @@ pub(crate) mod concurrent { Val, }, vm::component::ComponentInstance, - StoreContextMut, }, alloc::{sync::Arc, task::Wake}, anyhow::Result, @@ -729,9 +728,8 @@ pub(crate) mod concurrent { } impl ComponentInstance { - pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>( + pub(crate) fn poll_and_block( &mut self, - _store: StoreContextMut<'a, T>, future: impl Future> + Send + 'static, _caller_instance: RuntimeComponentInstanceIndex, ) -> Result {