|
| 1 | +//! The `thrust_macros::formula!` proc-macro: the single preprocessing layer for |
| 2 | +//! formula tokens. |
| 3 | +//! |
| 4 | +//! Annotation macros wrap their formula body in `thrust_macros::formula!(...)` |
| 5 | +//! instead of splicing it raw. This hides syntax that is not valid Rust (the |
| 6 | +//! `==>` operator) inside a macro call's arguments, so the surrounding pipeline — |
| 7 | +//! which parses formulas as [`syn::Expr`] — never chokes on it, and gives one |
| 8 | +//! place to run preprocessing before the body reaches rustc / HIR lowering. |
| 9 | +//! |
| 10 | +//! The only pass today is implication desugaring; further passes can be appended |
| 11 | +//! in [`expand`]. |
| 12 | +
|
| 13 | +use proc_macro2::{Group, Punct, Spacing, TokenStream, TokenTree}; |
| 14 | +use quote::{quote, ToTokens}; |
| 15 | +use syn::visit_mut::VisitMut; |
| 16 | + |
| 17 | +/// Wraps formula tokens in a `thrust_macros::formula!(...)` call. |
| 18 | +pub fn wrap_expr(tokens: TokenStream) -> TokenStream { |
| 19 | + quote!(::thrust_macros::formula!(#tokens)) |
| 20 | +} |
| 21 | + |
| 22 | +/// Wraps an invariant closure's body in `::thrust_macros::formula!(...)`, taking |
| 23 | +/// everything past the closure header (the first two top-level `|`) as the body. |
| 24 | +/// Only the body is wrapped: the closure must stay a parseable `ExprClosure` for |
| 25 | +/// `invariant::expand_invariant` to read its `.inputs`/`.body`, and the context |
| 26 | +/// form's threaded signature must not be preprocessed. Non-closure inputs pass |
| 27 | +/// through. |
| 28 | +/// |
| 29 | +/// A closure with an explicit return type (`|x| -> Ty { .. }`) must keep its body |
| 30 | +/// a block, so the `-> Ty {` / `}` is preserved and only the block's tail |
| 31 | +/// expression is wrapped. |
| 32 | +pub fn wrap_closure_body(input: TokenStream) -> TokenStream { |
| 33 | + let tokens: Vec<TokenTree> = input.into_iter().collect(); |
| 34 | + let bars: Vec<usize> = tokens |
| 35 | + .iter() |
| 36 | + .enumerate() |
| 37 | + .filter(|(_, tt)| matches!(tt, TokenTree::Punct(p) if p.as_char() == '|')) |
| 38 | + .map(|(i, _)| i) |
| 39 | + .collect(); |
| 40 | + let [_open, close, ..] = bars[..] else { |
| 41 | + return tokens.into_iter().collect(); |
| 42 | + }; |
| 43 | + let header: TokenStream = tokens[..=close].iter().cloned().collect(); |
| 44 | + let tail = &tokens[close + 1..]; |
| 45 | + |
| 46 | + // `-> Ty { body }`: wrap the block's tail expression in place. |
| 47 | + let has_return_type = matches!(tail.first(), Some(TokenTree::Punct(p)) if p.as_char() == '-') |
| 48 | + && matches!(tail.get(1), Some(TokenTree::Punct(p)) if p.as_char() == '>'); |
| 49 | + if has_return_type { |
| 50 | + if let Some((TokenTree::Group(block), prefix)) = tail.split_last() { |
| 51 | + if block.delimiter() == proc_macro2::Delimiter::Brace { |
| 52 | + let prefix: TokenStream = prefix.iter().cloned().collect(); |
| 53 | + let body = wrap_block_tail(block); |
| 54 | + return quote!(#header #prefix #body); |
| 55 | + } |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + // `|args| body`: the body is a bare expression; wrap it whole. |
| 60 | + let body = wrap_expr(tail.iter().cloned().collect()); |
| 61 | + quote!(#header #body) |
| 62 | +} |
| 63 | + |
| 64 | +/// Returns a brace group like `block` but with its tail expression (the final |
| 65 | +/// `;`-separated segment) wrapped in `::thrust_macros::formula!(...)`. Leading |
| 66 | +/// statements are left untouched; a block with no tail expression is unchanged. |
| 67 | +fn wrap_block_tail(block: &Group) -> TokenStream { |
| 68 | + let stmts: Vec<TokenTree> = block.stream().into_iter().collect(); |
| 69 | + let mut segments: Vec<Vec<TokenTree>> = vec![Vec::new()]; |
| 70 | + for tt in &stmts { |
| 71 | + segments.last_mut().unwrap().push(tt.clone()); |
| 72 | + if matches!(tt, TokenTree::Punct(p) if p.as_char() == ';') { |
| 73 | + segments.push(Vec::new()); |
| 74 | + } |
| 75 | + } |
| 76 | + let tail = segments.pop().unwrap(); |
| 77 | + |
| 78 | + let mut inner = TokenStream::new(); |
| 79 | + for seg in &segments { |
| 80 | + inner.extend(seg.iter().cloned()); |
| 81 | + } |
| 82 | + if !tail.is_empty() { |
| 83 | + inner.extend(wrap_expr(tail.into_iter().collect())); |
| 84 | + } |
| 85 | + |
| 86 | + let mut wrapped = Group::new(proc_macro2::Delimiter::Brace, inner); |
| 87 | + wrapped.set_span(block.span()); |
| 88 | + quote!(#wrapped) |
| 89 | +} |
| 90 | + |
| 91 | +/// Expands `formula!(<tokens>)` into the preprocessed boolean expression. |
| 92 | +pub fn expand(input: TokenStream) -> TokenStream { |
| 93 | + if let Err(e) = reject_bare_assignment(&input) { |
| 94 | + return e.to_compile_error(); |
| 95 | + } |
| 96 | + |
| 97 | + // `==>` is desugared to assignment (the lowest-precedence, right-associative |
| 98 | + // operator) so `syn` reproduces its precedence, then each assignment node is |
| 99 | + // rewritten into `!lhs || rhs`. |
| 100 | + let desugared = desugar_arrows(input); |
| 101 | + let mut expr: syn::Expr = match syn::parse2(desugared) { |
| 102 | + Ok(expr) => expr, |
| 103 | + Err(e) => return e.to_compile_error(), |
| 104 | + }; |
| 105 | + |
| 106 | + // Rewrites each assignment `lhs = rhs` (produced by [`desugar_arrows`] from |
| 107 | + // `lhs ==> rhs`) into `(!(lhs)) || (rhs)`. Visiting post-order means nested |
| 108 | + // implications are rewritten innermost-first, so the right-associative chain |
| 109 | + // `a ==> b ==> c` becomes `!a || (!b || c)`. |
| 110 | + struct ImplicationRewriter; |
| 111 | + |
| 112 | + impl VisitMut for ImplicationRewriter { |
| 113 | + fn visit_expr_mut(&mut self, expr: &mut syn::Expr) { |
| 114 | + syn::visit_mut::visit_expr_mut(self, expr); |
| 115 | + if let syn::Expr::Assign(assign) = expr { |
| 116 | + let left = &assign.left; |
| 117 | + let right = &assign.right; |
| 118 | + *expr = syn::parse_quote!((!(#left)) || (#right)); |
| 119 | + } |
| 120 | + } |
| 121 | + } |
| 122 | + |
| 123 | + ImplicationRewriter.visit_expr_mut(&mut expr); |
| 124 | + |
| 125 | + expr.into_token_stream() |
| 126 | +} |
| 127 | + |
| 128 | +/// Rejects a bare assignment `=` anywhere in `input`. Desugaring turns `==>` into |
| 129 | +/// `=`, so a genuine `=` would be read as an implication; since assignment is |
| 130 | +/// meaningless in a formula we reject it up front. A genuine `=` is a lone |
| 131 | +/// `Punct('=')` with [`Spacing::Alone`] not preceded by a joint punct, which |
| 132 | +/// excludes the `=`s of `==`, `!=`, `<=`, `>=`, and `==>`. |
| 133 | +fn reject_bare_assignment(input: &TokenStream) -> syn::Result<()> { |
| 134 | + let mut prev_joint = false; |
| 135 | + for tt in input.clone() { |
| 136 | + match tt { |
| 137 | + TokenTree::Group(g) => { |
| 138 | + reject_bare_assignment(&g.stream())?; |
| 139 | + prev_joint = false; |
| 140 | + } |
| 141 | + TokenTree::Punct(p) => { |
| 142 | + if p.as_char() == '=' && p.spacing() == Spacing::Alone && !prev_joint { |
| 143 | + return Err(syn::Error::new( |
| 144 | + p.span(), |
| 145 | + "`=` is not allowed in a formula; use `==` for equality or `==>` for implication", |
| 146 | + )); |
| 147 | + } |
| 148 | + prev_joint = p.spacing() == Spacing::Joint; |
| 149 | + } |
| 150 | + _ => prev_joint = false, |
| 151 | + } |
| 152 | + } |
| 153 | + Ok(()) |
| 154 | +} |
| 155 | + |
| 156 | +/// Replaces every contiguous `==>` (the puncts `=`, `=`, `>`) with a single `=`, |
| 157 | +/// recursing into every delimiter group. `==`, `=>`, `>=`, and `->` do not match |
| 158 | +/// the triple and are left untouched. |
| 159 | +fn desugar_arrows(input: TokenStream) -> TokenStream { |
| 160 | + let mut out = TokenStream::new(); |
| 161 | + let mut iter = input.into_iter().peekable(); |
| 162 | + while let Some(tt) = iter.next() { |
| 163 | + match tt { |
| 164 | + TokenTree::Group(group) => { |
| 165 | + let inner = desugar_arrows(group.stream()); |
| 166 | + let mut new_group = Group::new(group.delimiter(), inner); |
| 167 | + new_group.set_span(group.span()); |
| 168 | + out.extend([TokenTree::Group(new_group)]); |
| 169 | + } |
| 170 | + TokenTree::Punct(p) if p.as_char() == '=' && p.spacing() == Spacing::Joint => { |
| 171 | + // Look for a contiguous `==>`: `p` is the first `=`, joint to a |
| 172 | + // second joint `=`, then a `>`. Requiring joint spacing avoids |
| 173 | + // matching a spaced-out `= = >` / `== >`. |
| 174 | + if let Some(TokenTree::Punct(p2)) = iter.peek() { |
| 175 | + if p2.as_char() == '=' && p2.spacing() == Spacing::Joint { |
| 176 | + let mut lookahead = iter.clone(); |
| 177 | + lookahead.next(); // consume the second `=` |
| 178 | + if let Some(TokenTree::Punct(p3)) = lookahead.peek() { |
| 179 | + if p3.as_char() == '>' { |
| 180 | + // Matched `==>`; consume `=` and `>`, emit `=`. |
| 181 | + iter.next(); // second `=` |
| 182 | + iter.next(); // `>` |
| 183 | + let mut eq = Punct::new('=', Spacing::Alone); |
| 184 | + eq.set_span(p.span()); |
| 185 | + out.extend([TokenTree::Punct(eq)]); |
| 186 | + continue; |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | + } |
| 191 | + out.extend([TokenTree::Punct(p)]); |
| 192 | + } |
| 193 | + other => out.extend([other]), |
| 194 | + } |
| 195 | + } |
| 196 | + out |
| 197 | +} |
| 198 | + |
| 199 | +#[cfg(test)] |
| 200 | +mod tests { |
| 201 | + use super::*; |
| 202 | + |
| 203 | + /// `expand`s `s`, then renders it through `syn` so spacing matches `expect`. |
| 204 | + fn expand_expr(s: &str) -> String { |
| 205 | + let expr: syn::Expr = syn::parse2(expand(s.parse().unwrap())).unwrap(); |
| 206 | + quote!(#expr).to_string() |
| 207 | + } |
| 208 | + |
| 209 | + fn expect(s: &str) -> String { |
| 210 | + let expr: syn::Expr = syn::parse_str(s).unwrap(); |
| 211 | + quote!(#expr).to_string() |
| 212 | + } |
| 213 | + |
| 214 | + #[test] |
| 215 | + fn desugars_implication() { |
| 216 | + assert_eq!(expand_expr("a ==> b"), expect("(!(a)) || (b)")); |
| 217 | + // right-associative |
| 218 | + assert_eq!( |
| 219 | + expand_expr("a ==> b ==> c"), |
| 220 | + expect("(!(a)) || ((!(b)) || (c))") |
| 221 | + ); |
| 222 | + // lower precedence than `||` and `==` |
| 223 | + assert_eq!( |
| 224 | + expand_expr("a || b ==> c == d"), |
| 225 | + expect("(!(a || b)) || (c == d)") |
| 226 | + ); |
| 227 | + // nested inside a closure argument |
| 228 | + assert_eq!( |
| 229 | + expand_expr("exists(|x| a ==> b)"), |
| 230 | + expect("exists(|x| (!(a)) || (b))") |
| 231 | + ); |
| 232 | + } |
| 233 | + |
| 234 | + #[test] |
| 235 | + fn leaves_comparisons_untouched() { |
| 236 | + for s in ["a == b", "a != b", "a >= b", "a <= b"] { |
| 237 | + assert_eq!(expand_expr(s), expect(s), "{s}"); |
| 238 | + } |
| 239 | + } |
| 240 | + |
| 241 | + #[test] |
| 242 | + fn rejects_bare_assignment() { |
| 243 | + // `expand` emits a `compile_error!` instead of a valid expression. |
| 244 | + for s in ["a = b", "f(x = y)"] { |
| 245 | + assert!( |
| 246 | + expand(s.parse().unwrap()) |
| 247 | + .to_string() |
| 248 | + .contains("compile_error"), |
| 249 | + "{s}" |
| 250 | + ); |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + /// `wrap_closure_body` must leave the input a parseable `ExprClosure`, both |
| 255 | + /// for a bare-expression body and one with an explicit return type. |
| 256 | + fn assert_wraps_to_closure(s: &str) { |
| 257 | + let wrapped = wrap_closure_body(s.parse().unwrap()); |
| 258 | + syn::parse2::<syn::ExprClosure>(wrapped).unwrap_or_else(|e| panic!("{s}: {e}")); |
| 259 | + } |
| 260 | + |
| 261 | + #[test] |
| 262 | + fn preserves_closure_forms() { |
| 263 | + assert_wraps_to_closure("|x: i64| x >= 1 ==> x >= 0"); |
| 264 | + assert_wraps_to_closure("|x: i64| -> bool { x >= 1 ==> x >= 0 }"); |
| 265 | + } |
| 266 | +} |
0 commit comments