diff --git a/crates/guest-rust/macro/src/lib.rs b/crates/guest-rust/macro/src/lib.rs index b4be3798c..1daf59268 100644 --- a/crates/guest-rust/macro/src/lib.rs +++ b/crates/guest-rust/macro/src/lib.rs @@ -8,7 +8,7 @@ use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{braced, token, LitStr, Token}; use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId}; -use wit_bindgen_rust::{AsyncConfig, Opts, Ownership, WithOption}; +use wit_bindgen_rust::{Async, AsyncFilter, Opts, Ownership, WithOption}; #[proc_macro] pub fn generate(input: proc_macro::TokenStream) -> proc_macro::TokenStream { @@ -155,7 +155,7 @@ impl Parse for Config { return Err(Error::new(span, "cannot specify second async config")); } async_configured = true; - if !matches!(val, AsyncConfig::None) && !cfg!(feature = "async") { + if val.iter().any(|v| v.enabled) && !cfg!(feature = "async") { return Err(Error::new( span, "must enable `async` feature to enable async imports and/or exports", @@ -369,11 +369,6 @@ impl From for wit_bindgen_rust::ExportKey { } } -enum AsyncConfigSomeKind { - Imports, - Exports, -} - enum Opt { World(syn::LitStr), Path(Span, Vec), @@ -399,7 +394,7 @@ enum Opt { GenerateUnusedTypes(syn::LitBool), Features(Vec), DisableCustomSectionLinkHelpers(syn::LitBool), - Async(AsyncConfig, Span), + Async(Vec, Span), Debug(syn::LitBool), } @@ -563,25 +558,22 @@ impl Parse for Opt { let span = input.parse::()?.span; input.parse::()?; if input.peek(syn::LitBool) { - if input.parse::()?.value { - Ok(Opt::Async(AsyncConfig::All, span)) - } else { - Ok(Opt::Async(AsyncConfig::None, span)) - } + let enabled = input.parse::()?.value; + Ok(Opt::Async( + vec![Async { + enabled, + filter: AsyncFilter::All, + }], + span, + )) } else { - let mut imports = Vec::new(); - let mut exports = Vec::new(); + let mut vals = Vec::new(); let contents; - syn::braced!(contents in input); - for (kind, values) in - contents.parse_terminated(parse_async_some_field, Token![,])? - { - match kind { - AsyncConfigSomeKind::Imports => imports = values, - AsyncConfigSomeKind::Exports => exports = values, - } + syn::bracketed!(contents in input); + for val in contents.parse_terminated(parse_async, Token![,])? { + vals.push(val); } - Ok(Opt::Async(AsyncConfig::Some { imports, exports }, span)) + Ok(Opt::Async(vals, span)) } } else { Err(l.error()) @@ -642,26 +634,7 @@ fn fmt(input: &str) -> Result { Ok(prettyplease::unparse(&syntax_tree)) } -fn parse_async_some_field(input: ParseStream<'_>) -> Result<(AsyncConfigSomeKind, Vec)> { - let lookahead = input.lookahead1(); - let kind = if lookahead.peek(kw::imports) { - input.parse::()?; - input.parse::()?; - AsyncConfigSomeKind::Imports - } else if lookahead.peek(kw::exports) { - input.parse::()?; - input.parse::()?; - AsyncConfigSomeKind::Exports - } else { - return Err(lookahead.error()); - }; - - let list; - syn::bracketed!(list in input); - let fields = list.parse_terminated(Parse::parse, Token![,])?; - - Ok(( - kind, - fields.iter().map(|s: &syn::LitStr| s.value()).collect(), - )) +fn parse_async(input: ParseStream<'_>) -> Result { + let value = input.parse::()?.value(); + Ok(Async::parse(&value)) } diff --git a/crates/guest-rust/src/lib.rs b/crates/guest-rust/src/lib.rs index aad2a54a0..e0376b63f 100644 --- a/crates/guest-rust/src/lib.rs +++ b/crates/guest-rust/src/lib.rs @@ -839,18 +839,19 @@ /// // /// // The resulting bindings will use the component model /// // [async ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/Async.md). -/// // This may be specified either as a boolean (e.g. `async: true`, meaning -/// // all imports and exports should use the async ABI) or as lists of -/// // specific imports and/or exports as shown here: -/// async: { -/// imports: [ -/// "wasi:http/types@0.3.0-draft#[static]body.finish", -/// "wasi:http/handler@0.3.0-draft#handle", -/// ], -/// exports: [ -/// "wasi:http/handler@0.3.0-draft#handle", -/// ] -/// } +/// // +/// // If this option is not provided then the WIT's source annotation will +/// // be used instead. +/// async: true, // all bindings are async +/// async: false, // all bindings are sync +/// // With an array per-function configuration can be specified. A leading +/// // '-' will disable async for that particular function. +/// async: [ +/// "wasi:http/types@0.3.0-draft#[static]body.finish", +/// "import:wasi:http/handler@0.3.0-draft#handle", +/// "-export:wasi:http/handler@0.3.0-draft#handle", +/// "all", +/// ], /// }); /// ``` /// diff --git a/crates/rust/src/bindgen.rs b/crates/rust/src/bindgen.rs index e37d71c83..976b7caf8 100644 --- a/crates/rust/src/bindgen.rs +++ b/crates/rust/src/bindgen.rs @@ -885,7 +885,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { self.push_str(&prev_src); let constructor_type = match &func.kind { FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => { - self.push_str(&format!("T::{}", to_rust_ident(&func.name))); + self.push_str(&format!("T::{}", to_rust_ident(func.item_name()))); None } FunctionKind::Method(_) diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index 3854fba19..0a42bcbe0 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -1,8 +1,7 @@ use crate::bindgen::{FunctionBindgen, POINTER_SIZE_EXPRESSION}; use crate::{ - full_wit_type_name, int_repr, to_rust_ident, to_upper_camel_case, wasm_type, AsyncConfig, - FnSig, Identifier, InterfaceName, Ownership, RuntimeItem, RustFlagsRepr, RustWasm, - TypeGeneration, + full_wit_type_name, int_repr, to_rust_ident, to_upper_camel_case, wasm_type, FnSig, Identifier, + InterfaceName, Ownership, RuntimeItem, RustFlagsRepr, RustWasm, TypeGeneration, }; use anyhow::Result; use heck::*; @@ -162,17 +161,9 @@ impl<'i> InterfaceGenerator<'i> { continue; } - let async_ = match &self.r#gen.opts.async_ { - AsyncConfig::None => false, - AsyncConfig::All => true, - AsyncConfig::Some { exports, .. } => { - exports.contains(&if let Some((_, key)) = interface { - format!("{}#{}", self.resolve.name_world_key(key), func.name) - } else { - func.name.clone() - }) - } - }; + let async_ = self + .r#gen + .is_async(self.resolve, interface.map(|p| p.1), func, false); let resource = func.kind.resource(); funcs_to_export.push((func, resource, async_)); @@ -712,15 +703,7 @@ pub mod vtable{ordinal} {{ self.generate_payloads("", func, interface); - let async_ = match &self.r#gen.opts.async_ { - AsyncConfig::None => false, - AsyncConfig::All => true, - AsyncConfig::Some { imports, .. } => imports.contains(&if let Some(key) = interface { - format!("{}#{}", self.resolve.name_world_key(key), func.name) - } else { - func.name.clone() - }), - }; + let async_ = self.r#gen.is_async(self.resolve, interface, func, true); let mut sig = FnSig { async_, ..Default::default() @@ -1227,17 +1210,9 @@ unsafe fn call_import(params: *mut u8, results: *mut u8) -> u32 {{ if self.r#gen.skip.contains(&func.name) { continue; } - let async_ = match &self.r#gen.opts.async_ { - AsyncConfig::None => false, - AsyncConfig::All => true, - AsyncConfig::Some { exports, .. } => { - exports.contains(&if let Some((_, key)) = interface { - format!("{}#{}", self.resolve.name_world_key(key), func.name) - } else { - func.name.clone() - }) - } - }; + let async_ = self + .r#gen + .is_async(self.resolve, interface.map(|p| p.1), func, false); let mut sig = FnSig { async_, use_item_name: true, @@ -1349,7 +1324,7 @@ unsafe fn call_import(params: *mut u8, results: *mut u8) -> u32 {{ func.item_name() } } else { - &func.name + func.item_name() }; self.push_str(&to_rust_ident(func_name)); if let Some(generics) = &sig.generics { diff --git a/crates/rust/src/lib.rs b/crates/rust/src/lib.rs index 6b4f9e0b6..e4b27198a 100644 --- a/crates/rust/src/lib.rs +++ b/crates/rust/src/lib.rs @@ -52,6 +52,7 @@ struct RustWasm { future_payloads: IndexMap, stream_payloads: IndexMap, + used_async_options: HashSet, } #[derive(Default)] @@ -138,55 +139,65 @@ fn parse_with(s: &str) -> Result<(String, WithOption), String> { Ok((k.to_string(), v)) } -#[derive(Default, Debug, Clone)] -#[cfg_attr( - feature = "serde", - derive(serde::Deserialize), - serde(rename_all = "kebab-case") -)] -pub enum AsyncConfig { - #[default] - None, - Some { - imports: Vec, - exports: Vec, - }, +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +pub struct Async { + pub enabled: bool, + pub filter: AsyncFilter, +} + +impl Async { + pub fn parse(s: &str) -> Async { + let (s, enabled) = match s.strip_prefix('-') { + Some(s) => (s, false), + None => (s, true), + }; + let filter = match s { + "all" => AsyncFilter::All, + other => match other.strip_prefix("import:") { + Some(s) => AsyncFilter::Import(s.to_string()), + None => match other.strip_prefix("export:") { + Some(s) => AsyncFilter::Export(s.to_string()), + None => AsyncFilter::Function(s.to_string()), + }, + }, + }; + Async { enabled, filter } + } +} + +impl fmt::Display for Async { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.enabled { + write!(f, "-")?; + } + self.filter.fmt(f) + } +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +pub enum AsyncFilter { All, + Function(String), + Import(String), + Export(String), } -#[cfg(feature = "clap")] -fn parse_async(s: &str) -> Result { - Ok(match s { - "none" => AsyncConfig::None, - "all" => AsyncConfig::All, - _ => { - if let Some(values) = s.strip_prefix("some=") { - let mut imports = Vec::new(); - let mut exports = Vec::new(); - for value in values.split(',') { - let error = || { - Err(format!( - "expected string of form `import:` or `export:`; got `{value}`" - )) - }; - if let Some((k, v)) = value.split_once(":") { - match k { - "import" => imports.push(v.into()), - "export" => exports.push(v.into()), - _ => return error(), - } - } else { - return error(); - } - } - AsyncConfig::Some { imports, exports } - } else { - return Err(format!( - "expected string of form `none`, `all`, or `some=[,...]`; got `{s}`" - )); - } +impl fmt::Display for AsyncFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AsyncFilter::All => write!(f, "all"), + AsyncFilter::Function(s) => write!(f, "{s}"), + AsyncFilter::Import(s) => write!(f, "import:{s}"), + AsyncFilter::Export(s) => write!(f, "export:{s}"), } - }) + } +} + +#[cfg(feature = "clap")] +fn parse_async(s: &str) -> Result { + Ok(Async::parse(s)) } #[derive(Default, Debug, Clone)] @@ -323,16 +334,33 @@ pub struct Opts { /// Determines which functions to lift or lower `async`, if any. /// - /// Accepted values are: + /// This option can be passed multiple times and additionally accepts + /// comma-separated values for each option passed. Each individual argument + /// passed here can be one of: /// - /// - none - /// - all - /// - `some=[,...]`, where each `` is of the form: - /// - `import:` or - /// - `export:` - #[cfg_attr(feature = "clap", arg(long = "async", value_parser = parse_async, default_value = "none"))] + /// - `all` - all imports and exports will be async + /// - `-all` - force all imports and exports to be sync + /// - `foo:bar/baz#method` - force this method to be async + /// - `import:foo:bar/baz#method` - force this method to be async, but only + /// as an import + /// - `-export:foo:bar/baz#method` - force this export to be sync + /// + /// If a method is not listed in this option then the WIT's default bindings + /// mode will be used. If the WIT function is defined as `async` then async + /// bindings will be generated, otherwise sync bindings will be generated. + /// + /// Options are processed in the order they are passed here, so if a method + /// matches two directives passed the least-specific one should be last. + #[cfg_attr( + feature = "clap", + arg( + long = "async", + value_parser = parse_async, + value_delimiter =',', + ), + )] #[cfg_attr(feature = "serde", serde(rename = "async"))] - pub async_: AsyncConfig, + pub async_: Vec, } impl Opts { @@ -1004,6 +1032,54 @@ macro_rules! __export_{world_name}_impl {{ ); } } + + fn is_async( + &mut self, + resolve: &Resolve, + interface: Option<&WorldKey>, + func: &Function, + is_import: bool, + ) -> bool { + let name_to_test = match interface { + Some(key) => format!("{}#{}", resolve.name_world_key(key), func.name), + None => func.name.clone(), + }; + for (i, opt) in self.opts.async_.iter().enumerate() { + let name = match &opt.filter { + AsyncFilter::All => { + self.used_async_options.insert(i); + return opt.enabled; + } + AsyncFilter::Function(s) => s, + AsyncFilter::Import(s) => { + if !is_import { + continue; + } + s + } + AsyncFilter::Export(s) => { + if is_import { + continue; + } + s + } + }; + if *name == name_to_test { + self.used_async_options.insert(i); + return opt.enabled; + } + } + + match &func.kind { + FunctionKind::Freestanding + | FunctionKind::Method(_) + | FunctionKind::Static(_) + | FunctionKind::Constructor(_) => false, + FunctionKind::AsyncFreestanding + | FunctionKind::AsyncMethod(_) + | FunctionKind::AsyncStatic(_) => true, + } + } } impl WorldGenerator for RustWasm { @@ -1100,6 +1176,9 @@ impl WorldGenerator for RustWasm { "// * disable_custom_section_link_helpers" ); } + for opt in self.opts.async_.iter() { + uwriteln!(self.src_preamble, "// * async: {opt}"); + } self.types.analyze(resolve); self.world = Some(world); @@ -1415,6 +1494,17 @@ impl WorldGenerator for RustWasm { bail!("unused remappings provided via `with`: {unused_keys:?}"); } + // Error about unused async configuration to help catch configuration + // errors. + for (i, opt) in self.opts.async_.iter().enumerate() { + if self.used_async_options.contains(&i) { + continue; + } + if !matches!(opt.filter, AsyncFilter::All) { + bail!("unused async option: {opt}"); + } + } + Ok(()) } }