Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 105 additions & 85 deletions crates/wasmtime/src/runtime/component/concurrent.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![deny(unsafe_op_in_unsafe_fn)]

//! Runtime support for the Component Model Async ABI.
//!
//! This module and its submodules provide host runtime support for Component
Expand Down Expand Up @@ -2464,17 +2466,19 @@ impl Instance {
};

// Queue the call as a "high priority" work item.
self.queue_call(
store.as_context_mut(),
guest_task,
callee,
param_count,
result_count,
instance_flags,
(flags & START_FLAG_ASYNC_CALLEE) != 0,
NonNull::new(callback).map(SendSyncPtr::new),
NonNull::new(post_return).map(SendSyncPtr::new),
)?;
unsafe {
self.queue_call(
store.as_context_mut(),
guest_task,
callee,
param_count,
result_count,
instance_flags,
(flags & START_FLAG_ASYNC_CALLEE) != 0,
NonNull::new(callback).map(SendSyncPtr::new),
NonNull::new(post_return).map(SendSyncPtr::new),
)?;
}

let state = self.concurrent_state_mut(store.0);

Expand Down Expand Up @@ -3490,30 +3494,32 @@ impl<T: 'static> VMComponentAsyncStore for StoreInner<T> {
// pointer containing at least `storage_len` items.
let params = unsafe { std::slice::from_raw_parts(storage, storage_len) }.to_vec();

instance.prepare_call(
StoreContextMut(self),
start,
return_,
caller_instance,
callee_instance,
task_return_type,
memory,
string_encoding,
match result_count_or_max_if_async {
PREPARE_ASYNC_NO_RESULT => CallerInfo::Async {
params,
has_result: false,
},
PREPARE_ASYNC_WITH_RESULT => CallerInfo::Async {
params,
has_result: true,
},
result_count => CallerInfo::Sync {
params,
result_count,
unsafe {
instance.prepare_call(
StoreContextMut(self),
start,
return_,
caller_instance,
callee_instance,
task_return_type,
memory,
string_encoding,
match result_count_or_max_if_async {
PREPARE_ASYNC_NO_RESULT => CallerInfo::Async {
params,
has_result: false,
},
PREPARE_ASYNC_WITH_RESULT => CallerInfo::Async {
params,
has_result: true,
},
result_count => CallerInfo::Sync {
params,
result_count,
},
},
},
)
)
}
}

unsafe fn sync_start(
Expand All @@ -3525,21 +3531,23 @@ impl<T: 'static> VMComponentAsyncStore for StoreInner<T> {
storage: *mut MaybeUninit<ValRaw>,
storage_len: usize,
) -> Result<()> {
instance
.start_call(
StoreContextMut(self),
callback,
ptr::null_mut(),
callee,
param_count,
1,
START_FLAG_ASYNC_CALLEE,
// SAFETY: The `wasmtime_cranelift`-generated code that calls
// this method will have ensured that `storage` is a valid
// pointer containing at least `storage_len` items.
Some(unsafe { std::slice::from_raw_parts_mut(storage, storage_len) }),
)
.map(drop)
unsafe {
instance
.start_call(
StoreContextMut(self),
callback,
ptr::null_mut(),
callee,
param_count,
1,
START_FLAG_ASYNC_CALLEE,
// SAFETY: The `wasmtime_cranelift`-generated code that calls
// this method will have ensured that `storage` is a valid
// pointer containing at least `storage_len` items.
Some(std::slice::from_raw_parts_mut(storage, storage_len)),
)
.map(drop)
}
}

unsafe fn async_start(
Expand All @@ -3552,16 +3560,18 @@ impl<T: 'static> VMComponentAsyncStore for StoreInner<T> {
result_count: u32,
flags: u32,
) -> Result<u32> {
instance.start_call(
StoreContextMut(self),
callback,
post_return,
callee,
param_count,
result_count,
flags,
None,
)
unsafe {
instance.start_call(
StoreContextMut(self),
callback,
post_return,
callee,
param_count,
result_count,
flags,
None,
)
}
}

