Skip to content

Commit 3d93645

Browse files
committed
Add loop invariant annotations (minimal)
Introduce `thrust_macros::invariant!`, a Prusti-style loop invariant placed inside the loop body: fn f() { while cond { thrust_macros::invariant!(|x: i64, y: i64| x >= 1 && y >= 1); ... } } `invariant!(|x: T, ...| pred)` expands to a `#[thrust::formula_fn]` over `Model::Ty` parameters and a marker call (`thrust_models::__invariant_marker`) referencing it, so invariants share the static semantics of `requires`/`ensures`. The closure parameters name the live variables the invariant constrains. Generic- and `Self`-typed variables are out of scope here (the in-body macro cannot see the enclosing generics, and `Self` cannot be named inside a nested formula function); support for those will follow by extending `#[thrust_macros::context]` to thread that context in. Analyzer: - Collect invariant markers, map each to its enclosing loop header, and turn the formula function into that header's precondition by binding each named formula parameter to the matching live basic-block parameter. - Resolving a name to a live local errors when several distinct live locals share it (e.g. shadowing) rather than silently picking one. - The marker terminator is elaborated to a plain goto so it is not type-checked as a real call. https://claude.ai/code/session_01WB28auaD8dSQrckqBwJWBt
1 parent e25bfb0 commit 3d93645

16 files changed

Lines changed: 439 additions & 1 deletion

src/analyze/annot.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ pub fn exists_path() -> [Symbol; 3] {
141141
]
142142
}
143143

144+
pub fn invariant_marker_path() -> [Symbol; 3] {
145+
[
146+
Symbol::intern("thrust"),
147+
Symbol::intern("def"),
148+
Symbol::intern("invariant_marker"),
149+
]
150+
}
151+
144152
/// A [`annot::Resolver`] implementation for resolving function parameters.
145153
///
146154
/// The parameter names and their sorts needs to be configured via

src/analyze/annot_fn.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ where
4040
}
4141

4242
impl<'tcx> FormulaFn<'tcx> {
43+
pub fn formula(&self) -> &chc::Formula<rty::FunctionParamIdx> {
44+
&self.formula
45+
}
46+
4347
pub fn to_require_annot(&self) -> AnnotFormula<rty::FunctionParamIdx> {
4448
AnnotFormula::Formula(self.formula.clone())
4549
}

src/analyze/basic_block.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,19 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
10241024
None
10251025
}
10261026

1027+
/// Whether this block's terminator is a loop-invariant marker call.
1028+
fn terminator_is_invariant_marker(&self) -> Option<BasicBlock> {
1029+
let term = &self.body.basic_blocks[self.basic_block].terminator().kind;
1030+
if let TerminatorKind::Call { func, target, .. } = term {
1031+
if let Some((def_id, _)) = func.const_fn_def() {
1032+
if Some(def_id) == self.ctx.def_ids().invariant_marker() {
1033+
return Some(target.expect("invariant marker call must have a target"));
1034+
}
1035+
}
1036+
}
1037+
None
1038+
}
1039+
10271040
fn analyze_statements(&mut self) {
10281041
for local in self.drop_points.before_statements.clone() {
10291042
tracing::info!(?local, "implicitly dropped before statements");
@@ -1065,6 +1078,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
10651078
source_info: term.source_info,
10661079
};
10671080
}
1081+
if let Some(target) = self.terminator_is_invariant_marker() {
1082+
tracing::debug!(?term, "skip invariant marker");
1083+
return mir::Terminator {
1084+
kind: TerminatorKind::Goto { target },
1085+
source_info: term.source_info,
1086+
};
1087+
}
10681088
self.rust_call_visitor().visit_terminator(&mut term);
10691089
self.reborrow_visitor().visit_terminator(&mut term);
10701090
tracing::debug!(term = ?term.kind);

src/analyze/did_cache.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct DefIds {
2525
array_model_store: OnceCell<Option<DefId>>,
2626

2727
exists: OnceCell<Option<DefId>>,
28+
invariant_marker: OnceCell<Option<DefId>>,
2829
}
2930

3031
/// Retrieves and caches well-known [`DefId`]s.
@@ -176,4 +177,11 @@ impl<'tcx> DefIdCache<'tcx> {
176177
.exists
177178
.get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path()))
178179
}
180+
181+
pub fn invariant_marker(&self) -> Option<DefId> {
182+
*self
183+
.def_ids
184+
.invariant_marker
185+
.get_or_init(|| self.annotated_def(&crate::analyze::annot::invariant_marker_path()))
186+
}
179187
}

