Skip to content

Commit 7165183

Browse files
claudecoord-e
authored andcommitted
Support logical implication (==>) in annotations
Add a single formula-token preprocessing layer, the `thrust_macros::formula!` proc-macro, that every annotation wraps its formula body in. Its first pass desugars the implication operator `a ==> b` into `!a || b`: - `==>` is rewritten to assignment (the lowest-precedence, right-associative Rust operator) so `syn` reproduces implication's precedence/associativity for free, including inside closure bodies; - each resulting assignment node is rewritten to `(!(lhs)) || (rhs)` so the generated `#[thrust::formula_fn]` body is valid boolean Rust. `requires`/`ensures`, `predicate`, the `param`/`ret`/`sig` refinement types, and `invariant!` all route their formula through `formula!`. Closes #106. Reject bare `=` in formula! to avoid ambiguity with implication `==>` is desugared to assignment, so once parsed an `Expr::Assign` from a genuinely-written `=` is indistinguishable from an implication. Reject a bare `=` (a lone `Punct('=')` with Alone spacing, not preceded by a joint punct, so `==`/`!=`/`<=`/`>=`/`==>` are unaffected) up front with a clear diagnostic. Also enable syn's `extra-traits` feature explicitly (the crate compares `syn::Path` values), so the crate builds and its unit tests run standalone. https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL Document why formula wrapping isolates the body/tail expression https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL Address review: trim comments, dedupe self-rewrite, test expand() - shorten and de-duplicate doc/inline comments - call crate::formula::wrap_formula directly instead of importing it - rename rewrite_self_in_macro_tokens to rewrite_self_in_tokens and reuse it from rty, dropping rty's duplicate copy - unit-test the public formula::expand, not only reject_bare_assignment https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL cargo fmt https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL style fix
1 parent 624e9e9 commit 7165183

8 files changed

Lines changed: 389 additions & 24 deletions

File tree

tests/ui/fail/annot_implication.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper
4+
5+
use thrust_models::exists;
6+
7+
#[thrust::trusted]
8+
#[thrust::callable]
9+
fn rand() -> i64 {
10+
unimplemented!()
11+
}
12+
13+
// Same contract as the `pass` counterpart, but the body returns a negative
14+
// value when `x > 0`, so the implication `(x > 0) ==> (result > 0)` is
15+
// violated. This confirms `==>` carries real implication semantics rather than
16+
// being silently accepted.
17+
#[thrust_macros::ensures((x > 0) ==> (result > 0))]
18+
fn f(x: i64) -> i64 {
19+
if x > 0 {
20+
-1
21+
} else {
22+
0
23+
}
24+
}
25+
26+
// An unparenthesized chain must parse right-associatively, i.e.
27+
// `(x > 10) ==> ((x > 5) ==> (result == 1))`. Left-associative parsing would
28+
// make this unprovable, so this case pins down associativity.
29+
#[thrust_macros::ensures((x > 10) ==> (x > 5) ==> (result == 1))]
30+
fn g(x: i64) -> i64 {
31+
if x > 5 {
32+
1
33+
} else {
34+
0
35+
}
36+
}
37+
38+
// Implication nested inside a quantifier's closure body.
39+
#[thrust_macros::ensures(exists(|y: i64| (1 == 1) ==> (result == 2 * y)))]
40+
fn k() -> i64 {
41+
let x = rand();
42+
x + x
43+
}
44+
45+
fn main() {}

tests/ui/pass/annot_implication.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//@check-pass
2+
//@compile-flags: -C debug-assertions=off
3+
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper
4+
5+
use thrust_models::exists;
6+
7+
#[thrust::trusted]
8+
#[thrust::callable]
9+
fn rand() -> i64 {
10+
unimplemented!()
11+
}
12+
13+
// Basic implication in a postcondition: when the antecedent holds the
14+
// consequent must too; otherwise the obligation is vacuous.
15+
#[thrust_macros::ensures((x > 0) ==> (result > 0))]
16+
fn f(x: i64) -> i64 {
17+
if x > 0 {
18+
x
19+
} else {
20+
0
21+
}
22+
}
23+
24+
// An unparenthesized chain must parse right-associatively, i.e.
25+
// `(x > 10) ==> ((x > 5) ==> (result == 1))`. Left-associative parsing would
26+
// make this unprovable, so this case pins down associativity.
27+
#[thrust_macros::ensures((x > 10) ==> (x > 5) ==> (result == 1))]
28+
fn g(x: i64) -> i64 {
29+
if x > 5 {
30+
1
31+
} else {
32+
0
33+
}
34+
}
35+
36+
// Implication nested inside a quantifier's closure body.
37+
#[thrust_macros::ensures(exists(|y: i64| (1 == 1) ==> (result == 2 * y)))]
38+
fn k() -> i64 {
39+
let x = rand();
40+
x + x
41+
}
42+
43+
fn main() {}

