|
1 | 1 | use proc_macro::TokenStream; |
2 | 2 | use quote::quote; |
3 | | -use syn::{parse_macro_input, FnArg, ItemFn, ReturnType, Type}; |
| 3 | +use syn::{parse_macro_input, DeriveInput, Data, Fields, ItemImpl, ItemFn, ImplItem, FnArg, ReturnType, Type}; |
| 4 | + |
| 5 | +#[proc_macro_derive(Contract)] |
| 6 | +pub fn derive_contract(input: TokenStream) -> TokenStream { |
| 7 | + let input = parse_macro_input!(input as DeriveInput); |
| 8 | + let name = &input.ident; |
| 9 | + let Data::Struct(data) = &input.data else { return TokenStream::new() }; |
| 10 | + let Fields::Named(fields) = &data.fields else { return TokenStream::new() }; |
| 11 | + |
| 12 | + let init_calls = fields.named.iter().map(|f| { |
| 13 | + let field = f.ident.as_ref().unwrap(); |
| 14 | + let key = field.to_string(); |
| 15 | + quote! { self.#field = LazyCell::new(b!("__state__::", #key)); } |
| 16 | + }); |
| 17 | + |
| 18 | + let flush_calls = fields.named.iter().map(|f| { |
| 19 | + let field = f.ident.as_ref().unwrap(); |
| 20 | + quote! { self.#field.flush(); } |
| 21 | + }); |
| 22 | + |
| 23 | + TokenStream::from(quote! { |
| 24 | + impl #name { |
| 25 | + pub fn __init_lazy_fields(&mut self) { #(#init_calls)* } |
| 26 | + pub fn __flush_lazy_fields(&self) { #(#flush_calls)* } |
| 27 | + } |
| 28 | + }) |
| 29 | +} |
4 | 30 |
|
5 | 31 | #[proc_macro_attribute] |
6 | 32 | pub fn contract(_attr: TokenStream, item: TokenStream) -> TokenStream { |
7 | | - let input = parse_macro_input!(item as ItemFn); |
| 33 | + if let Ok(impl_block) = syn::parse::<ItemImpl>(item.clone()) { |
| 34 | + return handle_impl_block(impl_block); |
| 35 | + } |
| 36 | + if let Ok(function) = syn::parse::<ItemFn>(item.clone()) { |
| 37 | + return handle_function(function); |
| 38 | + } |
| 39 | + item |
| 40 | +} |
| 41 | + |
| 42 | +fn handle_impl_block(impl_block: ItemImpl) -> TokenStream { |
| 43 | + let self_ty = &impl_block.self_ty; |
| 44 | + let mut methods = Vec::new(); |
| 45 | + let mut wrappers = Vec::new(); |
| 46 | + |
| 47 | + for item in impl_block.items.iter() { |
| 48 | + let ImplItem::Fn(method) = item else { continue }; |
| 49 | + if !matches!(method.vis, syn::Visibility::Public(_)) { continue }; |
| 50 | + |
| 51 | + let name = &method.sig.ident; |
| 52 | + let has_return = !matches!(method.sig.output, ReturnType::Default); |
| 53 | + let has_self = method.sig.inputs.iter().any(|arg| matches!(arg, FnArg::Receiver(_))); |
| 54 | + |
| 55 | + if !has_self { |
| 56 | + methods.push(method); |
| 57 | + continue; |
| 58 | + } |
| 59 | + |
| 60 | + let args: Vec<_> = method.sig.inputs.iter() |
| 61 | + .filter_map(|arg| match arg { |
| 62 | + FnArg::Typed(pat_type) => { |
| 63 | + let param = &pat_type.pat; |
| 64 | + let ptr = syn::Ident::new(&format!("{}_ptr", quote!(#param)), name.span()); |
| 65 | + let deser = match &*pat_type.ty { |
| 66 | + Type::Path(tp) if quote!(#tp).to_string().contains("String") => quote!(read_string), |
| 67 | + _ => quote!(read_bytes), |
| 68 | + }; |
| 69 | + Some((quote!(#ptr: i32), quote!(let #param = #deser(#ptr);), quote!(#param))) |
| 70 | + } |
| 71 | + _ => None, |
| 72 | + }) |
| 73 | + .collect(); |
| 74 | + |
| 75 | + let params: Vec<_> = args.iter().map(|(p, _, _)| p).collect(); |
| 76 | + let deserializations: Vec<_> = args.iter().map(|(_, d, _)| d).collect(); |
| 77 | + let call_args: Vec<_> = args.iter().map(|(_, _, c)| c).collect(); |
| 78 | + |
| 79 | + let sig = if params.is_empty() { |
| 80 | + quote!(#[no_mangle] pub extern "C" fn #name()) |
| 81 | + } else { |
| 82 | + quote!(#[no_mangle] pub extern "C" fn #name(#(#params),*)) |
| 83 | + }; |
| 84 | + |
| 85 | + let call = if call_args.is_empty() { |
| 86 | + quote!(#name()) |
| 87 | + } else { |
| 88 | + quote!(#name(#(#call_args),*)) |
| 89 | + }; |
| 90 | + |
| 91 | + let body = if has_return { |
| 92 | + quote! { |
| 93 | + #(#deserializations)* |
| 94 | + let mut state = #self_ty::default(); |
| 95 | + state.__init_lazy_fields(); |
| 96 | + let result = state.#call; |
| 97 | + state.__flush_lazy_fields(); |
| 98 | + ret(result); |
| 99 | + } |
| 100 | + } else { |
| 101 | + quote! { |
| 102 | + #(#deserializations)* |
| 103 | + let mut state = #self_ty::default(); |
| 104 | + state.__init_lazy_fields(); |
| 105 | + state.#call; |
| 106 | + state.__flush_lazy_fields(); |
| 107 | + } |
| 108 | + }; |
| 109 | + |
| 110 | + wrappers.push(quote! { #sig { #body } }); |
| 111 | + methods.push(method); |
| 112 | + } |
| 113 | + |
| 114 | + TokenStream::from(quote! { |
| 115 | + impl #self_ty { #(#methods)* } |
| 116 | + #(#wrappers)* |
| 117 | + }) |
| 118 | +} |
| 119 | + |
| 120 | +fn handle_function(input: ItemFn) -> TokenStream { |
8 | 121 | let vis = &input.vis; |
9 | | - let fn_name = &input.sig.ident; |
10 | | - let impl_fn_name = syn::Ident::new(&format!("{}_impl", fn_name), fn_name.span()); |
| 122 | + let name = &input.sig.ident; |
| 123 | + let impl_name = syn::Ident::new(&format!("{}_impl", name), name.span()); |
11 | 124 | let inputs = &input.sig.inputs; |
12 | 125 | let output = &input.sig.output; |
13 | 126 | let block = &input.block; |
14 | 127 | let attrs = &input.attrs; |
15 | 128 | let has_return = !matches!(output, ReturnType::Default); |
16 | 129 |
|
17 | | - let mut param_count = 0; |
18 | | - let mut wrapper_params = quote!{}; |
| 130 | + let mut idx = 0; |
| 131 | + let mut params = quote!{}; |
19 | 132 | let mut deserializations = quote!{}; |
20 | 133 | let mut call_args = quote!{}; |
21 | 134 |
|
22 | 135 | for arg in inputs.iter() { |
23 | 136 | if let FnArg::Typed(pat_type) = arg { |
24 | | - let param_name = &pat_type.pat; |
25 | | - let ptr_name = syn::Ident::new(&format!("arg{}_ptr", param_count), fn_name.span()); |
26 | | - |
27 | | - let deserialize_fn = match &*pat_type.ty { |
| 137 | + let param = &pat_type.pat; |
| 138 | + let ptr = syn::Ident::new(&format!("arg{}_ptr", idx), name.span()); |
| 139 | + let deser = match &*pat_type.ty { |
28 | 140 | Type::Path(tp) if quote!(#tp).to_string().contains("String") => quote!(read_string), |
29 | 141 | _ => quote!(read_bytes), |
30 | 142 | }; |
31 | 143 |
|
32 | | - if param_count > 0 { |
33 | | - wrapper_params.extend(quote!(, #ptr_name: i32)); |
34 | | - call_args.extend(quote!(, #param_name)); |
| 144 | + if idx > 0 { |
| 145 | + params.extend(quote!(, #ptr: i32)); |
| 146 | + call_args.extend(quote!(, #param)); |
35 | 147 | } else { |
36 | | - wrapper_params.extend(quote!(#ptr_name: i32)); |
37 | | - call_args.extend(quote!(#param_name)); |
| 148 | + params.extend(quote!(#ptr: i32)); |
| 149 | + call_args.extend(quote!(#param)); |
38 | 150 | } |
39 | 151 |
|
40 | | - deserializations.extend(quote! { let #param_name = #deserialize_fn(#ptr_name); }); |
41 | | - param_count += 1; |
| 152 | + deserializations.extend(quote! { let #param = #deser(#ptr); }); |
| 153 | + idx += 1; |
42 | 154 | } |
43 | 155 | } |
44 | 156 |
|
45 | | - let wrapper_sig = if param_count == 0 { |
46 | | - quote!(#[no_mangle] pub extern "C" fn #fn_name()) |
| 157 | + let sig = if idx == 0 { |
| 158 | + quote!(#[no_mangle] pub extern "C" fn #name()) |
47 | 159 | } else { |
48 | | - quote!(#[no_mangle] pub extern "C" fn #fn_name(#wrapper_params)) |
| 160 | + quote!(#[no_mangle] pub extern "C" fn #name(#params)) |
49 | 161 | }; |
50 | 162 |
|
51 | | - let wrapper_call = if has_return { |
52 | | - quote!(ret(#impl_fn_name(#call_args));) |
| 163 | + let call = if has_return { |
| 164 | + quote!(ret(#impl_name(#call_args));) |
53 | 165 | } else { |
54 | | - quote!(#impl_fn_name(#call_args);) |
| 166 | + quote!(#impl_name(#call_args);) |
55 | 167 | }; |
56 | 168 |
|
57 | 169 | TokenStream::from(quote! { |
58 | | - #wrapper_sig { #deserializations #wrapper_call } |
59 | | - #(#attrs)* #vis fn #impl_fn_name(#inputs) #output #block |
| 170 | + #sig { #deserializations #call } |
| 171 | + #(#attrs)* #vis fn #impl_name(#inputs) #output #block |
60 | 172 | }) |
61 | 173 | } |
0 commit comments