Skip to content

Commit b4844c6

Browse files
committed
Support Self and generic variables in loop invariants
Extend `thrust_macros::invariant!` to invariants over generic- and `Self`-typed variables, by extending `#[thrust_macros::context]` to thread the surrounding context into each `invariant!` it encounters. `context` (which already stamps the enclosing `impl`/`trait` header onto methods so method-level `requires`/`ensures` recover the outer generics) now also accepts free functions and modules and threads the in-scope generics, where-predicates, and (in methods) `Self` into every `invariant!`. The macro crate is reorganized into `context`, `invariant`, and `spec` modules; the root keeps only the proc-macro entry points and the few shared helpers. `invariant!` accepts a threaded `[self] [params] [wheres] <closure>` input. The generated formula function re-declares the threaded generics (shadowing the enclosing ones) and is instantiated via turbofish; in methods, `Self` is re-declared the same way as a synthetic type parameter and instantiated with the real `Self` (legal in expression position). Adds pass/fail UI test pairs for the generic and `Self` cases. Note that a struct used as an invariant variable must model structurally (its `Model::Ty` must mirror its fields, e.g. a one-field tuple struct models as a one-tuple); a mismatching model trips an existing subtyping panic unrelated to this change. https://claude.ai/code/session_01WB28auaD8dSQrckqBwJWBt
1 parent c1aebbd commit b4844c6

9 files changed

