Skip to content

Commit 5a22dd0

Browse files
committed
Support logical implication (==>) in annotations
Add a single formula-token preprocessing layer, the `thrust_macros::formula!` proc-macro, that every annotation wraps its formula body in. Its first pass desugars the implication operator `a ==> b` into `!a || b`: - `==>` is rewritten to assignment (the lowest-precedence, right-associative Rust operator) so `syn` reproduces implication's precedence/associativity for free, including inside closure bodies; - each resulting assignment node is rewritten to `(!(lhs)) || (rhs)` so the generated `#[thrust::formula_fn]` body is valid boolean Rust. `requires`/`ensures`, `predicate`, the `param`/`ret`/`sig` refinement types, and `invariant!` all route their formula through `formula!`. Closes #106. Reject bare `=` in formula! to avoid ambiguity with implication `==>` is desugared to assignment, so once parsed an `Expr::Assign` from a genuinely-written `=` is indistinguishable from an implication. Reject a bare `=` (a lone `Punct('=')` with Alone spacing, not preceded by a joint punct, so `==`/`!=`/`<=`/`>=`/`==>` are unaffected) up front with a clear diagnostic. Also enable syn's `extra-traits` feature explicitly (the crate compares `syn::Path` values), so the crate builds and its unit tests run standalone. https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL Document why formula wrapping isolates the body/tail expression https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL Address review: trim comments, dedupe self-rewrite, test expand() - shorten and de-duplicate doc/inline comments - call crate::formula::wrap_formula directly instead of importing it - rename rewrite_self_in_macro_tokens to rewrite_self_in_tokens and reuse it from rty, dropping rty's duplicate copy - unit-test the public formula::expand, not only reject_bare_assignment https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL cargo fmt https://claude.ai/code/session_017DdTwmnyZefiWR7zfcKsEL style fix
1 parent 624e9e9 commit 5a22dd0

8 files changed

Lines changed: 413 additions & 24 deletions

File tree

tests/ui/fail/annot_implication.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper
4+
5+
use thrust_models::exists;
6+
7+
#[thrust::trusted]
8+
#[thrust::callable]
9+
fn rand() -> i64 {
10+
unimplemented!()
11+
}
12+
13+
// Same contract as the `pass` counterpart, but the body returns a negative
14+
// value when `x > 0`, so the implication `(x > 0) ==> (result > 0)` is
15+
// violated. This confirms `==>` carries real implication semantics rather than
16+
// being silently accepted.
17+
#[thrust_macros::ensures((x > 0) ==> (result > 0))]
18+
fn f(x: i64) -> i64 {
19+
if x > 0 {
20+
-1
21+
} else {
22+
0
23+
}
24+
}
25+
26+
// An unparenthesized chain must parse right-associatively, i.e.
27+
// `(x > 10) ==> ((x > 5) ==> (result == 1))`. Left-associative parsing would
28+
// make this unprovable, so this case pins down associativity.
29+
#[thrust_macros::ensures((x > 10) ==> (x > 5) ==> (result == 1))]
30+
fn g(x: i64) -> i64 {
31+
if x > 5 {
32+
1
33+
} else {
34+
0
35+
}
36+
}
37+
38+
// Implication nested inside a quantifier's closure body.
39+
#[thrust_macros::ensures(exists(|y: i64| (1 == 1) ==> (result == 2 * y)))]
40+
fn k() -> i64 {
41+
let x = rand();
42+
x + x
43+
}
44+
45+
fn main() {}

