|
| 1 | +//! The `thrust_macros::formula!` proc-macro: the single preprocessing layer for |
| 2 | +//! formula tokens. |
| 3 | +//! |
| 4 | +//! Annotation macros wrap their formula body in `thrust_macros::formula!(...)` |
| 5 | +//! instead of splicing it raw. This hides syntax that is not valid Rust (the |
| 6 | +//! `==>` operator) inside a macro call's arguments, so the surrounding pipeline — |
| 7 | +//! which parses formulas as [`syn::Expr`] — never chokes on it, and gives one |
| 8 | +//! place to run preprocessing before the body reaches rustc / HIR lowering. |
| 9 | +//! |
| 10 | +//! The only pass today is implication desugaring; further passes can be appended |
| 11 | +//! in [`expand`]. |
| 12 | +
|
| 13 | +use proc_macro2::{Group, Punct, Spacing, TokenStream, TokenTree}; |
| 14 | +use quote::{quote, ToTokens}; |
| 15 | +use syn::visit_mut::VisitMut; |
| 16 | + |
| 17 | +/// Wraps formula tokens in a `thrust_macros::formula!(...)` call. |
| 18 | +pub fn wrap_formula(tokens: TokenStream) -> TokenStream { |
| 19 | + quote!(::thrust_macros::formula!(#tokens)) |
| 20 | +} |
| 21 | + |
| 22 | +/// Expands `formula!(<tokens>)` into the preprocessed boolean expression. |
| 23 | +pub fn expand(input: TokenStream) -> TokenStream { |
| 24 | + if let Err(e) = reject_bare_assignment(&input) { |
| 25 | + return e.to_compile_error(); |
| 26 | + } |
| 27 | + |
| 28 | + // `==>` is desugared to assignment (the lowest-precedence, right-associative |
| 29 | + // operator) so `syn` reproduces its precedence, then each assignment node is |
| 30 | + // rewritten into `!lhs || rhs`. |
| 31 | + let desugared = desugar_arrows(input); |
| 32 | + let mut expr: syn::Expr = match syn::parse2(desugared) { |
| 33 | + Ok(expr) => expr, |
| 34 | + Err(e) => return e.to_compile_error(), |
| 35 | + }; |
| 36 | + |
| 37 | + // Rewrites each assignment `lhs = rhs` (produced by [`desugar_arrows`] from |
| 38 | + // `lhs ==> rhs`) into the boolean expression `(!(lhs)) || (rhs)`. Visiting |
| 39 | + // post-order means nested implications are rewritten innermost-first, so the |
| 40 | + // right-associative chain `a ==> b ==> c` becomes `!a || (!b || c)`. |
| 41 | + struct ImplicationRewriter; |
| 42 | + |
| 43 | + impl VisitMut for ImplicationRewriter { |
| 44 | + fn visit_expr_mut(&mut self, expr: &mut syn::Expr) { |
| 45 | + syn::visit_mut::visit_expr_mut(self, expr); |
| 46 | + if let syn::Expr::Assign(assign) = expr { |
| 47 | + let left = &assign.left; |
| 48 | + let right = &assign.right; |
| 49 | + *expr = syn::parse_quote!((!(#left)) || (#right)); |
| 50 | + } |
| 51 | + } |
| 52 | + } |
| 53 | + |
| 54 | + ImplicationRewriter.visit_expr_mut(&mut expr); |
| 55 | + |
| 56 | + expr.into_token_stream() |
| 57 | +} |
| 58 | + |
| 59 | +/// Rejects a bare assignment `=` anywhere in `input`. Desugaring turns `==>` into |
| 60 | +/// `=`, so a genuine `=` would be read as an implication; since assignment is |
| 61 | +/// meaningless in a formula we reject it up front. A genuine `=` is a lone |
| 62 | +/// `Punct('=')` with [`Spacing::Alone`] not preceded by a joint punct, which |
| 63 | +/// excludes the `=`s of `==`, `!=`, `<=`, `>=`, and `==>`. |
| 64 | +fn reject_bare_assignment(input: &TokenStream) -> syn::Result<()> { |
| 65 | + let mut prev_joint = false; |
| 66 | + for tt in input.clone() { |
| 67 | + match tt { |
| 68 | + TokenTree::Group(g) => { |
| 69 | + reject_bare_assignment(&g.stream())?; |
| 70 | + prev_joint = false; |
| 71 | + } |
| 72 | + TokenTree::Punct(p) => { |
| 73 | + if p.as_char() == '=' && p.spacing() == Spacing::Alone && !prev_joint { |
| 74 | + return Err(syn::Error::new( |
| 75 | + p.span(), |
| 76 | + "`=` is not allowed in a formula; use `==` for equality or `==>` for implication", |
| 77 | + )); |
| 78 | + } |
| 79 | + prev_joint = p.spacing() == Spacing::Joint; |
| 80 | + } |
| 81 | + _ => prev_joint = false, |
| 82 | + } |
| 83 | + } |
| 84 | + Ok(()) |
| 85 | +} |
| 86 | + |
| 87 | +/// Replaces every contiguous `==>` (the puncts `=`, `=`, `>`) with a single `=`, |
| 88 | +/// recursing into every delimiter group. `==`, `=>`, `>=`, and `->` do not match |
| 89 | +/// the triple and are left untouched. |
| 90 | +fn desugar_arrows(input: TokenStream) -> TokenStream { |
| 91 | + let mut out = TokenStream::new(); |
| 92 | + let mut iter = input.into_iter().peekable(); |
| 93 | + while let Some(tt) = iter.next() { |
| 94 | + match tt { |
| 95 | + TokenTree::Group(group) => { |
| 96 | + let inner = desugar_arrows(group.stream()); |
| 97 | + let mut new_group = Group::new(group.delimiter(), inner); |
| 98 | + new_group.set_span(group.span()); |
| 99 | + out.extend([TokenTree::Group(new_group)]); |
| 100 | + } |
| 101 | + TokenTree::Punct(p) if p.as_char() == '=' && p.spacing() == Spacing::Joint => { |
| 102 | + // Look for `=` `=` `>`. `p` is the first `=`. |
| 103 | + if let Some(TokenTree::Punct(p2)) = iter.peek() { |
| 104 | + if p2.as_char() == '=' && p2.spacing() == Spacing::Joint { |
| 105 | + let mut lookahead = iter.clone(); |
| 106 | + lookahead.next(); // consume the second `=` |
| 107 | + if let Some(TokenTree::Punct(p3)) = lookahead.peek() { |
| 108 | + if p3.as_char() == '>' { |
| 109 | + // Matched `==>`; consume `=` and `>`, emit `=`. |
| 110 | + iter.next(); // second `=` |
| 111 | + iter.next(); // `>` |
| 112 | + out.extend([TokenTree::Punct(Punct::new('=', Spacing::Alone))]); |
| 113 | + continue; |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + } |
| 118 | + out.extend([TokenTree::Punct(p)]); |
| 119 | + } |
| 120 | + other => out.extend([other]), |
| 121 | + } |
| 122 | + } |
| 123 | + out |
| 124 | +} |
| 125 | + |
| 126 | +#[cfg(test)] |
| 127 | +mod tests { |
| 128 | + use super::*; |
| 129 | + |
| 130 | + /// `expand`s `s`, then renders it through `syn` so spacing matches `expect`. |
| 131 | + fn expand_expr(s: &str) -> String { |
| 132 | + let expr: syn::Expr = syn::parse2(expand(s.parse().unwrap())).unwrap(); |
| 133 | + quote!(#expr).to_string() |
| 134 | + } |
| 135 | + |
| 136 | + fn expect(s: &str) -> String { |
| 137 | + let expr: syn::Expr = syn::parse_str(s).unwrap(); |
| 138 | + quote!(#expr).to_string() |
| 139 | + } |
| 140 | + |
| 141 | + #[test] |
| 142 | + fn desugars_implication() { |
| 143 | + assert_eq!(expand_expr("a ==> b"), expect("(!(a)) || (b)")); |
| 144 | + // right-associative |
| 145 | + assert_eq!( |
| 146 | + expand_expr("a ==> b ==> c"), |
| 147 | + expect("(!(a)) || ((!(b)) || (c))") |
| 148 | + ); |
| 149 | + // lower precedence than `||` and `==` |
| 150 | + assert_eq!( |
| 151 | + expand_expr("a || b ==> c == d"), |
| 152 | + expect("(!(a || b)) || (c == d)") |
| 153 | + ); |
| 154 | + // nested inside a closure argument |
| 155 | + assert_eq!( |
| 156 | + expand_expr("exists(|x| a ==> b)"), |
| 157 | + expect("exists(|x| (!(a)) || (b))") |
| 158 | + ); |
| 159 | + } |
| 160 | + |
| 161 | + #[test] |
| 162 | + fn leaves_comparisons_untouched() { |
| 163 | + for s in ["a == b", "a != b", "a >= b", "a <= b"] { |
| 164 | + assert_eq!(expand_expr(s), expect(s), "{s}"); |
| 165 | + } |
| 166 | + } |
| 167 | + |
| 168 | + #[test] |
| 169 | + fn rejects_bare_assignment() { |
| 170 | + // `expand` emits a `compile_error!` instead of a valid expression. |
| 171 | + for s in ["a = b", "f(x = y)"] { |
| 172 | + assert!( |
| 173 | + expand(s.parse().unwrap()) |
| 174 | + .to_string() |
| 175 | + .contains("compile_error"), |
| 176 | + "{s}" |
| 177 | + ); |
| 178 | + } |
| 179 | + } |
| 180 | +} |
0 commit comments