thrust-macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ proc-macro = true
1010
[dependencies]
1111
proc-macro2 = "1"
1212
quote = "1"
13-
syn = { version = "2", features = ["full", "visit", "visit-mut"] }
13+
syn = { version = "2", features = ["extra-traits", "full", "visit", "visit-mut"] }

thrust-macros/src/formula.rs

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
//! The `thrust_macros::formula!` proc-macro: the single preprocessing layer for
2+
//! formula tokens.
3+
//!
4+
//! Annotation macros wrap their formula body in `thrust_macros::formula!(...)`
5+
//! instead of splicing it raw. This hides syntax that is not valid Rust (the
6+
//! `==>` operator) inside a macro call's arguments, so the surrounding pipeline —
7+
//! which parses formulas as [`syn::Expr`] — never chokes on it, and gives one
8+
//! place to run preprocessing before the body reaches rustc / HIR lowering.
9+
//!
10+
//! The only pass today is implication desugaring; further passes can be appended
11+
//! in [`expand`].
12+
13+
use proc_macro2::{Group, Punct, Spacing, TokenStream, TokenTree};
14+
use quote::{quote, ToTokens};
15+
use syn::visit_mut::VisitMut;
16+
17+
/// Wraps formula tokens in a `thrust_macros::formula!(...)` call.
18+
pub fn wrap_formula(tokens: TokenStream) -> TokenStream {
19+
quote!(::thrust_macros::formula!(#tokens))
20+
}
21+
22+
/// Expands `formula!(<tokens>)` into the preprocessed boolean expression.
23+
pub fn expand(input: TokenStream) -> TokenStream {
24+
if let Err(e) = reject_bare_assignment(&input) {
25+
return e.to_compile_error();
26+
}
27+
28+
// `==>` is desugared to assignment (the lowest-precedence, right-associative
29+
// operator) so `syn` reproduces its precedence, then each assignment node is
30+
// rewritten into `!lhs || rhs`.
31+
let desugared = desugar_arrows(input);
32+
let mut expr: syn::Expr = match syn::parse2(desugared) {
33+
Ok(expr) => expr,
34+
Err(e) => return e.to_compile_error(),
35+
};
36+
37+
// Rewrites each assignment `lhs = rhs` (produced by [`desugar_arrows`] from
38+
// `lhs ==> rhs`) into the boolean expression `(!(lhs)) || (rhs)`. Visiting
39+
// post-order means nested implications are rewritten innermost-first, so the
40+
// right-associative chain `a ==> b ==> c` becomes `!a || (!b || c)`.
41+
struct ImplicationRewriter;
42+
43+
impl VisitMut for ImplicationRewriter {
44+
fn visit_expr_mut(&mut self, expr: &mut syn::Expr) {
45+
syn::visit_mut::visit_expr_mut(self, expr);
46+
if let syn::Expr::Assign(assign) = expr {
47+
let left = &assign.left;
48+
let right = &assign.right;
49+
*expr = syn::parse_quote!((!(#left)) || (#right));
50+
}
51+
}
52+
}
53+
54+
ImplicationRewriter.visit_expr_mut(&mut expr);
55+
56+
expr.into_token_stream()
57+
}
58+
59+
/// Rejects a bare assignment `=` anywhere in `input`. Desugaring turns `==>` into
60+
/// `=`, so a genuine `=` would be read as an implication; since assignment is
61+
/// meaningless in a formula we reject it up front. A genuine `=` is a lone
62+
/// `Punct('=')` with [`Spacing::Alone`] not preceded by a joint punct, which
63+
/// excludes the `=`s of `==`, `!=`, `<=`, `>=`, and `==>`.
64+
fn reject_bare_assignment(input: &TokenStream) -> syn::Result<()> {
65+
let mut prev_joint = false;
66+
for tt in input.clone() {
67+
match tt {
68+
TokenTree::Group(g) => {
69+
reject_bare_assignment(&g.stream())?;
70+
prev_joint = false;
71+
}
72+
TokenTree::Punct(p) => {
73+
if p.as_char() == '=' && p.spacing() == Spacing::Alone && !prev_joint {
74+
return Err(syn::Error::new(
75+
p.span(),
76+
"`=` is not allowed in a formula; use `==` for equality or `==>` for implication",
77+
));
78+
}
79+
prev_joint = p.spacing() == Spacing::Joint;
80+
}
81+
_ => prev_joint = false,
82+
}
83+
}
84+
Ok(())
85+
}
86+
87+
/// Replaces every contiguous `==>` (the puncts `=`, `=`, `>`) with a single `=`,
88+
/// recursing into every delimiter group. `==`, `=>`, `>=`, and `->` do not match
89+
/// the triple and are left untouched.
90+
fn desugar_arrows(input: TokenStream) -> TokenStream {
91+
let mut out = TokenStream::new();
92+
let mut iter = input.into_iter().peekable();
93+
while let Some(tt) = iter.next() {
94+
match tt {
95+
TokenTree::Group(group) => {
96+
let inner = desugar_arrows(group.stream());
97+
let mut new_group = Group::new(group.delimiter(), inner);
98+
new_group.set_span(group.span());
99+
out.extend([TokenTree::Group(new_group)]);
100+
}
101+
TokenTree::Punct(p) if p.as_char() == '=' && p.spacing() == Spacing::Joint => {
102+
// Look for `=` `=` `>`. `p` is the first `=`.
103+
if let Some(TokenTree::Punct(p2)) = iter.peek() {
104+
if p2.as_char() == '=' && p2.spacing() == Spacing::Joint {
105+
let mut lookahead = iter.clone();
106+
lookahead.next(); // consume the second `=`
107+
if let Some(TokenTree::Punct(p3)) = lookahead.peek() {
108+
if p3.as_char() == '>' {
109+
// Matched `==>`; consume `=` and `>`, emit `=`.
110+
iter.next(); // second `=`
111+
iter.next(); // `>`
112+
out.extend([TokenTree::Punct(Punct::new('=', Spacing::Alone))]);
113+
continue;
114+
}
115+
}
116+
}
117+
}
118+
out.extend([TokenTree::Punct(p)]);
119+
}
120+
other => out.extend([other]),
121+
}
122+
}
123+
out
124+
}
125+
126+
#[cfg(test)]
127+
mod tests {
128+
use super::*;
129+
130+
/// `expand`s `s`, then renders it through `syn` so spacing matches `expect`.
131+
fn expand_expr(s: &str) -> String {
132+
let expr: syn::Expr = syn::parse2(expand(s.parse().unwrap())).unwrap();
133+
quote!(#expr).to_string()
134+
}
135+
136+
fn expect(s: &str) -> String {
137+
let expr: syn::Expr = syn::parse_str(s).unwrap();
138+
quote!(#expr).to_string()
139+
}
140+
141+
#[test]
142+
fn desugars_implication() {
143+
assert_eq!(expand_expr("a ==> b"), expect("(!(a)) || (b)"));
144+
// right-associative
145+
assert_eq!(
146+
expand_expr("a ==> b ==> c"),
147+
expect("(!(a)) || ((!(b)) || (c))")
148+
);
149+
// lower precedence than `||` and `==`
150+
assert_eq!(
151+
expand_expr("a || b ==> c == d"),
152+
expect("(!(a || b)) || (c == d)")
153+
);
154+
// nested inside a closure argument
155+
assert_eq!(
156+
expand_expr("exists(|x| a ==> b)"),
157+
expect("exists(|x| (!(a)) || (b))")
158+
);
159+
}
160+
161+
#[test]
162+
fn leaves_comparisons_untouched() {
163+
for s in ["a == b", "a != b", "a >= b", "a <= b"] {
164+
assert_eq!(expand_expr(s), expect(s), "{s}");
165+
}
166+
}
167+
168+
#[test]
169+
fn rejects_bare_assignment() {
170+
// `expand` emits a `compile_error!` instead of a valid expression.
171+
for s in ["a = b", "f(x = y)"] {
172+
assert!(
173+
expand(s.parse().unwrap())
174+
.to_string()
175+
.contains("compile_error"),
176+
"{s}"
177+
);
178+
}
179+
}
180+
}

thrust-macros/src/invariant.rs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030
use std::sync::atomic::{AtomicUsize, Ordering};
3131

3232
use proc_macro::TokenStream;
33-
use proc_macro2::TokenStream as TokenStream2;
33+
use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2};
3434
use quote::{format_ident, quote, ToTokens};
3535
use syn::{
3636
parse::{Parse, ParseStream},
37-
parse_macro_input,
3837
visit_mut::VisitMut,
3938
FnArg, GenericParam, Signature, WherePredicate,
4039
};
@@ -46,13 +45,38 @@ static COUNTER: AtomicUsize = AtomicUsize::new(0);
4645
/// Expands `invariant!(CLOSURE)`: a bare predicate closure with no threaded
4746
/// context.
4847
pub fn expand(input: TokenStream) -> TokenStream {
49-
let closure = parse_macro_input!(input as syn::ExprClosure);
48+
let input = wrap_closure_body(input.into());
49+
let closure = match syn::parse2::<syn::ExprClosure>(input) {
50+
Ok(closure) => closure,
51+
Err(e) => return e.to_compile_error().into(),
52+
};
5053
match expand_invariant(&closure, None) {
5154
Ok(expr) => expr.into_token_stream().into(),
5255
Err(e) => e.to_compile_error().into(),
5356
}
5457
}
5558

59+
/// Wraps the closure body in `input` in `::thrust_macros::formula!(...)`, taking
60+
/// everything past the closure header (the first two top-level `|`) as the body.
61+
/// Only the body is wrapped: the closure must stay a parseable `ExprClosure` for
62+
/// [`expand_invariant`] to read its `.inputs`/`.body`, and the context form's
63+
/// threaded signature must not be preprocessed. Non-closure inputs pass through.
64+
fn wrap_closure_body(input: TokenStream2) -> TokenStream2 {
65+
let tokens: Vec<TokenTree2> = input.into_iter().collect();
66+
let bars: Vec<usize> = tokens
67+
.iter()
68+
.enumerate()
69+
.filter(|(_, tt)| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|'))
70+
.map(|(i, _)| i)
71+
.collect();
72+
let [_open, close, ..] = bars[..] else {
73+
return tokens.into_iter().collect();
74+
};
75+
let header: TokenStream2 = tokens[..=close].iter().cloned().collect();
76+
let body = crate::formula::wrap_formula(tokens[close + 1..].iter().cloned().collect());
77+
quote!(#header #body)
78+
}
79+
5680
/// Expands `_invariant_with_context!(#outer_attr #sig; CLOSURE)`, the form
5781
/// `#[thrust_macros::invariant_context]` rewrites each `invariant!` into.
5882
pub fn expand_with_context(input: TokenStream) -> TokenStream {
@@ -75,7 +99,11 @@ pub fn expand_with_context(input: TokenStream) -> TokenStream {
7599
}
76100
}
77101

78-
let WithContext { closure, context } = parse_macro_input!(input as WithContext);
102+
let input = wrap_closure_body(input.into());
103+
let WithContext { closure, context } = match syn::parse2::<WithContext>(input) {
104+
Ok(parsed) => parsed,
105+
Err(e) => return e.to_compile_error().into(),
106+
};
79107
match expand_invariant(&closure, Some(&context)) {
80108
Ok(expr) => expr.into_token_stream().into(),
81109
Err(e) => e.to_compile_error().into(),

thrust-macros/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2};
33

44
mod context;
55
mod fn_outer_item;
6+
mod formula;
67
mod formula_fn_type_lowering;
78
mod invariant;
89
mod invariant_context;
@@ -32,6 +33,12 @@ pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream {
3233
context::expand(item)
3334
}
3435

36+
/// Preprocesses a formula body (see [`mod@formula`]); not written by hand.
37+
#[proc_macro]
38+
pub fn formula(input: TokenStream) -> TokenStream {
39+
formula::expand(input.into()).into()
40+
}
41+
3542
/// Declares a loop invariant inside a loop body:
3643
///
3744
/// ```ignore

0 commit comments

Comments
 (0)