src/analyze/local_def.rs

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,132 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
805805
}
806806
}
807807

808+
/// Scans the body for loop-invariant marker calls and maps, for each
809+
/// enclosing loop header, the formula function carrying the invariant
810+
/// (together with its already-monomorphized generic arguments).
811+
fn collect_loop_invariant_annotations(
812+
&self,
813+
) -> HashMap<BasicBlock, (LocalDefId, mir_ty::GenericArgsRef<'tcx>)> {
814+
let mut loop_invariants = HashMap::new();
815+
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
816+
let Some(term) = &data.terminator else {
817+
continue;
818+
};
819+
let mir::TerminatorKind::Call { func, args, .. } = &term.kind else {
820+
continue;
821+
};
822+
let Some((def_id, _)) = func.const_fn_def() else {
823+
continue;
824+
};
825+
if Some(def_id) != self.ctx.def_ids().invariant_marker() {
826+
continue;
827+
}
828+
829+
let arg_ty = args[0].node.ty(&self.body.local_decls, self.tcx);
830+
let mir_ty::TyKind::FnDef(formula_def_id, generic_args) = arg_ty.kind() else {
831+
panic!("invariant marker argument must be a formula function item");
832+
};
833+
let formula_def_id = formula_def_id
834+
.as_local()
835+
.expect("invariant formula function must be local");
836+
let header = Self::loop_header_of(&self.body, bb).unwrap_or_else(|| {
837+
panic!("no enclosing loop header for invariant marker at {bb:?}")
838+
});
839+
loop_invariants.insert(header, (formula_def_id, *generic_args));
840+
}
841+
loop_invariants
842+
}
843+
844+
/// Walks up the dominator tree from the marker block to the innermost
845+
/// enclosing loop header: the first dominator that needs its own
846+
/// precondition (in-degree >= 2) and has a back edge.
847+
fn loop_header_of(body: &Body<'_>, marker_bb: BasicBlock) -> Option<BasicBlock> {
848+
let doms = body.basic_blocks.dominators();
849+
let preds = body.basic_blocks.predecessors();
850+
let mut cur = Some(marker_bb);
851+
while let Some(bb) = cur {
852+
if analyze::basic_block::needs_own_precondition(body, bb)
853+
&& preds[bb].iter().any(|&p| doms.dominates(bb, p))
854+
{
855+
return Some(bb);
856+
}
857+
cur = doms.immediate_dominator(bb);
858+
}
859+
None
860+
}
861+
862+
/// Resolves the live local matching a source variable name at the given
863+
/// basic block, among the locals that are parameters of `bty`.
864+
///
865+
/// When several distinct live locals share the name (e.g. two shadowed
866+
/// variables that are both loop-carried), the mapping is ambiguous; rather
867+
/// than silently pick one, this raises a fatal error. Disambiguating which
868+
/// shadow an invariant refers to is left as future work.
869+
fn local_of_name_in_bb(&self, name: rustc_span::Symbol, bty: &BasicBlockType) -> Option<Local> {
870+
let mut found: Option<Local> = None;
871+
for vdi in &self.body.var_debug_info {
872+
if vdi.name != name {
873+
continue;
874+
}
875+
let mir::VarDebugInfoContents::Place(place) = vdi.value else {
876+
continue;
877+
};
878+
if !place.projection.is_empty() {
879+
continue;
880+
}
881+
if bty.param_of_local(place.local).is_none() {
882+
continue;
883+
}
884+
match found {
885+
None => found = Some(place.local),
886+
Some(prev) if prev == place.local => {}
887+
Some(_) => self.tcx.dcx().fatal(format!(
888+
"loop invariant refers to `{name}`, which is ambiguous at the loop header: \
889+
multiple live variables share this name (e.g. through shadowing). \
890+
Rename the variables to disambiguate."
891+
)),
892+
}
893+
}
894+
found
895+
}
896+
897+
/// Translates a user-provided loop invariant (a formula function over named
898+
/// live variables) into a precondition refinement over `bty`'s parameters.
899+
/// Each formula parameter names a live variable at the loop header and is
900+
/// mapped to the corresponding basic-block parameter.
901+
fn build_invariant_precondition(
902+
&self,
903+
formula_def_id: LocalDefId,
904+
generic_args: mir_ty::GenericArgsRef<'tcx>,
905+
bty: &BasicBlockType,
906+
) -> rty::Refinement<rty::FunctionParamIdx> {
907+
let formula_fn = self
908+
.ctx
909+
.formula_fn_with_args(formula_def_id, generic_args)
910+
.expect("invariant formula function is not registered");
911+
let idents = self.tcx.fn_arg_idents(formula_def_id.to_def_id());
912+
913+
let mut mapping: Vec<rty::FunctionParamIdx> = Vec::with_capacity(idents.len());
914+
for ident in idents {
915+
let name = ident.expect("invariant parameters must be named").name;
916+
let local = self.local_of_name_in_bb(name, bty).unwrap_or_else(|| {
917+
self.tcx.dcx().fatal(format!(
918+
"loop invariant refers to `{name}`, which is not a live variable at the loop header"
919+
))
920+
});
921+
mapping.push(bty.param_of_local(local).unwrap());
922+
}
923+
924+
formula_fn
925+
.formula()
926+
.clone()
927+
.subst_var(|idx| chc::Term::var(rty::RefinedTypeVar::Free(mapping[idx.index()])))
928+
.into()
929+
}
930+
808931
fn refine_basic_blocks(&mut self) {
809932
use rustc_mir_dataflow::Analysis as _;
933+
let loop_invariants = self.collect_loop_invariant_annotations();
810934
let mut results = rustc_mir_dataflow::impls::MaybeLiveLocals
811935
.iterate_to_fixpoint(self.tcx, &self.body, None)
812936
.into_results_cursor(&self.body);
@@ -851,7 +975,18 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
851975
}
852976
// function return type is basic block return type
853977
let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty;
854-
if analyze::basic_block::needs_own_precondition(&self.body, bb) {
978+
if let Some(&(formula_def_id, generic_args)) = loop_invariants.get(&bb) {
979+
// A user-supplied loop invariant fully replaces inference: build
980+
// the block type without a precondition pvar and install the
981+
// invariant as its precondition.
982+
let mut bty = self
983+
.type_builder
984+
.build_basic_block(&self.body, live_locals, ret_ty);
985+
let inv = self.build_invariant_precondition(formula_def_id, generic_args, &bty);
986+
bty.set_precondition(inv);
987+
self.ctx
988+
.register_basic_block_ty_with_precondition(self.local_def_id, bb, bty);
989+
} else if analyze::basic_block::needs_own_precondition(&self.body, bb) {
855990
let bty = self
856991
.type_builder
857992
.for_template(&mut self.ctx)

std.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ mod thrust_models {
302302
pub fn exists<T>(_x: T) -> bool {
303303
unimplemented!()
304304
}
305+
306+
#[thrust::def::invariant_marker]
307+
#[thrust::ignored]
308+
#[inline(never)]
309+
pub fn __invariant_marker<F>(_f: F) {
310+
unimplemented!()
311+
}
305312
}
306313

307314
#[thrust::extern_spec_fn]

tests/ui/fail/loop_invariant.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
fn main() {
10+
let mut x = 1_i64;
11+
let mut y = 1_i64;
12+
while rand() == 0 {
13+
thrust_macros::invariant!(|x: i64| x >= 1);
14+
let t1 = x;
15+
let t2 = y;
16+
x = t1 + t2;
17+
y = t1 + t2;
18+
}
19+
assert!(y >= 1);
20+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
struct Counter;
10+
11+
impl Counter {
12+
fn run(&self) {
13+
let mut x = 5_i64;
14+
while rand() == 0 {
15+
thrust_macros::invariant!(|x: i64| x >= 1);
16+
x = x - 1;
17+
}
18+
assert!(x >= 1);
19+
}
20+
}
21+
22+
fn main() {
23+
Counter.run();
24+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
fn main() {
10+
let mut x = 5_i64;
11+
let p = &mut x;
12+
while rand() == 0 {
13+
thrust_macros::invariant!(|p: &mut i64| *p >= 1);
14+
*p = *p - 1;
15+
}
16+
assert!(*p >= 1);
17+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//@error-in-other-file: Unsat
2+
//@compile-flags: -C debug-assertions=off
3+
4+
#[thrust_macros::requires(true)]
5+
#[thrust_macros::ensures(true)]
6+
#[thrust::trusted]
7+
fn rand() -> i64 { unimplemented!() }
8+
9+
fn main() {
10+
let mut x = 1_i64;
11+
while rand() == 0 {
12+
let mut y = 1_i64;
13+
while rand() == 0 {
14+
thrust_macros::invariant!(|x: i64| x >= 1);
15+
y = x + y;
16+
}
17+
x = x + y;
18+
}
19+
assert!(x >= 1);
20+
}

0 commit comments

Comments
 (0)