Lines changed: 1223 additions & 742 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
#[thrust_macros::context]
10+
fn keep<T: Copy + PartialEq>(v: T) {
11+
let mut x = v;
12+
while rand() == 0 {
13+
thrust_macros::invariant!(|v: T| v == v);
14+
x = v;
15+
}
16+
assert!(x == v);
17+
}
18+
19+
fn main() {
20+
keep(0_i64);
21+
keep(true);
22+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
struct Counter(i64);
10+
impl thrust_models::Model for Counter {
11+
type Ty = (thrust_models::model::Int,);
12+
}
13+
14+
#[thrust_macros::context]
15+
impl Counter {
16+
fn run(self) {
17+
let mut c = self;
18+
let mut x = 1_i64;
19+
while rand() == 0 {
20+
thrust_macros::invariant!(|x: i64, c: Self| x >= 2 && c == c);
21+
x = x + 1;
22+
c = Counter(0);
23+
}
24+
let _last = c;
25+
assert!(x >= 1);
26+
}
27+
}
28+
29+
fn main() { Counter(0).run(); }
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//@check-pass
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
#[thrust_macros::context]
10+
fn keep<T: Copy + PartialEq>(v: T) {
11+
let mut x = v;
12+
while rand() == 0 {
13+
thrust_macros::invariant!(|x: T, v: T| x == v);
14+
x = v;
15+
}
16+
assert!(x == v);
17+
}
18+
19+
fn main() {
20+
keep(0_i64);
21+
keep(true);
22+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@check-pass
2+
//@compile-flags: -C debug-assertions=off
3+
4+
// A concrete, non-generic invariant works without `#[thrust_macros::context]`.
5+
6+
#[thrust_macros::requires(true)]
7+
#[thrust_macros::ensures(true)]
8+
#[thrust::trusted]
9+
fn rand() -> i64 { unimplemented!() }
10+
11+
fn main() {
12+
let mut x = 1_i64;
13+
while rand() == 0 {
14+
thrust_macros::invariant!(|x: i64| x >= 1);
15+
x = x + 1;
16+
}
17+
assert!(x >= 1);
18+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//@check-pass
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
struct Counter(i64);
10+
impl thrust_models::Model for Counter {
11+
type Ty = (thrust_models::model::Int,);
12+
}
13+
14+
#[thrust_macros::context]
15+
impl Counter {
16+
fn run(self) {
17+
let mut c = self;
18+
let mut x = 1_i64;
19+
while rand() == 0 {
20+
thrust_macros::invariant!(|x: i64, c: Self| x >= 1 && c == c);
21+
x = x + 1;
22+
c = Counter(0);
23+
}
24+
let _last = c;
25+
assert!(x >= 1);
26+
}
27+
}
28+
29+
fn main() { Counter(0).run(); }

thrust-macros/src/context.rs

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
//! Expansion of `#[thrust_macros::context]`.
2+
//!
3+
//! `context` supplies the surrounding context that thrust annotations within an
4+
//! item cannot see by themselves: it stamps the enclosing `impl`/`trait` header
5+
//! onto each method (so method-level `requires`/`ensures` recover the outer
6+
//! generics) and threads that generic context into every `invariant!(...)`.
7+
8+
use proc_macro::TokenStream;
9+
use proc_macro2::TokenStream as TokenStream2;
10+
use quote::{quote, ToTokens};
11+
use syn::{parse_macro_input, GenericParam, Generics, WherePredicate};
12+
13+
use crate::{tokens_contain_self, FnOuterItem};
14+
15+
pub(super) fn expand(item: TokenStream) -> TokenStream {
16+
let mut item = parse_macro_input!(item as syn::Item);
17+
process_context_item(&mut item);
18+
item.into_token_stream().into()
19+
}
20+
21+
/// Stamps outer context onto methods and threads the generic context into the
22+
/// invariants of every function body in the item (recursing through modules).
23+
fn process_context_item(item: &mut syn::Item) {
24+
match item {
25+
syn::Item::Fn(item_fn) => {
26+
let generics = item_fn.sig.generics.clone();
27+
let threaded = thread_invariants(&mut item_fn.block, &generics, None);
28+
if threaded.found {
29+
inject_model_bounds(&mut item_fn.sig.generics, None, false);
30+
}
31+
}
32+
syn::Item::Impl(item_impl) => {
33+
let outer = FnOuterItem::ItemImpl(item_impl.clone()).into_header_only();
34+
for impl_item in &mut item_impl.items {
35+
let syn::ImplItem::Fn(method) = impl_item else {
36+
continue;
37+
};
38+
method
39+
.attrs
40+
.push(syn::parse_quote!(#[thrust::_outer_context(#outer)]));
41+
let generics = method.sig.generics.clone();
42+
let threaded = thread_invariants(&mut method.block, &generics, Some(&outer));
43+
if threaded.found {
44+
inject_model_bounds(&mut method.sig.generics, Some(&outer), threaded.self_used);
45+
}
46+
}
47+
}
48+
syn::Item::Trait(item_trait) => {
49+
let outer = FnOuterItem::ItemTrait(item_trait.clone()).into_header_only();
50+
for trait_item in &mut item_trait.items {
51+
let syn::TraitItem::Fn(method) = trait_item else {
52+
continue;
53+
};
54+
method
55+
.attrs
56+
.push(syn::parse_quote!(#[thrust::_outer_context(#outer)]));
57+
if let Some(block) = &mut method.default {
58+
let generics = method.sig.generics.clone();
59+
let threaded = thread_invariants(block, &generics, Some(&outer));
60+
if threaded.found {
61+
inject_model_bounds(
62+
&mut method.sig.generics,
63+
Some(&outer),
64+
threaded.self_used,
65+
);
66+
}
67+
}
68+
}
69+
}
70+
syn::Item::Mod(item_mod) => {
71+
if let Some((_, items)) = &mut item_mod.content {
72+
for inner in items {
73+
process_context_item(inner);
74+
}
75+
}
76+
}
77+
_ => {}
78+
}
79+
}
80+
81+
struct Threaded {
82+
found: bool,
83+
self_used: bool,
84+
}
85+
86+
/// Prepends the generic context to every `invariant!(...)` in a function body.
87+
fn thread_invariants(
88+
block: &mut syn::Block,
89+
generics: &Generics,
90+
outer: Option<&FnOuterItem>,
91+
) -> Threaded {
92+
use syn::visit_mut::VisitMut as _;
93+
94+
let context = invariant_context_tokens(generics, outer);
95+
let mut threader = InvariantThreader {
96+
context,
97+
is_method: outer.is_some(),
98+
found: false,
99+
self_used: false,
100+
};
101+
threader.visit_block_mut(block);
102+
Threaded {
103+
found: threader.found,
104+
self_used: threader.self_used,
105+
}
106+
}
107+
108+
/// Builds the `[generic-params] [where-predicates]` prefix that `invariant!`
109+
/// consumes: every generic parameter in scope (the function's own and, for
110+
/// methods, the outer ones), the existing where predicates, and the
111+
/// `Model`/`PartialEq` bounds those parameters require.
112+
fn invariant_context_tokens(generics: &Generics, outer: Option<&FnOuterItem>) -> TokenStream2 {
113+
let mut params: Vec<GenericParam> = generics.params.iter().cloned().collect();
114+
let mut preds: Vec<WherePredicate> = generics
115+
.where_clause
116+
.as_ref()
117+
.map(|wc| wc.predicates.iter().cloned().collect())
118+
.unwrap_or_default();
119+
if let Some(outer) = outer {
120+
params.extend(outer.generics().params.iter().cloned());
121+
if let Some(wc) = &outer.generics().where_clause {
122+
preds.extend(wc.predicates.iter().cloned());
123+
}
124+
}
125+
for param in &params {
126+
if let GenericParam::Type(tp) = param {
127+
let ident = &tp.ident;
128+
preds.push(syn::parse_quote!(#ident: thrust_models::Model));
129+
preds.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq));
130+
}
131+
}
132+
quote! { [#(#params),*] [#(#preds),*] }
133+
}
134+
135+
struct InvariantThreader {
136+
context: TokenStream2,
137+
is_method: bool,
138+
found: bool,
139+
self_used: bool,
140+
}
141+
142+
impl syn::visit_mut::VisitMut for InvariantThreader {
143+
fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
144+
syn::visit_mut::visit_macro_mut(self, mac);
145+
if is_invariant_macro(&mac.path) {
146+
self.found = true;
147+
// An invariant in a method may name `Self` in its variable types.
148+
// A nested item cannot, so signal `invariant!` to re-declare `Self`
149+
// as a synthetic generic, but only when it is actually used (so we
150+
// do not over-constrain the host method with `Self: Model`).
151+
let uses_self = self.is_method && tokens_contain_self(&mac.tokens);
152+
self.self_used |= uses_self;
153+
let self_marker = if uses_self { quote!(Self) } else { quote!() };
154+
let context = &self.context;
155+
let original = &mac.tokens;
156+
mac.tokens = quote! { [#self_marker] #context #original };
157+
}
158+
}
159+
}
160+
161+
/// Adds `T: Model` and `<T as Model>::Ty: PartialEq` bounds for every type
162+
/// parameter in scope to a function's where clause. The marker call generated
163+
/// for an invariant instantiates a `Model`-bounded formula function, so the
164+
/// function hosting the call must itself satisfy those bounds. When an
165+
/// invariant names `Self`, `invariant!` instantiates the formula function with
166+
/// `Self`, so the same bounds are added for `Self` (`with_self`).
167+
fn inject_model_bounds(generics: &mut Generics, outer: Option<&FnOuterItem>, with_self: bool) {
168+
let mut tys: Vec<TokenStream2> = generics
169+
.params
170+
.iter()
171+
.filter_map(|p| match p {
172+
GenericParam::Type(tp) => Some(tp.ident.to_token_stream()),
173+
_ => None,
174+
})
175+
.collect();
176+
if let Some(outer) = outer {
177+
for param in &outer.generics().params {
178+
if let GenericParam::Type(tp) = param {
179+
tys.push(tp.ident.to_token_stream());
180+
}
181+
}
182+
}
183+
if with_self {
184+
tys.push(quote!(Self));
185+
}
186+
if tys.is_empty() {
187+
return;
188+
}
189+
let where_clause = generics.make_where_clause();
190+
for ty in tys {
191+
where_clause
192+
.predicates
193+
.push(syn::parse_quote!(#ty: thrust_models::Model));
194+
where_clause
195+
.predicates
196+
.push(syn::parse_quote!(<#ty as thrust_models::Model>::Ty: PartialEq));
197+
}
198+
}
199+
200+
fn is_invariant_macro(path: &syn::Path) -> bool {
201+
path.segments.last().is_some_and(|s| s.ident == "invariant")
202+
}

0 commit comments

Comments
 (0)