diff --git a/crates/wasmtime/src/runtime/component/concurrent.rs b/crates/wasmtime/src/runtime/component/concurrent.rs index cf813b7f23..ebbf395eb0 100644 --- a/crates/wasmtime/src/runtime/component/concurrent.rs +++ b/crates/wasmtime/src/runtime/component/concurrent.rs @@ -157,8 +157,7 @@ enum Event { }, StreamRead { code: ReturnCode, - handle: u32, - ty: TypeStreamTableIndex, + pending: Option<(TypeStreamTableIndex, u32)>, }, StreamWrite { code: ReturnCode, @@ -166,8 +165,7 @@ enum Event { }, FutureRead { code: ReturnCode, - handle: u32, - ty: TypeFutureTableIndex, + pending: Option<(TypeFutureTableIndex, u32)>, }, FutureWrite { code: ReturnCode, @@ -4291,7 +4289,10 @@ impl Waitable { /// the state of the stream or future. fn on_delivery(&self, instance: &mut ComponentInstance, event: Event) { match event { - Event::FutureRead { ty, handle, .. } + Event::FutureRead { + pending: Some((ty, handle)), + .. + } | Event::FutureWrite { pending: Some((ty, handle)), .. @@ -4313,7 +4314,10 @@ impl Waitable { _ => unreachable!(), }; } - Event::StreamRead { ty, handle, .. } + Event::StreamRead { + pending: Some((ty, handle)), + .. + } | Event::StreamWrite { pending: Some((ty, handle)), .. diff --git a/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs index e9b13520e6..dab1e0c6a4 100644 --- a/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs +++ b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs @@ -1675,7 +1675,7 @@ impl ComponentInstance { })? } WriteEvent::Close => super::with_local_instance(|_, instance| { - instance.host_close_writer(rep) + instance.host_close_writer(rep, kind) })?, WriteEvent::Watch { tx } => super::with_local_instance(|_, instance| { let state = instance.get_mut(TableId::::new(rep))?; @@ -1736,7 +1736,7 @@ impl ComponentInstance { })? } ReadEvent::Close => super::with_local_instance(|store, instance| { - instance.host_close_reader(store, rep) + instance.host_close_reader(store, rep, kind) })?, ReadEvent::Watch { tx } => super::with_local_instance(|_, instance| { let state = instance.get_mut(TableId::::new(rep))?; @@ -1767,6 +1767,82 @@ impl ComponentInstance { Waitable::Transmit(TableId::::new(waitable)).set_event(self, Some(event)) } + /// Set or update the event for the specified waitable. + /// + /// If there is already an event set for this waitable, we assert that it is + /// of the same variant as the new one and reuse the `ReturnCode` count and + /// the `pending` field if applicable. + // TODO: This is a bit awkward due to how + // `Event::{Stream,Future}{Write,Read}` and + // `ReturnCode::{Completed,Closed,Cancelled}` are currently represented. + // Consider updating those representations in a way that allows this + // function to be simplified. + fn update_event(&mut self, waitable: u32, event: Event) -> Result<()> { + let waitable = Waitable::Transmit(TableId::::new(waitable)); + + fn update_code(old: ReturnCode, new: ReturnCode) -> ReturnCode { + let (ReturnCode::Completed(count) + | ReturnCode::Closed(count) + | ReturnCode::Cancelled(count)) = old + else { + unreachable!() + }; + + match new { + ReturnCode::Closed(0) => ReturnCode::Closed(count), + ReturnCode::Cancelled(0) => ReturnCode::Cancelled(count), + _ => unreachable!(), + } + } + + let event = match (waitable.take_event(self)?, event) { + (None, _) => event, + ( + Some(Event::FutureWrite { + code: old_code, + pending: old_pending, + }), + Event::FutureWrite { code, pending }, + ) => Event::FutureWrite { + code: update_code(old_code, code), + pending: old_pending.or(pending), + }, + ( + Some(Event::FutureRead { + code: old_code, + pending: old_pending, + }), + Event::FutureRead { code, pending }, + ) => Event::FutureRead { + code: update_code(old_code, code), + pending: old_pending.or(pending), + }, + ( + Some(Event::StreamWrite { + code: old_code, + pending: old_pending, + }), + Event::StreamWrite { code, pending }, + ) => Event::StreamWrite { + code: update_code(old_code, code), + pending: old_pending.or(pending), + }, + ( + Some(Event::StreamRead { + code: old_code, + pending: old_pending, + }), + Event::StreamRead { code, pending }, + ) => Event::StreamRead { + code: update_code(old_code, code), + pending: old_pending.or(pending), + }, + _ => unreachable!(), + }; + + waitable.set_event(self, Some(event)) + } + fn get_mut_by_index( &mut self, ty: TableIndex, @@ -1895,8 +1971,14 @@ impl ComponentInstance { self.set_event( read_handle.rep(), match ty { - TableIndex::Future(ty) => Event::FutureRead { code, ty, handle }, - TableIndex::Stream(ty) => Event::StreamRead { code, ty, handle }, + TableIndex::Future(ty) => Event::FutureRead { + code, + pending: Some((ty, handle)), + }, + TableIndex::Stream(ty) => Event::StreamRead { + code, + pending: Some((ty, handle)), + }, }, )?; } @@ -1926,7 +2008,7 @@ impl ComponentInstance { } if let PostWrite::Close = post_write { - self.host_close_writer(transmit_rep)?; + self.host_close_writer(transmit_rep, kind)?; } Ok(()) @@ -2125,7 +2207,7 @@ impl ComponentInstance { /// # Arguments /// /// * `transmit_rep` - The `TransmitState` rep for the stream or future. - fn host_close_writer(&mut self, transmit_rep: u32) -> Result<()> { + fn host_close_writer(&mut self, transmit_rep: u32, kind: TransmitKind) -> Result<()> { let transmit_id = TableId::::new(transmit_rep); let transmit = self .get_mut(transmit_id) @@ -2158,32 +2240,26 @@ impl ComponentInstance { ReadState::Open }; + let read_handle = transmit.read_handle; + // Swap in the new read state match mem::replace(&mut transmit.read, new_state) { // If the guest was ready to read, then we cannot close the reader (or writer) // we must deliver the event, and update the state associated with the handle to // represent that a read must be performed ReadState::GuestReady { ty, handle, .. } => { - let read_handle = transmit.read_handle; - - let code = ReturnCode::Closed( - if let Some(Event::StreamRead { - code: ReturnCode::Completed(count), - .. - }) = self.take_event(read_handle.rep())? - { - count - } else { - 0 - }, - ); - // Ensure the final read of the guest is queued, with appropriate closure indicator - self.set_event( + self.update_event( read_handle.rep(), match ty { - TableIndex::Future(ty) => Event::FutureRead { code, ty, handle }, - TableIndex::Stream(ty) => Event::StreamRead { code, ty, handle }, + TableIndex::Future(ty) => Event::FutureRead { + code: ReturnCode::Closed(0), + pending: Some((ty, handle)), + }, + TableIndex::Stream(ty) => Event::StreamRead { + code: ReturnCode::Closed(0), + pending: Some((ty, handle)), + }, }, )?; } @@ -2195,7 +2271,21 @@ impl ComponentInstance { } // If the read state is open, then there are no registered readers of the stream/future - ReadState::Open => {} + ReadState::Open => { + self.update_event( + read_handle.rep(), + match kind { + TransmitKind::Future => Event::FutureRead { + code: ReturnCode::Closed(0), + pending: None, + }, + TransmitKind::Stream => Event::StreamRead { + code: ReturnCode::Closed(0), + pending: None, + }, + }, + )?; + } // If the read state was already closed, then we can remove the transmit state completely // (both writer and reader have been closed) @@ -2213,7 +2303,12 @@ impl ComponentInstance { /// /// * `store` - The store to which this instance belongs /// * `transmit_rep` - The `TransmitState` rep for the stream or future. - fn host_close_reader(&mut self, store: &mut dyn VMStore, transmit_rep: u32) -> Result<()> { + fn host_close_reader( + &mut self, + store: &mut dyn VMStore, + transmit_rep: u32, + kind: TransmitKind, + ) -> Result<()> { let transmit_id = TableId::::new(transmit_rep); let transmit = self .get_mut(transmit_id) @@ -2230,6 +2325,8 @@ impl ComponentInstance { WriteState::Open }; + let write_handle = transmit.write_handle; + match mem::replace(&mut transmit.write, new_state) { // If a guest is waiting to write, notify it that the read end has // been closed. @@ -2239,48 +2336,45 @@ impl ComponentInstance { post_write, .. } => { - let write_handle = transmit.write_handle; - - let pending = if let PostWrite::Close = post_write { + if let PostWrite::Close = post_write { self.delete_transmit(transmit_id)?; - false } else { - true + self.update_event( + write_handle.rep(), + match ty { + TableIndex::Future(ty) => Event::FutureWrite { + code: ReturnCode::Closed(0), + pending: Some((ty, handle)), + }, + TableIndex::Stream(ty) => Event::StreamWrite { + code: ReturnCode::Closed(0), + pending: Some((ty, handle)), + }, + }, + )?; }; + } - let code = ReturnCode::Closed( - if let Some(Event::StreamWrite { - code: ReturnCode::Completed(count), - .. - }) = self.take_event(write_handle.rep())? - { - count - } else { - 0 - }, - ); + WriteState::HostReady { accept, .. } => { + accept(store, self, Reader::End)?; + } - self.set_event( + WriteState::Open => { + self.update_event( write_handle.rep(), - match ty { - TableIndex::Future(ty) => Event::FutureWrite { - code, - pending: pending.then_some((ty, handle)), + match kind { + TransmitKind::Future => Event::FutureWrite { + code: ReturnCode::Closed(0), + pending: None, }, - TableIndex::Stream(ty) => Event::StreamWrite { - code, - pending: pending.then_some((ty, handle)), + TransmitKind::Stream => Event::StreamWrite { + code: ReturnCode::Closed(0), + pending: None, }, }, )?; } - WriteState::HostReady { accept, .. } => { - accept(store, self, Reader::End)?; - } - - WriteState::Open => {} - WriteState::Closed => { log::trace!("host_close_reader delete {transmit_rep}"); self.delete_transmit(transmit_id)?; @@ -2605,13 +2699,11 @@ impl ComponentInstance { match read_ty { TableIndex::Future(ty) => Event::FutureRead { code, - ty, - handle: read_handle, + pending: Some((ty, read_handle)), }, TableIndex::Stream(ty) => Event::StreamRead { code, - ty, - handle: read_handle, + pending: Some((ty, read_handle)), }, }, )?; @@ -2933,12 +3025,16 @@ impl ComponentInstance { /// Close the writable end of the specified stream or future from the guest. fn guest_close_writable(&mut self, ty: TableIndex, writer: u32) -> Result<()> { - let (transmit_rep, WaitableState::Stream(_, state) | WaitableState::Future(_, state)) = - self.state_table(ty) - .remove_by_index(writer) - .context("failed to find writer")? - else { - bail!("invalid stream or future handle"); + let (transmit_rep, state) = self + .state_table(ty) + .remove_by_index(writer) + .context("failed to find writer")?; + let (state, kind) = match state { + WaitableState::Stream(_, state) => (state, TransmitKind::Stream), + WaitableState::Future(_, state) => (state, TransmitKind::Future), + _ => { + bail!("invalid stream or future handle"); + } }; match state { StreamFutureState::Write => {} @@ -2952,7 +3048,7 @@ impl ComponentInstance { .get(TableId::::new(transmit_rep))? .state .rep(); - self.host_close_writer(transmit_rep) + self.host_close_writer(transmit_rep, kind) } /// Close the readable end of the specified stream or future from the guest. @@ -2962,10 +3058,13 @@ impl ComponentInstance { ty: TableIndex, reader: u32, ) -> Result<()> { - let (rep, WaitableState::Stream(_, state) | WaitableState::Future(_, state)) = - self.state_table(ty).remove_by_index(reader)? - else { - bail!("invalid stream or future handle"); + let (rep, state) = self.state_table(ty).remove_by_index(reader)?; + let (state, kind) = match state { + WaitableState::Stream(_, state) => (state, TransmitKind::Stream), + WaitableState::Future(_, state) => (state, TransmitKind::Future), + _ => { + bail!("invalid stream or future handle"); + } }; match state { StreamFutureState::Read => {} @@ -2977,7 +3076,7 @@ impl ComponentInstance { let id = TableId::::new(rep); let rep = self.get(id)?.state.rep(); log::trace!("guest_close_readable: close reader {id:?}"); - self.host_close_reader(store, rep) + self.host_close_reader(store, rep, kind) } /// Create a new error context for the given component. diff --git a/tests/misc_testsuite/component-model-async/partial-stream-copies.wast b/tests/misc_testsuite/component-model-async/partial-stream-copies.wast index d4573fbbf1..a847c96241 100644 --- a/tests/misc_testsuite/component-model-async/partial-stream-copies.wast +++ b/tests/misc_testsuite/component-model-async/partial-stream-copies.wast @@ -21,7 +21,6 @@ (import "" "task.return" (func $task.return (param i32))) (import "" "waitable.join" (func $waitable.join (param i32 i32))) (import "" "waitable-set.new" (func $waitable-set.new (result i32))) - (import "" "waitable-set.wait" (func $waitable-set.wait (param i32 i32) (result i32))) (import "" "stream.new" (func $stream.new (result i64))) (import "" "stream.read" (func $stream.read (param i32 i32 i32) (result i32))) (import "" "stream.write" (func $stream.write (param i32 i32 i32) (result i32))) @@ -50,14 +49,14 @@ ;; create a new stream r/w pair $outsr/$outsw (local.set $ret64 (call $stream.new)) (local.set $outsr (i32.wrap_i64 (local.get $ret64))) - (if (i32.ne (i32.const 4) (local.get $outsr)) + (if (i32.ne (i32.const 3) (local.get $outsr)) (then unreachable)) (global.set $outsw (i32.wrap_i64 (i64.shr_u (local.get $ret64) (i64.const 32)))) - (if (i32.ne (i32.const 3) (global.get $outsw)) + (if (i32.ne (i32.const 4) (global.get $outsw)) (then unreachable)) ;; start async read on $insr which will block - (local.set $ret (call $stream.read (global.get $insr) (i32.const 12) (global.get $inbufp))) + (local.set $ret (call $stream.read (global.get $insr) (global.get $inbufp) (i32.const 12))) (if (i32.ne (i32.const -1 (; BLOCKED ;)) (local.get $ret)) (then unreachable)) @@ -70,14 +69,50 @@ (i32.or (i32.const 2 (; WAIT ;)) (i32.shl (global.get $ws) (i32.const 4))) ) (func $transform_cb (export "transform_cb") (param $event_code i32) (param $index i32) (param $payload i32) (result i32) - unreachable + (local $ret i32) (local $ret64 i64) + + ;; confirm the read succeeded fully + (if (i32.ne (local.get $event_code) (i32.const 2 (; STREAM_READ ;))) + (then unreachable)) + (if (i32.ne (local.get $index) (global.get $insr)) + (then unreachable)) + (if (i32.ne (local.get $payload) (i32.const 0xc0 (; COMPLETED=0 | (12 << 4) ;))) + (then unreachable)) + (if (i32.ne (i32.const 0x89abcdef) (i32.load offset=0 (global.get $inbufp))) + (then unreachable)) + (if (i32.ne (i32.const 0x01234567) (i32.load offset=4 (global.get $inbufp))) + (then unreachable)) + (if (i32.ne (i32.const 0x89abcdef) (i32.load offset=8 (global.get $inbufp))) + (then unreachable)) + + ;; multiple read calls succeed until 12-byte buffer is consumed + (local.set $ret (call $stream.read (global.get $insr) (global.get $inbufp) (i32.const 4))) + (if (i32.ne (i32.const 0x40) (local.get $ret)) + (then unreachable)) + (if (i32.ne (i32.const 0x76543210) (i32.load (global.get $inbufp))) + (then unreachable)) + (local.set $ret (call $stream.read (global.get $insr) (global.get $inbufp) (i32.const 2))) + (if (i32.ne (i32.const 0x20) (local.get $ret)) + (then unreachable)) + (if (i32.ne (i32.const 0xba98) (i32.load16_u (global.get $inbufp))) + (then unreachable)) + (local.set $ret (call $stream.read (global.get $insr) (global.get $inbufp) (i32.const 8))) + (if (i32.ne (i32.const 0x60) (local.get $ret)) + (then unreachable)) + (if (i32.ne (i32.const 0x3210fedc) (i32.load (global.get $inbufp))) + (then unreachable)) + (if (i32.ne (i32.const 0x7654) (i32.load16_u offset=4 (global.get $inbufp))) + (then unreachable)) + + (call $stream.close-readable (global.get $insr)) + (call $stream.close-writable (global.get $outsw)) + (return (i32.const 0 (; EXIT ;))) ) ) (type $ST (stream u8)) (canon task.return (result $ST) (memory $memory "mem") (core func $task.return)) (canon waitable.join (core func $waitable.join)) (canon waitable-set.new (core func $waitable-set.new)) - (canon waitable-set.wait (memory $memory "mem") (core func $waitable-set.wait)) (canon stream.new $ST (core func $stream.new)) (canon stream.read $ST async (memory $memory "mem") (core func $stream.read)) (canon stream.write $ST async (memory $memory "mem") (core func $stream.write)) @@ -88,7 +123,6 @@ (export "task.return" (func $task.return)) (export "waitable.join" (func $waitable.join)) (export "waitable-set.new" (func $waitable-set.new)) - (export "waitable-set.wait" (func $waitable-set.wait)) (export "stream.new" (func $stream.new)) (export "stream.read" (func $stream.read)) (export "stream.write" (func $stream.write)) @@ -138,23 +172,46 @@ (local.set $retp (i32.const 8)) (i32.store (local.get $paramp) (local.get $insr)) (local.set $ret (call $transform (local.get $paramp) (local.get $retp))) - (if (i32.ne (i32.const 2) (local.get $ret)) + (if (i32.ne (i32.const 2 (; RETURNED=2 | (0<<4) ;)) (local.get $ret)) (then unreachable)) (local.set $outsr (i32.load (local.get $retp))) - (if (i32.ne (i32.const 2) (local.get $outsr)) + (if (i32.ne (i32.const 1) (local.get $outsr)) (then unreachable)) ;; multiple write calls succeed until 12-byte buffer is filled - (i64.store (i32.const 0) (i64.const 0x0123456789abcdef)) - (local.set $ret (call $stream.write (local.get $insw) (i32.const 0) (i32.const 8))) + (i64.store (i32.const 16) (i64.const 0x0123456789abcdef)) + (local.set $ret (call $stream.write (local.get $insw) (i32.const 16) (i32.const 8))) (if (i32.ne (i32.const 0x80) (local.get $ret)) (then unreachable)) - (local.set $ret (call $stream.write (local.get $insw) (i32.const 0) (i32.const 8))) + (local.set $ret (call $stream.write (local.get $insw) (i32.const 16) (i32.const 8))) (if (i32.ne (i32.const 0x40) (local.get $ret)) (then unreachable)) - ;; return 44 to the top-level test harness - (i32.const 44) + ;; start a blocking write with a 12-byte buffer + (i64.store (i32.const 16) (i64.const 0xfedcba9876543210)) + (i32.store (i32.const 24) (i32.const 0x76543210)) + (local.set $ret (call $stream.write (local.get $insw) (i32.const 16) (i32.const 12))) + (if (i32.ne (i32.const -1 (; BLOCKED ;)) (local.get $ret)) + (then unreachable)) + + ;; wait for transform to read our write and close all the streams + (local.set $ws (call $waitable-set.new)) + (call $waitable.join (local.get $insw) (local.get $ws)) + (local.set $ret (call $waitable-set.wait (local.get $ws) (i32.const 0))) + + ;; confirm the write and the closed stream + (if (i32.ne (i32.const 3 (; STREAM_WRITE ;)) (local.get $ret)) + (then unreachable)) + (if (i32.ne (local.get $insw) (i32.load (i32.const 0))) + (then unreachable)) + (if (i32.ne (i32.const 0xc1 (; CLOSED=1 | (12 << 4) ;) (; TODO: currently returns 0xc0 ;)) (i32.load (i32.const 4))) + (then unreachable)) + + (call $stream.close-writable (local.get $insw)) + (call $stream.close-readable (local.get $outsr)) + + ;; return 42 to the top-level test harness + (i32.const 42) ) ) (type $ST (stream u8)) @@ -186,4 +243,4 @@ (instance $d (instantiate $D (with "transform" (func $c "transform")))) (func (export "run") (alias export $d "run")) ) -;;(assert_return (invoke "run") (u32.const 44)) +(assert_return (invoke "run") (u32.const 42))