tests/ui/pass/annot_implication.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//@check-pass
2+
//@compile-flags: -C debug-assertions=off
3+
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper
4+
5+
use thrust_models::exists;
6+
7+
#[thrust::trusted]
8+
#[thrust::callable]
9+
fn rand() -> i64 {
10+
unimplemented!()
11+
}
12+
13+
// Basic implication in a postcondition: when the antecedent holds the
14+
// consequent must too; otherwise the obligation is vacuous.
15+
#[thrust_macros::ensures((x > 0) ==> (result > 0))]
16+
fn f(x: i64) -> i64 {
17+
if x > 0 {
18+
x
19+
} else {
20+
0
21+
}
22+
}
23+
24+
// An unparenthesized chain must parse right-associatively, i.e.
25+
// `(x > 10) ==> ((x > 5) ==> (result == 1))`. Left-associative parsing would
26+
// make this unprovable, so this case pins down associativity.
27+
#[thrust_macros::ensures((x > 10) ==> (x > 5) ==> (result == 1))]
28+
fn g(x: i64) -> i64 {
29+
if x > 5 {
30+
1
31+
} else {
32+
0
33+
}
34+
}
35+
36+
// Implication nested inside a quantifier's closure body.
37+
#[thrust_macros::ensures(exists(|y: i64| (1 == 1) ==> (result == 2 * y)))]
38+
fn k() -> i64 {
39+
let x = rand();
40+
x + x
41+
}
42+
43+
fn main() {}

thrust-macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ proc-macro = true
1010
[dependencies]
1111
proc-macro2 = "1"
1212
quote = "1"
13-
syn = { version = "2", features = ["full", "visit", "visit-mut"] }
13+
syn = { version = "2", features = ["extra-traits", "full", "visit", "visit-mut"] }

