diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 7eb23862..05264f71 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -145,6 +145,14 @@ pub fn exists_path() -> [Symbol; 3] { ] } +pub fn implies_path() -> [Symbol; 3] { + [ + Symbol::intern("thrust"), + Symbol::intern("def"), + Symbol::intern("implies"), + ] +} + pub fn invariant_marker_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 9e418169..2a289f78 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -102,6 +102,7 @@ enum FormulaOrTerm { BinOp(chc::Term, AmbiguousBinOp, chc::Term), And(Box>, Box>), Or(Box>, Box>), + Implies(Box>, Box>), Not(Box>), Literal(bool), } @@ -124,6 +125,7 @@ impl FormulaOrTerm { } FormulaOrTerm::And(lhs, rhs) => lhs.into_formula()?.and(rhs.into_formula()?), FormulaOrTerm::Or(lhs, rhs) => lhs.into_formula()?.or(rhs.into_formula()?), + FormulaOrTerm::Implies(lhs, rhs) => lhs.into_formula()?.implies(rhs.into_formula()?), FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_formula()?.not(), FormulaOrTerm::Literal(b) => { if b { @@ -148,6 +150,7 @@ impl FormulaOrTerm { FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Lt, rhs) => lhs.lt(rhs), FormulaOrTerm::And(lhs, rhs) => lhs.into_term()?.and(rhs.into_term()?), FormulaOrTerm::Or(lhs, rhs) => lhs.into_term()?.or(rhs.into_term()?), + FormulaOrTerm::Implies(lhs, rhs) => lhs.into_term()?.not().or(rhs.into_term()?), FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_term()?.not(), FormulaOrTerm::Literal(b) => chc::Term::bool(b), }; @@ -607,6 +610,14 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> { body_formula, )); } + if Some(def_id) == self.def_ids.implies() { + let [lhs, rhs] = args else { + panic!("implies takes exactly 2 arguments"); + }; + let lhs = self.to_formula_or_term(lhs); + let rhs = self.to_formula_or_term(rhs); + return FormulaOrTerm::Implies(lhs.into(), rhs.into()); + } if Some(def_id) == self.def_ids.mut_model_new() { assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments"); let t1 = self.to_term(&args[0]); diff --git a/src/analyze/did_cache.rs b/src/analyze/did_cache.rs index ee08a576..4e009574 100644 --- a/src/analyze/did_cache.rs +++ b/src/analyze/did_cache.rs @@ -25,6 +25,7 @@ struct DefIds { array_model_store: OnceCell>, exists: OnceCell>, + implies: OnceCell>, invariant_marker: OnceCell>, closure_precondition: OnceCell>, @@ -181,6 +182,13 @@ impl<'tcx> DefIdCache<'tcx> { .get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path())) } + pub fn implies(&self) -> Option { + *self + .def_ids + .implies + .get_or_init(|| self.annotated_def(&crate::analyze::annot::implies_path())) + } + pub fn invariant_marker(&self) -> Option { *self .def_ids diff --git a/src/chc.rs b/src/chc.rs index 58166391..5435a136 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1251,6 +1251,7 @@ pub enum Formula { Not(Box>), And(Vec>), Or(Vec>), + Implies(Box>, Box>), Exists(Vec<(String, Sort)>, Box>), } @@ -1293,6 +1294,13 @@ where ); inner.group() } + Formula::Implies(lhs, rhs) => lhs + .pretty_atom(allocator) + .append(allocator.space()) + .append(allocator.text("==>")) + .append(allocator.line()) + .append(rhs.pretty_atom(allocator)) + .group(), Formula::Exists(vars, fo) => { let vars = allocator.intersperse( vars.iter().map(|(name, sort)| { @@ -1327,7 +1335,7 @@ impl Formula { D::Doc: Clone, { match self { - Formula::And(_) | Formula::Or(_) | Formula::Exists { .. } => { + Formula::And(_) | Formula::Or(_) | Formula::Implies(_, _) | Formula::Exists { .. } => { self.pretty(allocator).parens() } _ => self.pretty(allocator), @@ -1348,6 +1356,7 @@ impl Formula { Formula::Not(fo) => fo.is_bottom(), Formula::And(fs) => fs.iter().all(Formula::is_top), Formula::Or(fs) => fs.iter().any(Formula::is_top), + Formula::Implies(lhs, rhs) => lhs.is_bottom() || rhs.is_top(), Formula::Exists(_, fo) => fo.is_top(), } } @@ -1358,6 +1367,7 @@ impl Formula { Formula::Not(fo) => fo.is_top(), Formula::And(fs) => fs.iter().any(Formula::is_bottom), Formula::Or(fs) => fs.iter().all(Formula::is_bottom), + Formula::Implies(lhs, rhs) => lhs.is_top() && rhs.is_bottom(), Formula::Exists(_, fo) => fo.is_bottom(), } } @@ -1389,6 +1399,10 @@ impl Formula { } } + pub fn implies(self, other: Self) -> Self { + Formula::Implies(Box::new(self), Box::new(other)) + } + pub fn exists(vars: Vec<(String, Sort)>, body: Self) -> Self { Formula::Exists(vars, Box::new(body)) } @@ -1406,6 +1420,9 @@ impl Formula { Formula::And(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect()) } Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect()), + Formula::Implies(lhs, rhs) => { + Formula::Implies(Box::new(lhs.subst_var(&mut f)), Box::new(rhs.subst_var(f))) + } Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.subst_var(f))), } } @@ -1421,6 +1438,9 @@ impl Formula { Formula::Not(fo) => Formula::Not(Box::new(fo.map_var(&mut f))), Formula::And(fs) => Formula::And(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()), Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()), + Formula::Implies(lhs, rhs) => { + Formula::Implies(Box::new(lhs.map_var(&mut f)), Box::new(rhs.map_var(f))) + } Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.map_var(f))), } } @@ -1435,6 +1455,7 @@ impl Formula { Formula::Not(fo) => Box::new(fo.fv()), Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::fv)), Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::fv)), + Formula::Implies(lhs, rhs) => Box::new(lhs.fv().chain(rhs.fv())), Formula::Exists(_, fo) => Box::new(fo.fv()), } } @@ -1449,6 +1470,7 @@ impl Formula { Formula::Not(fo) => Box::new(fo.iter_atoms()), Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)), Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)), + Formula::Implies(lhs, rhs) => Box::new(lhs.iter_atoms().chain(rhs.iter_atoms())), Formula::Exists(_, fo) => Box::new(fo.iter_atoms()), } } @@ -1469,6 +1491,17 @@ impl Formula { match self { Formula::Atom(_atom) => {} Formula::Not(fo) => fo.simplify(), + Formula::Implies(lhs, rhs) => { + lhs.simplify(); + rhs.simplify(); + if lhs.is_bottom() || rhs.is_top() { + *self = Formula::top(); + } else if lhs.is_top() { + *self = std::mem::take(&mut **rhs); + } else if rhs.is_bottom() { + *self = std::mem::take(&mut **lhs).not(); + } + } Formula::And(fs) => { for fo in &mut *fs { fo.simplify(); @@ -1624,7 +1657,7 @@ where .into_iter() .map(|a| a.guarded(guard.clone())) .collect(), - formula: guard.not().or(formula), + formula: guard.implies(formula), } } } diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index e8886ed6..c00782c0 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -288,6 +288,11 @@ impl<'ctx, 'a> std::fmt::Display for Formula<'ctx, 'a> { let fs = List::open(fs.iter().map(|fo| Formula::new(self.ctx, self.clause, fo))); write!(f, "(or {})", fs) } + chc::Formula::Implies(lhs, rhs) => { + let lhs = Formula::new(self.ctx, self.clause, lhs); + let rhs = Formula::new(self.ctx, self.clause, rhs); + write!(f, "(=> {lhs} {rhs})") + } chc::Formula::Exists(vars, fo) => { let vars = List::closed(vars.iter().map(|(v, s)| { diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index b3b24d52..64f79aa4 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -81,6 +81,9 @@ fn unbox_formula(formula: Formula) -> Formula { Formula::Not(fo) => Formula::Not(Box::new(unbox_formula(*fo))), Formula::And(fs) => Formula::And(fs.into_iter().map(unbox_formula).collect()), Formula::Or(fs) => Formula::Or(fs.into_iter().map(unbox_formula).collect()), + Formula::Implies(lhs, rhs) => { + Formula::Implies(Box::new(unbox_formula(*lhs)), Box::new(unbox_formula(*rhs))) + } Formula::Exists(vars, fo) => { let vars = vars.into_iter().map(|(v, s)| (v, unbox_sort(s))).collect(); Formula::Exists(vars, Box::new(unbox_formula(*fo))) diff --git a/src/rty.rs b/src/rty.rs index 7311c4d7..6a98e7e5 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1930,6 +1930,10 @@ fn subst_ty_params_in_formula(formula: &mut chc::Formula, subst: &TypeP subst_ty_params_in_formula(f, subst); } } + chc::Formula::Implies(lhs, rhs) => { + subst_ty_params_in_formula(lhs, subst); + subst_ty_params_in_formula(rhs, subst); + } chc::Formula::Exists(vars, f) => { for (_, sort) in vars { subst_ty_params_in_sort(sort, subst); diff --git a/std.rs b/std.rs index 132e3859..a66766c9 100644 --- a/std.rs +++ b/std.rs @@ -325,6 +325,13 @@ mod thrust_models { unimplemented!() } + #[allow(dead_code)] + #[thrust::def::implies] + #[thrust::ignored] + pub fn implies(_lhs: bool, _rhs: bool) -> bool { + unimplemented!() + } + #[thrust::def::invariant_marker] #[thrust::ignored] #[inline(never)] diff --git a/tests/ui/fail/annot_implication.rs b/tests/ui/fail/annot_implication.rs new file mode 100644 index 00000000..2f55bfb8 --- /dev/null +++ b/tests/ui/fail/annot_implication.rs @@ -0,0 +1,45 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper + +use thrust_models::exists; + +#[thrust::trusted] +#[thrust::callable] +fn rand() -> i64 { + unimplemented!() +} + +// Same contract as the `pass` counterpart, but the body returns a negative +// value when `x > 0`, so the implication `(x > 0) ==> (result > 0)` is +// violated. This confirms `==>` carries real implication semantics rather than +// being silently accepted. +#[thrust_macros::ensures((x > 0) ==> (result > 0))] +fn f(x: i64) -> i64 { + if x > 0 { + -1 + } else { + 0 + } +} + +// An unparenthesized chain must parse right-associatively, i.e. +// `(x > 10) ==> ((x > 5) ==> (result == 1))`. Left-associative parsing would +// make this unprovable, so this case pins down associativity. +#[thrust_macros::ensures((x > 10) ==> (x > 5) ==> (result == 1))] +fn g(x: i64) -> i64 { + if x > 5 { + 1 + } else { + 0 + } +} + +// Implication nested inside a quantifier's closure body. +#[thrust_macros::ensures(exists(|y: i64| (1 == 1) ==> (result == 2 * y)))] +fn k() -> i64 { + let x = rand(); + x + x +} + +fn main() {} diff --git a/tests/ui/pass/annot_implication.rs b/tests/ui/pass/annot_implication.rs new file mode 100644 index 00000000..d111425d --- /dev/null +++ b/tests/ui/pass/annot_implication.rs @@ -0,0 +1,43 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper + +use thrust_models::exists; + +#[thrust::trusted] +#[thrust::callable] +fn rand() -> i64 { + unimplemented!() +} + +// Basic implication in a postcondition: when the antecedent holds the +// consequent must too; otherwise the obligation is vacuous. +#[thrust_macros::ensures((x > 0) ==> (result > 0))] +fn f(x: i64) -> i64 { + if x > 0 { + x + } else { + 0 + } +} + +// An unparenthesized chain must parse right-associatively, i.e. +// `(x > 10) ==> ((x > 5) ==> (result == 1))`. Left-associative parsing would +// make this unprovable, so this case pins down associativity. +#[thrust_macros::ensures((x > 10) ==> (x > 5) ==> (result == 1))] +fn g(x: i64) -> i64 { + if x > 5 { + 1 + } else { + 0 + } +} + +// Implication nested inside a quantifier's closure body. +#[thrust_macros::ensures(exists(|y: i64| (1 == 1) ==> (result == 2 * y)))] +fn k() -> i64 { + let x = rand(); + x + x +} + +fn main() {} diff --git a/thrust-macros/Cargo.toml b/thrust-macros/Cargo.toml index 9fec73b2..a0898642 100644 --- a/thrust-macros/Cargo.toml +++ b/thrust-macros/Cargo.toml @@ -10,4 +10,4 @@ proc-macro = true [dependencies] proc-macro2 = "1" quote = "1" -syn = { version = "2", features = ["full", "visit", "visit-mut"] } +syn = { version = "2", features = ["extra-traits", "full", "visit", "visit-mut"] } diff --git a/thrust-macros/src/formula.rs b/thrust-macros/src/formula.rs new file mode 100644 index 00000000..f2112c82 --- /dev/null +++ b/thrust-macros/src/formula.rs @@ -0,0 +1,269 @@ +//! The `thrust_macros::formula!` proc-macro: the single preprocessing layer for +//! formula tokens. +//! +//! Annotation macros wrap their formula body in `thrust_macros::formula!(...)` +//! instead of splicing it raw. This hides syntax that is not valid Rust (the +//! `==>` operator) inside a macro call's arguments, so the surrounding pipeline — +//! which parses formulas as [`syn::Expr`] — never chokes on it, and gives one +//! place to run preprocessing before the body reaches rustc / HIR lowering. +//! +//! The only pass today is implication lowering; further passes can be appended +//! in [`expand`]. + +use proc_macro2::{Group, Punct, Spacing, TokenStream, TokenTree}; +use quote::{quote, ToTokens}; +use syn::visit_mut::VisitMut; + +/// Wraps formula tokens in a `thrust_macros::formula!(...)` call. +pub fn wrap_expr(tokens: TokenStream) -> TokenStream { + quote!(::thrust_macros::formula!(#tokens)) +} + +/// Wraps an invariant closure's body in `::thrust_macros::formula!(...)`, taking +/// everything past the closure header (the first two top-level `|`) as the body. +/// Only the body is wrapped: the closure must stay a parseable `ExprClosure` for +/// `invariant::expand_invariant` to read its `.inputs`/`.body`, and the context +/// form's threaded signature must not be preprocessed. Non-closure inputs pass +/// through. +/// +/// A closure with an explicit return type (`|x| -> Ty { .. }`) must keep its body +/// a block, so the `-> Ty {` / `}` is preserved and only the block's tail +/// expression is wrapped. +pub fn wrap_closure_body(input: TokenStream) -> TokenStream { + let tokens: Vec = input.into_iter().collect(); + let bars: Vec = tokens + .iter() + .enumerate() + .filter(|(_, tt)| matches!(tt, TokenTree::Punct(p) if p.as_char() == '|')) + .map(|(i, _)| i) + .collect(); + let [_open, close, ..] = bars[..] else { + return tokens.into_iter().collect(); + }; + let header: TokenStream = tokens[..=close].iter().cloned().collect(); + let tail = &tokens[close + 1..]; + + // `-> Ty { body }`: wrap the block's tail expression in place. + let has_return_type = matches!(tail.first(), Some(TokenTree::Punct(p)) if p.as_char() == '-') + && matches!(tail.get(1), Some(TokenTree::Punct(p)) if p.as_char() == '>'); + if has_return_type { + if let Some((TokenTree::Group(block), prefix)) = tail.split_last() { + if block.delimiter() == proc_macro2::Delimiter::Brace { + let prefix: TokenStream = prefix.iter().cloned().collect(); + let body = wrap_block_tail(block); + return quote!(#header #prefix #body); + } + } + } + + // `|args| body`: the body is a bare expression; wrap it whole. + let body = wrap_expr(tail.iter().cloned().collect()); + quote!(#header #body) +} + +/// Returns a brace group like `block` but with its tail expression (the final +/// `;`-separated segment) wrapped in `::thrust_macros::formula!(...)`. Leading +/// statements are left untouched; a block with no tail expression is unchanged. +fn wrap_block_tail(block: &Group) -> TokenStream { + let stmts: Vec = block.stream().into_iter().collect(); + let mut segments: Vec> = vec![Vec::new()]; + for tt in &stmts { + segments.last_mut().unwrap().push(tt.clone()); + if matches!(tt, TokenTree::Punct(p) if p.as_char() == ';') { + segments.push(Vec::new()); + } + } + let tail = segments.pop().unwrap(); + + let mut inner = TokenStream::new(); + for seg in &segments { + inner.extend(seg.iter().cloned()); + } + if !tail.is_empty() { + inner.extend(wrap_expr(tail.into_iter().collect())); + } + + let mut wrapped = Group::new(proc_macro2::Delimiter::Brace, inner); + wrapped.set_span(block.span()); + quote!(#wrapped) +} + +/// Expands `formula!()` into the preprocessed boolean expression. +pub fn expand(input: TokenStream) -> TokenStream { + if let Err(e) = reject_bare_assignment(&input) { + return e.to_compile_error(); + } + + // `==>` is desugared to assignment (the lowest-precedence, right-associative + // operator) so `syn` reproduces its precedence, then each assignment node is + // rewritten into a marker call that the analyzer lowers to `chc::Formula::Implies`. + let desugared = desugar_arrows(input); + let mut expr: syn::Expr = match syn::parse2(desugared) { + Ok(expr) => expr, + Err(e) => return e.to_compile_error(), + }; + + // Rewrites each assignment `lhs = rhs` (produced by [`desugar_arrows`] from + // `lhs ==> rhs`) into `thrust_models::implies(lhs, rhs)`. Visiting + // post-order means nested implications are rewritten innermost-first, so the + // right-associative chain `a ==> b ==> c` becomes `implies(a, implies(b, c))`. + struct ImplicationRewriter; + + impl VisitMut for ImplicationRewriter { + fn visit_expr_mut(&mut self, expr: &mut syn::Expr) { + syn::visit_mut::visit_expr_mut(self, expr); + if let syn::Expr::Assign(assign) = expr { + let left = &assign.left; + let right = &assign.right; + *expr = syn::parse_quote!(thrust_models::implies((#left), (#right))); + } + } + } + + ImplicationRewriter.visit_expr_mut(&mut expr); + + expr.into_token_stream() +} + +/// Rejects a bare assignment `=` anywhere in `input`. Desugaring turns `==>` into +/// `=`, so a genuine `=` would be read as an implication; since assignment is +/// meaningless in a formula we reject it up front. A genuine `=` is a lone +/// `Punct('=')` with [`Spacing::Alone`] not preceded by a joint punct, which +/// excludes the `=`s of `==`, `!=`, `<=`, `>=`, and `==>`. +fn reject_bare_assignment(input: &TokenStream) -> syn::Result<()> { + let mut prev_joint = false; + for tt in input.clone() { + match tt { + TokenTree::Group(g) => { + reject_bare_assignment(&g.stream())?; + prev_joint = false; + } + TokenTree::Punct(p) => { + if p.as_char() == '=' && p.spacing() == Spacing::Alone && !prev_joint { + return Err(syn::Error::new( + p.span(), + "`=` is not allowed in a formula; use `==` for equality or `==>` for implication", + )); + } + prev_joint = p.spacing() == Spacing::Joint; + } + _ => prev_joint = false, + } + } + Ok(()) +} + +/// Replaces every contiguous `==>` (the puncts `=`, `=`, `>`) with a single `=`, +/// recursing into every delimiter group. `==`, `=>`, `>=`, and `->` do not match +/// the triple and are left untouched. +fn desugar_arrows(input: TokenStream) -> TokenStream { + let mut out = TokenStream::new(); + let mut iter = input.into_iter().peekable(); + while let Some(tt) = iter.next() { + match tt { + TokenTree::Group(group) => { + let inner = desugar_arrows(group.stream()); + let mut new_group = Group::new(group.delimiter(), inner); + new_group.set_span(group.span()); + out.extend([TokenTree::Group(new_group)]); + } + TokenTree::Punct(p) if p.as_char() == '=' && p.spacing() == Spacing::Joint => { + // Look for a contiguous `==>`: `p` is the first `=`, joint to a + // second joint `=`, then a `>`. Requiring joint spacing avoids + // matching a spaced-out `= = >` / `== >`. + if let Some(TokenTree::Punct(p2)) = iter.peek() { + if p2.as_char() == '=' && p2.spacing() == Spacing::Joint { + let mut lookahead = iter.clone(); + lookahead.next(); // consume the second `=` + if let Some(TokenTree::Punct(p3)) = lookahead.peek() { + if p3.as_char() == '>' { + // Matched `==>`; consume `=` and `>`, emit `=`. + iter.next(); // second `=` + iter.next(); // `>` + let mut eq = Punct::new('=', Spacing::Alone); + eq.set_span(p.span()); + out.extend([TokenTree::Punct(eq)]); + continue; + } + } + } + } + out.extend([TokenTree::Punct(p)]); + } + other => out.extend([other]), + } + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + /// `expand`s `s`, then renders it through `syn` so spacing matches `expect`. + fn expand_expr(s: &str) -> String { + let expr: syn::Expr = syn::parse2(expand(s.parse().unwrap())).unwrap(); + quote!(#expr).to_string() + } + + fn expect(s: &str) -> String { + let expr: syn::Expr = syn::parse_str(s).unwrap(); + quote!(#expr).to_string() + } + + #[test] + fn desugars_implication() { + assert_eq!( + expand_expr("a ==> b"), + expect("thrust_models::implies((a), (b))") + ); + // right-associative + assert_eq!( + expand_expr("a ==> b ==> c"), + expect("thrust_models::implies((a), (thrust_models::implies((b), (c))))") + ); + // lower precedence than `||` and `==` + assert_eq!( + expand_expr("a || b ==> c == d"), + expect("thrust_models::implies((a || b), (c == d))") + ); + // nested inside a closure argument + assert_eq!( + expand_expr("exists(|x| a ==> b)"), + expect("exists(|x| thrust_models::implies((a), (b)))") + ); + } + + #[test] + fn leaves_comparisons_untouched() { + for s in ["a == b", "a != b", "a >= b", "a <= b"] { + assert_eq!(expand_expr(s), expect(s), "{s}"); + } + } + + #[test] + fn rejects_bare_assignment() { + // `expand` emits a `compile_error!` instead of a valid expression. + for s in ["a = b", "f(x = y)"] { + assert!( + expand(s.parse().unwrap()) + .to_string() + .contains("compile_error"), + "{s}" + ); + } + } + + /// `wrap_closure_body` must leave the input a parseable `ExprClosure`, both + /// for a bare-expression body and one with an explicit return type. + fn assert_wraps_to_closure(s: &str) { + let wrapped = wrap_closure_body(s.parse().unwrap()); + syn::parse2::(wrapped).unwrap_or_else(|e| panic!("{s}: {e}")); + } + + #[test] + fn preserves_closure_forms() { + assert_wraps_to_closure("|x: i64| x >= 1 ==> x >= 0"); + assert_wraps_to_closure("|x: i64| -> bool { x >= 1 ==> x >= 0 }"); + } +} diff --git a/thrust-macros/src/invariant.rs b/thrust-macros/src/invariant.rs index a400da85..e538fd6a 100644 --- a/thrust-macros/src/invariant.rs +++ b/thrust-macros/src/invariant.rs @@ -34,7 +34,6 @@ use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, - parse_macro_input, visit_mut::VisitMut, FnArg, GenericParam, Signature, WherePredicate, }; @@ -46,7 +45,11 @@ static COUNTER: AtomicUsize = AtomicUsize::new(0); /// Expands `invariant!(CLOSURE)`: a bare predicate closure with no threaded /// context. pub fn expand(input: TokenStream) -> TokenStream { - let closure = parse_macro_input!(input as syn::ExprClosure); + let input = crate::formula::wrap_closure_body(input.into()); + let closure = match syn::parse2::(input) { + Ok(closure) => closure, + Err(e) => return e.to_compile_error().into(), + }; match expand_invariant(&closure, None) { Ok(expr) => expr.into_token_stream().into(), Err(e) => e.to_compile_error().into(), @@ -75,7 +78,11 @@ pub fn expand_with_context(input: TokenStream) -> TokenStream { } } - let WithContext { closure, context } = parse_macro_input!(input as WithContext); + let input = crate::formula::wrap_closure_body(input.into()); + let WithContext { closure, context } = match syn::parse2::(input) { + Ok(parsed) => parsed, + Err(e) => return e.to_compile_error().into(), + }; match expand_invariant(&closure, Some(&context)) { Ok(expr) => expr.into_token_stream().into(), Err(e) => e.to_compile_error().into(), diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index ee41b0b3..2c89f573 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -3,6 +3,7 @@ use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; mod context; mod fn_outer_item; +mod formula; mod formula_fn_type_lowering; mod invariant; mod invariant_context; @@ -32,6 +33,12 @@ pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream { context::expand(item) } +/// Preprocesses a formula body (see [`mod@formula`]); not written by hand. +#[proc_macro] +pub fn formula(input: TokenStream) -> TokenStream { + formula::expand(input.into()).into() +} + /// Declares a loop invariant inside a loop body: /// /// ```ignore diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index 2abf2075..93a49f55 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -171,8 +171,11 @@ fn expand_with_annotations( let mut path_stmts = Vec::new(); for mut annotation in annotations { if has_receiver { - annotation.refined_type.formula = - rewrite_self_in_tokens(annotation.refined_type.formula); + let self_ = format_ident!("self_"); + annotation.refined_type.formula = std::mem::take(&mut annotation.refined_type.formula) + .into_iter() + .map(|tt| crate::spec::rewrite_self_in_tokens(tt, &self_)) + .collect(); } formula_fns.push(build_formula_fn(&func, outer_context.as_ref(), &annotation)); path_stmts.push(build_refinement_path_stmt(&func, &annotation)); @@ -404,20 +407,6 @@ fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result TokenStream2 { - tokens - .into_iter() - .map(|tt| match tt { - TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), - TokenTree2::Group(g) => { - let inner = rewrite_self_in_tokens(g.stream()); - TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) - } - other => other, - }) - .collect() -} - fn formula_fn_name(func: &FnItemWithSignature, ann: &RefinedTypeAnnotation) -> syn::Ident { let pos = ann .position @@ -452,7 +441,7 @@ fn build_formula_fn( let extended_where = extended_where_clause(func, &model_preds); let binder = &ann.refined_type.binder; let binder_ty = &ann.refined_type.binder_ty; - let formula = &ann.refined_type.formula; + let formula = crate::formula::wrap_expr(ann.refined_type.formula.clone()); quote! { #[allow(unused_variables)] diff --git a/thrust-macros/src/spec.rs b/thrust-macros/src/spec.rs index 943087de..edf56df8 100644 --- a/thrust-macros/src/spec.rs +++ b/thrust-macros/src/spec.rs @@ -7,7 +7,7 @@ //! references them. use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Group, Ident, TokenStream as TokenStream2, TokenTree}; use quote::{format_ident, quote, ToTokens}; use syn::{ parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, WherePredicate, @@ -16,6 +16,9 @@ use syn::{ use crate::{fn_outer_item::FnOuterItem, FormulaFnTypeLowering}; pub fn expand_predicate(item: TokenStream) -> TokenStream { + // Predicate bodies are consumed by the plugin as a raw SMT-LIB string literal + // (see `analyze::local_def::define_as_predicate`), not as formula expressions, + // so they are not routed through `formula!`. let func = parse_macro_input!(item as FnItemWithSignature); let outer_context = match extract_outer_context(&func) { Ok(ctx) => ctx, @@ -51,7 +54,7 @@ pub fn expand_predicate(item: TokenStream) -> TokenStream { } pub fn expand_requires(attr: TokenStream, item: TokenStream) -> TokenStream { - let expr = TokenStream2::from(attr); + let expr = crate::formula::expand(attr.into()); let mut func = parse_macro_input!(item as FnItemWithSignature); let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { @@ -66,7 +69,7 @@ pub fn expand_requires(attr: TokenStream, item: TokenStream) -> TokenStream { } pub fn expand_ensures(attr: TokenStream, item: TokenStream) -> TokenStream { - let expr = TokenStream2::from(attr); + let expr = crate::formula::expand(attr.into()); let mut func = parse_macro_input!(item as FnItemWithSignature); let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { @@ -489,12 +492,42 @@ fn rewrite_self_in_expr(expr: &mut syn::Expr) { *ident = format_ident!("self_"); } } + + // syn skips macro token streams, so rewrite `self` inside the + // `formula!(..)` wrapper by hand. + fn visit_macro_mut(&mut self, mac: &mut syn::Macro) { + let self_ = format_ident!("self_"); + mac.tokens = std::mem::take(&mut mac.tokens) + .into_iter() + .map(|tt| rewrite_self_in_tokens(tt, &self_)) + .collect(); + } } use syn::visit_mut::VisitMut as _; Visitor.visit_expr_mut(expr); } +/// Replaces a `self` identifier with `self_`, recursing into groups. Takes a +/// single token tree (so callers can `map` it over a stream) and the `self_` +/// replacement to substitute in. +pub fn rewrite_self_in_tokens(token: impl Into, self_: &Ident) -> TokenTree { + match token.into() { + TokenTree::Ident(id) if id == "self" => TokenTree::Ident(self_.clone()), + TokenTree::Group(g) => { + let inner = g + .stream() + .into_iter() + .map(|tt| rewrite_self_in_tokens(tt, self_)) + .collect(); + let mut new_group = Group::new(g.delimiter(), inner); + new_group.set_span(g.span()); + TokenTree::Group(new_group) + } + other => other, + } +} + /// Returns `` — the generic param list for function definitions, /// without a where clause. pub fn generic_params_tokens(generics: &Generics) -> TokenStream2 {