|
| 1 | +//! Expansion of `#[thrust_macros::context]`. |
| 2 | +//! |
| 3 | +//! `context` supplies the surrounding context that thrust annotations within an |
| 4 | +//! item cannot see by themselves: it stamps the enclosing `impl`/`trait` header |
| 5 | +//! onto each method (so method-level `requires`/`ensures` recover the outer |
| 6 | +//! generics) and threads that generic context into every `invariant!(...)`. |
| 7 | +
|
| 8 | +use proc_macro::TokenStream; |
| 9 | +use proc_macro2::TokenStream as TokenStream2; |
| 10 | +use quote::{quote, ToTokens}; |
| 11 | +use syn::{parse_macro_input, GenericParam, Generics, WherePredicate}; |
| 12 | + |
| 13 | +use crate::{tokens_contain_self, FnOuterItem}; |
| 14 | + |
| 15 | +pub(super) fn expand(item: TokenStream) -> TokenStream { |
| 16 | + let mut item = parse_macro_input!(item as syn::Item); |
| 17 | + process_context_item(&mut item); |
| 18 | + item.into_token_stream().into() |
| 19 | +} |
| 20 | + |
| 21 | +/// Stamps outer context onto methods and threads the generic context into the |
| 22 | +/// invariants of every function body in the item (recursing through modules). |
| 23 | +fn process_context_item(item: &mut syn::Item) { |
| 24 | + match item { |
| 25 | + syn::Item::Fn(item_fn) => { |
| 26 | + let generics = item_fn.sig.generics.clone(); |
| 27 | + let threaded = thread_invariants(&mut item_fn.block, &generics, None); |
| 28 | + if threaded.found { |
| 29 | + inject_model_bounds(&mut item_fn.sig.generics, None, false); |
| 30 | + } |
| 31 | + } |
| 32 | + syn::Item::Impl(item_impl) => { |
| 33 | + let outer = FnOuterItem::ItemImpl(item_impl.clone()).into_header_only(); |
| 34 | + for impl_item in &mut item_impl.items { |
| 35 | + let syn::ImplItem::Fn(method) = impl_item else { |
| 36 | + continue; |
| 37 | + }; |
| 38 | + method |
| 39 | + .attrs |
| 40 | + .push(syn::parse_quote!(#[thrust::_outer_context(#outer)])); |
| 41 | + let generics = method.sig.generics.clone(); |
| 42 | + let threaded = thread_invariants(&mut method.block, &generics, Some(&outer)); |
| 43 | + if threaded.found { |
| 44 | + inject_model_bounds(&mut method.sig.generics, Some(&outer), threaded.self_used); |
| 45 | + } |
| 46 | + } |
| 47 | + } |
| 48 | + syn::Item::Trait(item_trait) => { |
| 49 | + let outer = FnOuterItem::ItemTrait(item_trait.clone()).into_header_only(); |
| 50 | + for trait_item in &mut item_trait.items { |
| 51 | + let syn::TraitItem::Fn(method) = trait_item else { |
| 52 | + continue; |
| 53 | + }; |
| 54 | + method |
| 55 | + .attrs |
| 56 | + .push(syn::parse_quote!(#[thrust::_outer_context(#outer)])); |
| 57 | + if let Some(block) = &mut method.default { |
| 58 | + let generics = method.sig.generics.clone(); |
| 59 | + let threaded = thread_invariants(block, &generics, Some(&outer)); |
| 60 | + if threaded.found { |
| 61 | + inject_model_bounds( |
| 62 | + &mut method.sig.generics, |
| 63 | + Some(&outer), |
| 64 | + threaded.self_used, |
| 65 | + ); |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + } |
| 70 | + syn::Item::Mod(item_mod) => { |
| 71 | + if let Some((_, items)) = &mut item_mod.content { |
| 72 | + for inner in items { |
| 73 | + process_context_item(inner); |
| 74 | + } |
| 75 | + } |
| 76 | + } |
| 77 | + _ => {} |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +struct Threaded { |
| 82 | + found: bool, |
| 83 | + self_used: bool, |
| 84 | +} |
| 85 | + |
| 86 | +/// Prepends the generic context to every `invariant!(...)` in a function body. |
| 87 | +fn thread_invariants( |
| 88 | + block: &mut syn::Block, |
| 89 | + generics: &Generics, |
| 90 | + outer: Option<&FnOuterItem>, |
| 91 | +) -> Threaded { |
| 92 | + use syn::visit_mut::VisitMut as _; |
| 93 | + |
| 94 | + let context = invariant_context_tokens(generics, outer); |
| 95 | + let mut threader = InvariantThreader { |
| 96 | + context, |
| 97 | + is_method: outer.is_some(), |
| 98 | + found: false, |
| 99 | + self_used: false, |
| 100 | + }; |
| 101 | + threader.visit_block_mut(block); |
| 102 | + Threaded { |
| 103 | + found: threader.found, |
| 104 | + self_used: threader.self_used, |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +/// Builds the `[generic-params] [where-predicates]` prefix that `invariant!` |
| 109 | +/// consumes: every generic parameter in scope (the function's own and, for |
| 110 | +/// methods, the outer ones), the existing where predicates, and the |
| 111 | +/// `Model`/`PartialEq` bounds those parameters require. |
| 112 | +fn invariant_context_tokens(generics: &Generics, outer: Option<&FnOuterItem>) -> TokenStream2 { |
| 113 | + let mut params: Vec<GenericParam> = generics.params.iter().cloned().collect(); |
| 114 | + let mut preds: Vec<WherePredicate> = generics |
| 115 | + .where_clause |
| 116 | + .as_ref() |
| 117 | + .map(|wc| wc.predicates.iter().cloned().collect()) |
| 118 | + .unwrap_or_default(); |
| 119 | + if let Some(outer) = outer { |
| 120 | + params.extend(outer.generics().params.iter().cloned()); |
| 121 | + if let Some(wc) = &outer.generics().where_clause { |
| 122 | + preds.extend(wc.predicates.iter().cloned()); |
| 123 | + } |
| 124 | + } |
| 125 | + for param in ¶ms { |
| 126 | + if let GenericParam::Type(tp) = param { |
| 127 | + let ident = &tp.ident; |
| 128 | + preds.push(syn::parse_quote!(#ident: thrust_models::Model)); |
| 129 | + preds.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); |
| 130 | + } |
| 131 | + } |
| 132 | + quote! { [#(#params),*] [#(#preds),*] } |
| 133 | +} |
| 134 | + |
| 135 | +struct InvariantThreader { |
| 136 | + context: TokenStream2, |
| 137 | + is_method: bool, |
| 138 | + found: bool, |
| 139 | + self_used: bool, |
| 140 | +} |
| 141 | + |
| 142 | +impl syn::visit_mut::VisitMut for InvariantThreader { |
| 143 | + fn visit_macro_mut(&mut self, mac: &mut syn::Macro) { |
| 144 | + syn::visit_mut::visit_macro_mut(self, mac); |
| 145 | + if is_invariant_macro(&mac.path) { |
| 146 | + self.found = true; |
| 147 | + // An invariant in a method may name `Self` in its variable types. |
| 148 | + // A nested item cannot, so signal `invariant!` to re-declare `Self` |
| 149 | + // as a synthetic generic, but only when it is actually used (so we |
| 150 | + // do not over-constrain the host method with `Self: Model`). |
| 151 | + let uses_self = self.is_method && tokens_contain_self(&mac.tokens); |
| 152 | + self.self_used |= uses_self; |
| 153 | + let self_marker = if uses_self { quote!(Self) } else { quote!() }; |
| 154 | + let context = &self.context; |
| 155 | + let original = &mac.tokens; |
| 156 | + mac.tokens = quote! { [#self_marker] #context #original }; |
| 157 | + } |
| 158 | + } |
| 159 | +} |
| 160 | + |
| 161 | +/// Adds `T: Model` and `<T as Model>::Ty: PartialEq` bounds for every type |
| 162 | +/// parameter in scope to a function's where clause. The marker call generated |
| 163 | +/// for an invariant instantiates a `Model`-bounded formula function, so the |
| 164 | +/// function hosting the call must itself satisfy those bounds. When an |
| 165 | +/// invariant names `Self`, `invariant!` instantiates the formula function with |
| 166 | +/// `Self`, so the same bounds are added for `Self` (`with_self`). |
| 167 | +fn inject_model_bounds(generics: &mut Generics, outer: Option<&FnOuterItem>, with_self: bool) { |
| 168 | + let mut tys: Vec<TokenStream2> = generics |
| 169 | + .params |
| 170 | + .iter() |
| 171 | + .filter_map(|p| match p { |
| 172 | + GenericParam::Type(tp) => Some(tp.ident.to_token_stream()), |
| 173 | + _ => None, |
| 174 | + }) |
| 175 | + .collect(); |
| 176 | + if let Some(outer) = outer { |
| 177 | + for param in &outer.generics().params { |
| 178 | + if let GenericParam::Type(tp) = param { |
| 179 | + tys.push(tp.ident.to_token_stream()); |
| 180 | + } |
| 181 | + } |
| 182 | + } |
| 183 | + if with_self { |
| 184 | + tys.push(quote!(Self)); |
| 185 | + } |
| 186 | + if tys.is_empty() { |
| 187 | + return; |
| 188 | + } |
| 189 | + let where_clause = generics.make_where_clause(); |
| 190 | + for ty in tys { |
| 191 | + where_clause |
| 192 | + .predicates |
| 193 | + .push(syn::parse_quote!(#ty: thrust_models::Model)); |
| 194 | + where_clause |
| 195 | + .predicates |
| 196 | + .push(syn::parse_quote!(<#ty as thrust_models::Model>::Ty: PartialEq)); |
| 197 | + } |
| 198 | +} |
| 199 | + |
| 200 | +fn is_invariant_macro(path: &syn::Path) -> bool { |
| 201 | + path.segments.last().is_some_and(|s| s.ident == "invariant") |
| 202 | +} |
0 commit comments