thrust-macros/src/formula.rs

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
//! The `thrust_macros::formula!` proc-macro: the single preprocessing layer for
2+
//! formula tokens.
3+
//!
4+
//! Annotation macros wrap their formula body in `thrust_macros::formula!(...)`
5+
//! instead of splicing it raw. This hides syntax that is not valid Rust (the
6+
//! `==>` operator) inside a macro call's arguments, so the surrounding pipeline —
7+
//! which parses formulas as [`syn::Expr`] — never chokes on it, and gives one
8+
//! place to run preprocessing before the body reaches rustc / HIR lowering.
9+
//!
10+
//! The only pass today is implication desugaring; further passes can be appended
11+
//! in [`expand`].
12+
13+
use proc_macro2::{Group, Punct, Spacing, TokenStream, TokenTree};
14+
use quote::{quote, ToTokens};
15+
use syn::visit_mut::VisitMut;
16+
17+
/// Wraps formula tokens in a `thrust_macros::formula!(...)` call.
18+
pub fn wrap_expr(tokens: TokenStream) -> TokenStream {
19+
quote!(::thrust_macros::formula!(#tokens))
20+
}
21+
22+
/// Wraps an invariant closure's body in `::thrust_macros::formula!(...)`, taking
23+
/// everything past the closure header (the first two top-level `|`) as the body.
24+
/// Only the body is wrapped: the closure must stay a parseable `ExprClosure` for
25+
/// `invariant::expand_invariant` to read its `.inputs`/`.body`, and the context
26+
/// form's threaded signature must not be preprocessed. Non-closure inputs pass
27+
/// through.
28+
///
29+
/// A closure with an explicit return type (`|x| -> Ty { .. }`) must keep its body
30+
/// a block, so the `-> Ty {` / `}` is preserved and only the block's tail
31+
/// expression is wrapped.
32+
pub fn wrap_closure_body(input: TokenStream) -> TokenStream {
33+
let tokens: Vec<TokenTree> = input.into_iter().collect();
34+
let bars: Vec<usize> = tokens
35+
.iter()
36+
.enumerate()
37+
.filter(|(_, tt)| matches!(tt, TokenTree::Punct(p) if p.as_char() == '|'))
38+
.map(|(i, _)| i)
39+
.collect();
40+
let [_open, close, ..] = bars[..] else {
41+
return tokens.into_iter().collect();
42+
};
43+
let header: TokenStream = tokens[..=close].iter().cloned().collect();
44+
let tail = &tokens[close + 1..];
45+
46+
// `-> Ty { body }`: wrap the block's tail expression in place.
47+
let has_return_type = matches!(tail.first(), Some(TokenTree::Punct(p)) if p.as_char() == '-')
48+
&& matches!(tail.get(1), Some(TokenTree::Punct(p)) if p.as_char() == '>');
49+
if has_return_type {
50+
if let Some((TokenTree::Group(block), prefix)) = tail.split_last() {
51+
if block.delimiter() == proc_macro2::Delimiter::Brace {
52+
let prefix: TokenStream = prefix.iter().cloned().collect();
53+
let body = wrap_block_tail(block);
54+
return quote!(#header #prefix #body);
55+
}
56+
}
57+
}
58+
59+
// `|args| body`: the body is a bare expression; wrap it whole.
60+
let body = wrap_expr(tail.iter().cloned().collect());
61+
quote!(#header #body)
62+
}
63+
64+
/// Returns a brace group like `block` but with its tail expression (the final
65+
/// `;`-separated segment) wrapped in `::thrust_macros::formula!(...)`. Leading
66+
/// statements are left untouched; a block with no tail expression is unchanged.
67+
fn wrap_block_tail(block: &Group) -> TokenStream {
68+
let stmts: Vec<TokenTree> = block.stream().into_iter().collect();
69+
let mut segments: Vec<Vec<TokenTree>> = vec![Vec::new()];
70+
for tt in &stmts {
71+
segments.last_mut().unwrap().push(tt.clone());
72+
if matches!(tt, TokenTree::Punct(p) if p.as_char() == ';') {
73+
segments.push(Vec::new());
74+
}
75+
}
76+
let tail = segments.pop().unwrap();
77+
78+
let mut inner = TokenStream::new();
79+
for seg in &segments {
80+
inner.extend(seg.iter().cloned());
81+
}
82+
if !tail.is_empty() {
83+
inner.extend(wrap_expr(tail.into_iter().collect()));
84+
}
85+
86+
let mut wrapped = Group::new(proc_macro2::Delimiter::Brace, inner);
87+
wrapped.set_span(block.span());
88+
quote!(#wrapped)
89+
}
90+
91+
/// Expands `formula!(<tokens>)` into the preprocessed boolean expression.
92+
pub fn expand(input: TokenStream) -> TokenStream {
93+
if let Err(e) = reject_bare_assignment(&input) {
94+
return e.to_compile_error();
95+
}
96+
97+
// `==>` is desugared to assignment (the lowest-precedence, right-associative
98+
// operator) so `syn` reproduces its precedence, then each assignment node is
99+
// rewritten into `!lhs || rhs`.
100+
let desugared = desugar_arrows(input);
101+
let mut expr: syn::Expr = match syn::parse2(desugared) {
102+
Ok(expr) => expr,
103+
Err(e) => return e.to_compile_error(),
104+
};
105+
106+
// Rewrites each assignment `lhs = rhs` (produced by [`desugar_arrows`] from
107+
// `lhs ==> rhs`) into `(!(lhs)) || (rhs)`. Visiting post-order means nested
108+
// implications are rewritten innermost-first, so the right-associative chain
109+
// `a ==> b ==> c` becomes `!a || (!b || c)`.
110+
struct ImplicationRewriter;
111+
112+
impl VisitMut for ImplicationRewriter {
113+
fn visit_expr_mut(&mut self, expr: &mut syn::Expr) {
114+
syn::visit_mut::visit_expr_mut(self, expr);
115+
if let syn::Expr::Assign(assign) = expr {
116+
let left = &assign.left;
117+
let right = &assign.right;
118+
*expr = syn::parse_quote!((!(#left)) || (#right));
119+
}
120+
}
121+
}
122+
123+
ImplicationRewriter.visit_expr_mut(&mut expr);
124+
125+
expr.into_token_stream()
126+
}
127+
128+
/// Rejects a bare assignment `=` anywhere in `input`. Desugaring turns `==>` into
129+
/// `=`, so a genuine `=` would be read as an implication; since assignment is
130+
/// meaningless in a formula we reject it up front. A genuine `=` is a lone
131+
/// `Punct('=')` with [`Spacing::Alone`] not preceded by a joint punct, which
132+
/// excludes the `=`s of `==`, `!=`, `<=`, `>=`, and `==>`.
133+
fn reject_bare_assignment(input: &TokenStream) -> syn::Result<()> {
134+
let mut prev_joint = false;
135+
for tt in input.clone() {
136+
match tt {
137+
TokenTree::Group(g) => {
138+
reject_bare_assignment(&g.stream())?;
139+
prev_joint = false;
140+
}
141+
TokenTree::Punct(p) => {
142+
if p.as_char() == '=' && p.spacing() == Spacing::Alone && !prev_joint {
143+
return Err(syn::Error::new(
144+
p.span(),
145+
"`=` is not allowed in a formula; use `==` for equality or `==>` for implication",
146+
));
147+
}
148+
prev_joint = p.spacing() == Spacing::Joint;
149+
}
150+
_ => prev_joint = false,
151+
}
152+
}
153+
Ok(())
154+
}
155+
156+
/// Replaces every contiguous `==>` (the puncts `=`, `=`, `>`) with a single `=`,
157+
/// recursing into every delimiter group. `==`, `=>`, `>=`, and `->` do not match
158+
/// the triple and are left untouched.
159+
fn desugar_arrows(input: TokenStream) -> TokenStream {
160+
let mut out = TokenStream::new();
161+
let mut iter = input.into_iter().peekable();
162+
while let Some(tt) = iter.next() {
163+
match tt {
164+
TokenTree::Group(group) => {
165+
let inner = desugar_arrows(group.stream());
166+
let mut new_group = Group::new(group.delimiter(), inner);
167+
new_group.set_span(group.span());
168+
out.extend([TokenTree::Group(new_group)]);
169+
}
170+
TokenTree::Punct(p) if p.as_char() == '=' && p.spacing() == Spacing::Joint => {
171+
// Look for a contiguous `==>`: `p` is the first `=`, joint to a
172+
// second joint `=`, then a `>`. Requiring joint spacing avoids
173+
// matching a spaced-out `= = >` / `== >`.
174+
if let Some(TokenTree::Punct(p2)) = iter.peek() {
175+
if p2.as_char() == '=' && p2.spacing() == Spacing::Joint {
176+
let mut lookahead = iter.clone();
177+
lookahead.next(); // consume the second `=`
178+
if let Some(TokenTree::Punct(p3)) = lookahead.peek() {
179+
if p3.as_char() == '>' {
180+
// Matched `==>`; consume `=` and `>`, emit `=`.
181+
iter.next(); // second `=`
182+
iter.next(); // `>`
183+
let mut eq = Punct::new('=', Spacing::Alone);
184+
eq.set_span(p.span());
185+
out.extend([TokenTree::Punct(eq)]);
186+
continue;
187+
}
188+
}
189+
}
190+
}
191+
out.extend([TokenTree::Punct(p)]);
192+
}
193+
other => out.extend([other]),
194+
}
195+
}
196+
out
197+
}
198+
199+
#[cfg(test)]
200+
mod tests {
201+
use super::*;
202+
203+
/// `expand`s `s`, then renders it through `syn` so spacing matches `expect`.
204+
fn expand_expr(s: &str) -> String {
205+
let expr: syn::Expr = syn::parse2(expand(s.parse().unwrap())).unwrap();
206+
quote!(#expr).to_string()
207+
}
208+
209+
fn expect(s: &str) -> String {
210+
let expr: syn::Expr = syn::parse_str(s).unwrap();
211+
quote!(#expr).to_string()
212+
}
213+
214+
#[test]
215+
fn desugars_implication() {
216+
assert_eq!(expand_expr("a ==> b"), expect("(!(a)) || (b)"));
217+
// right-associative
218+
assert_eq!(
219+
expand_expr("a ==> b ==> c"),
220+
expect("(!(a)) || ((!(b)) || (c))")
221+
);
222+
// lower precedence than `||` and `==`
223+
assert_eq!(
224+
expand_expr("a || b ==> c == d"),
225+
expect("(!(a || b)) || (c == d)")
226+
);
227+
// nested inside a closure argument
228+
assert_eq!(
229+
expand_expr("exists(|x| a ==> b)"),
230+
expect("exists(|x| (!(a)) || (b))")
231+
);
232+
}
233+
234+
#[test]
235+
fn leaves_comparisons_untouched() {
236+
for s in ["a == b", "a != b", "a >= b", "a <= b"] {
237+
assert_eq!(expand_expr(s), expect(s), "{s}");
238+
}
239+
}
240+
241+
#[test]
242+
fn rejects_bare_assignment() {
243+
// `expand` emits a `compile_error!` instead of a valid expression.
244+
for s in ["a = b", "f(x = y)"] {
245+
assert!(
246+
expand(s.parse().unwrap())
247+
.to_string()
248+
.contains("compile_error"),
249+
"{s}"
250+
);
251+
}
252+
}
253+
254+
/// `wrap_closure_body` must leave the input a parseable `ExprClosure`, both
255+
/// for a bare-expression body and one with an explicit return type.
256+
fn assert_wraps_to_closure(s: &str) {
257+
let wrapped = wrap_closure_body(s.parse().unwrap());
258+
syn::parse2::<syn::ExprClosure>(wrapped).unwrap_or_else(|e| panic!("{s}: {e}"));
259+
}
260+
261+
#[test]
262+
fn preserves_closure_forms() {
263+
assert_wraps_to_closure("|x: i64| x >= 1 ==> x >= 0");
264+
assert_wraps_to_closure("|x: i64| -> bool { x >= 1 ==> x >= 0 }");
265+
}
266+
}

thrust-macros/src/invariant.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ use proc_macro2::TokenStream as TokenStream2;
3434
use quote::{format_ident, quote, ToTokens};
3535
use syn::{
3636
parse::{Parse, ParseStream},
37-
parse_macro_input,
3837
visit_mut::VisitMut,
3938
FnArg, GenericParam, Signature, WherePredicate,
4039
};
@@ -46,7 +45,11 @@ static COUNTER: AtomicUsize = AtomicUsize::new(0);
4645
/// Expands `invariant!(CLOSURE)`: a bare predicate closure with no threaded
4746
/// context.
4847
pub fn expand(input: TokenStream) -> TokenStream {
49-
let closure = parse_macro_input!(input as syn::ExprClosure);
48+
let input = crate::formula::wrap_closure_body(input.into());
49+
let closure = match syn::parse2::<syn::ExprClosure>(input) {
50+
Ok(closure) => closure,
51+
Err(e) => return e.to_compile_error().into(),
52+
};
5053
match expand_invariant(&closure, None) {
5154
Ok(expr) => expr.into_token_stream().into(),
5255
Err(e) => e.to_compile_error().into(),
@@ -75,7 +78,11 @@ pub fn expand_with_context(input: TokenStream) -> TokenStream {
7578
}
7679
}
7780

78-
let WithContext { closure, context } = parse_macro_input!(input as WithContext);
81+
let input = crate::formula::wrap_closure_body(input.into());
82+
let WithContext { closure, context } = match syn::parse2::<WithContext>(input) {
83+
Ok(parsed) => parsed,
84+
Err(e) => return e.to_compile_error().into(),
85+
};
7986
match expand_invariant(&closure, Some(&context)) {
8087
Ok(expr) => expr.into_token_stream().into(),
8188
Err(e) => e.to_compile_error().into(),

0 commit comments

Comments
 (0)