|
2 | 2 | // SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | 4 | use proc_macro::TokenStream; |
5 | | -use quote::{format_ident, quote}; |
6 | | -use syn::FnArg::Typed; |
7 | | -use syn::__private::Span; |
8 | | -use syn::parse::{Parse, ParseStream}; |
9 | | -use syn::{parse_macro_input, parse_quote, Arm, Ident, ItemTrait, Pat, TraitItem}; |
10 | | - |
11 | | -fn snake_to_camel(ident_str: &str) -> String { |
12 | | - let mut camel_ty = String::with_capacity(ident_str.len()); |
13 | | - |
14 | | - let mut last_char_was_underscore = true; |
15 | | - for c in ident_str.chars() { |
16 | | - match c { |
17 | | - '_' => last_char_was_underscore = true, |
18 | | - c if last_char_was_underscore => { |
19 | | - camel_ty.extend(c.to_uppercase()); |
20 | | - last_char_was_underscore = false; |
21 | | - } |
22 | | - c => camel_ty.extend(c.to_lowercase()), |
23 | | - } |
24 | | - } |
25 | | - |
26 | | - camel_ty.shrink_to_fit(); |
27 | | - camel_ty |
28 | | -} |
29 | | - |
30 | | -#[proc_macro_attribute] |
31 | | -pub fn extract_request_id(_attr: TokenStream, input: TokenStream) -> TokenStream { |
32 | | - let mut item: ItemTrait = syn::parse(input).unwrap(); |
33 | | - let name = &format_ident!("{}Request", item.ident); |
34 | | - let mut arms: Vec<Arm> = vec![]; |
35 | | - let mut backpressure_variants: Vec<Ident> = vec![]; |
36 | | - |
37 | | - for inner in item.items.iter_mut() { |
38 | | - if let TraitItem::Fn(func) = inner { |
39 | | - // Strip #[force_backpressure] and record which methods carry it. |
40 | | - let had_force_backpressure = func.attrs.iter().any(|attr| { |
41 | | - attr.meta |
42 | | - .path() |
43 | | - .get_ident() |
44 | | - .is_some_and(|i| i == "force_backpressure") |
45 | | - }); |
46 | | - func.attrs.retain(|attr| { |
47 | | - attr.meta |
48 | | - .path() |
49 | | - .get_ident() |
50 | | - .is_none_or(|i| i != "force_backpressure") |
51 | | - }); |
52 | | - |
53 | | - let method = Ident::new( |
54 | | - &snake_to_camel(&func.sig.ident.to_string()), |
55 | | - Span::mixed_site(), |
56 | | - ); |
57 | | - |
58 | | - if had_force_backpressure { |
59 | | - backpressure_variants.push(method.clone()); |
60 | | - } |
61 | | - |
62 | | - for any_arg in &func.sig.inputs { |
63 | | - if let Typed(arg) = any_arg { |
64 | | - if let Pat::Ident(ident) = &*arg.pat { |
65 | | - let matched_enum_type = match ident.ident.to_string().as_str() { |
66 | | - "session_id" => Some(format_ident!("SessionId")), |
67 | | - "instance_id" => Some(format_ident!("InstanceId")), |
68 | | - _ => None, |
69 | | - }; |
70 | | - if let Some(enum_type) = matched_enum_type { |
71 | | - arms.push(parse_quote! { |
72 | | - #name::#method { #ident, .. } => RequestIdentifier::#enum_type(#ident.clone()) |
73 | | - }); |
74 | | - } |
75 | | - } |
76 | | - } |
77 | | - } |
78 | | - } |
79 | | - } |
80 | | - |
81 | | - let backpressure_body = if backpressure_variants.is_empty() { |
82 | | - quote! { false } |
83 | | - } else { |
84 | | - quote! { matches!(self, #(#name::#backpressure_variants { .. })|*) } |
85 | | - }; |
86 | | - |
87 | | - TokenStream::from(quote! { |
88 | | - #item |
89 | | - |
90 | | - impl RequestIdentification for tarpc::Request<#name> { |
91 | | - fn extract_identifier(&self) -> RequestIdentifier { |
92 | | - match &self.message { |
93 | | - #( |
94 | | - #arms, |
95 | | - )* |
96 | | - _ => RequestIdentifier::None, |
97 | | - } |
98 | | - } |
99 | | - } |
100 | | - |
101 | | - impl #name { |
102 | | - /// Returns true if this request variant was annotated with `#[force_backpressure]`. |
103 | | - pub fn requires_backpressure(&self) -> bool { |
104 | | - #backpressure_body |
105 | | - } |
106 | | - } |
107 | | - }) |
108 | | -} |
| 5 | +use quote::quote; |
| 6 | +use syn::{parse_macro_input, parse::{Parse, ParseStream}}; |
109 | 7 |
|
110 | 8 | struct EnvOrDefault { |
111 | 9 | name: syn::LitStr, |
|
0 commit comments