unsafe fn future_write(
Expand Down Expand Up @@ -4403,11 +4413,13 @@ impl AsyncCx {
///
/// SAFETY: TODO
unsafe fn poll<U>(&self, mut future: Pin<&mut (dyn Future<Output = U> + Send)>) -> Poll<U> {
let poll_cx = *self.current_poll_cx;
let _reset = Reset(self.current_poll_cx, poll_cx);
*self.current_poll_cx = PollContext::default();
assert!(!poll_cx.future_context.is_null());
future.as_mut().poll(&mut *poll_cx.future_context)
unsafe {
let poll_cx = *self.current_poll_cx;
let _reset = Reset(self.current_poll_cx, poll_cx);
*self.current_poll_cx = PollContext::default();
assert!(!poll_cx.future_context.is_null());
future.as_mut().poll(&mut *poll_cx.future_context)
}
}

/// "Stackfully" poll the specified future by alternately polling it and
Expand All @@ -4419,10 +4431,12 @@ impl AsyncCx {
mut future: Pin<&mut (dyn Future<Output = U> + Send)>,
) -> Result<U> {
loop {
match self.poll(future.as_mut()) {
Poll::Ready(v) => break Ok(v),
Poll::Pending => {
self.suspend(None)?;
unsafe {
match self.poll(future.as_mut()) {
Poll::Ready(v) => break Ok(v),
Poll::Pending => {
self.suspend(None)?;
}
}
}
}
Expand All @@ -4440,7 +4454,7 @@ impl AsyncCx {
} else {
ProtectionMask::all()
};
let store = suspend_fiber(self.current_suspend, self.current_stack_limit, store);
let store = unsafe { suspend_fiber(self.current_suspend, self.current_stack_limit, store) };
if self.track_pkey_context_switch {
mpk::allow(previous_mask);
}
Expand Down Expand Up @@ -4886,17 +4900,19 @@ unsafe fn resume_fiber_raw<'a>(
}
}

let _reset_executor = Reset((*fiber).executor_ptr, *(*fiber).executor_ptr);
*(*fiber).executor_ptr = &raw mut (*fiber).executor;
let _reset_suspend = Reset((*fiber).suspend, *(*fiber).suspend);
let _reset_stack_limit = Reset((*fiber).stack_limit, *(*fiber).stack_limit);
let state = Some((*fiber).state.take().unwrap().push());
let restore = Restore { fiber, state };
(*restore.fiber)
.fiber
.as_ref()
.unwrap()
.resume((store, result))
unsafe {
let _reset_executor = Reset((*fiber).executor_ptr, *(*fiber).executor_ptr);
*(*fiber).executor_ptr = &raw mut (*fiber).executor;
let _reset_suspend = Reset((*fiber).suspend, *(*fiber).suspend);
let _reset_stack_limit = Reset((*fiber).stack_limit, *(*fiber).stack_limit);
let state = Some((*fiber).state.take().unwrap().push());
let restore = Restore { fiber, state };
(*restore.fiber)
.fiber
.as_ref()
.unwrap()
.resume((store, result))
}
}

/// See `resume_fiber_raw`
Expand All @@ -4905,7 +4921,9 @@ unsafe fn resume_fiber(
store: Option<*mut dyn VMStore>,
result: Result<()>,
) -> Result<Result<(*mut dyn VMStore, Result<()>), Option<*mut dyn VMStore>>> {
match resume_fiber_raw(fiber, store, result).map(|(store, result)| (store.unwrap(), result)) {
match unsafe {
resume_fiber_raw(fiber, store, result).map(|(store, result)| (store.unwrap(), result))
} {
Ok(pair) => Ok(Ok(pair)),
Err(s) => {
if let Some(range) = fiber.fiber.as_ref().unwrap().stack().range() {
Expand All @@ -4932,10 +4950,12 @@ unsafe fn suspend_fiber(
stack_limit: *mut usize,
store: Option<*mut dyn VMStore>,
) -> Result<Option<*mut dyn VMStore>> {
let _reset_suspend = Reset(suspend, *suspend);
let _reset_stack_limit = Reset(stack_limit, *stack_limit);
assert!(!(*suspend).is_null());
let (store, result) = (**suspend).suspend(store);
let (store, result) = unsafe {
let _reset_suspend = Reset(suspend, *suspend);
let _reset_stack_limit = Reset(stack_limit, *stack_limit);
assert!(!(*suspend).is_null());
(**suspend).suspend(store)
};
result?;
Ok(store)
}
Expand Down