Skip to content

Commit 707a3cd

Browse files
coord-eclaude
authored andcommitted
Represent formula implication explicitly
style fix
1 parent 5a22dd0 commit 707a3cd

10 files changed

Lines changed: 96 additions & 14 deletions

File tree

src/analyze/annot.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ pub fn exists_path() -> [Symbol; 3] {
145145
]
146146
}
147147

148+
pub fn implies_path() -> [Symbol; 3] {
149+
[
150+
Symbol::intern("thrust"),
151+
Symbol::intern("def"),
152+
Symbol::intern("implies"),
153+
]
154+
}
155+
148156
pub fn invariant_marker_path() -> [Symbol; 3] {
149157
[
150158
Symbol::intern("thrust"),

src/analyze/annot_fn.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ enum FormulaOrTerm<T> {
102102
BinOp(chc::Term<T>, AmbiguousBinOp, chc::Term<T>),
103103
And(Box<FormulaOrTerm<T>>, Box<FormulaOrTerm<T>>),
104104
Or(Box<FormulaOrTerm<T>>, Box<FormulaOrTerm<T>>),
105+
Implies(Box<FormulaOrTerm<T>>, Box<FormulaOrTerm<T>>),
105106
Not(Box<FormulaOrTerm<T>>),
106107
Literal(bool),
107108
}
@@ -124,6 +125,7 @@ impl<T> FormulaOrTerm<T> {
124125
}
125126
FormulaOrTerm::And(lhs, rhs) => lhs.into_formula()?.and(rhs.into_formula()?),
126127
FormulaOrTerm::Or(lhs, rhs) => lhs.into_formula()?.or(rhs.into_formula()?),
128+
FormulaOrTerm::Implies(lhs, rhs) => lhs.into_formula()?.implies(rhs.into_formula()?),
127129
FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_formula()?.not(),
128130
FormulaOrTerm::Literal(b) => {
129131
if b {
@@ -148,6 +150,7 @@ impl<T> FormulaOrTerm<T> {
148150
FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Lt, rhs) => lhs.lt(rhs),
149151
FormulaOrTerm::And(lhs, rhs) => lhs.into_term()?.and(rhs.into_term()?),
150152
FormulaOrTerm::Or(lhs, rhs) => lhs.into_term()?.or(rhs.into_term()?),
153+
FormulaOrTerm::Implies(lhs, rhs) => lhs.into_term()?.not().or(rhs.into_term()?),
151154
FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_term()?.not(),
152155
FormulaOrTerm::Literal(b) => chc::Term::bool(b),
153156
};
@@ -607,6 +610,14 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
607610
body_formula,
608611
));
609612
}
613+
if Some(def_id) == self.def_ids.implies() {
614+
let [lhs, rhs] = args else {
615+
panic!("implies takes exactly 2 arguments");
616+
};
617+
let lhs = self.to_formula_or_term(lhs);
618+
let rhs = self.to_formula_or_term(rhs);
619+
return FormulaOrTerm::Implies(lhs.into(), rhs.into());
620+
}
610621
if Some(def_id) == self.def_ids.mut_model_new() {
611622
assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments");
612623
let t1 = self.to_term(&args[0]);

src/analyze/did_cache.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct DefIds {
2525
array_model_store: OnceCell<Option<DefId>>,
2626

2727
exists: OnceCell<Option<DefId>>,
28+
implies: OnceCell<Option<DefId>>,
2829
invariant_marker: OnceCell<Option<DefId>>,
2930

3031
closure_precondition: OnceCell<Option<DefId>>,
@@ -181,6 +182,13 @@ impl<'tcx> DefIdCache<'tcx> {
181182
.get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path()))
182183
}
183184

185+
pub fn implies(&self) -> Option<DefId> {
186+
*self
187+
.def_ids
188+
.implies
189+
.get_or_init(|| self.annotated_def(&crate::analyze::annot::implies_path()))
190+
}
191+
184192
pub fn invariant_marker(&self) -> Option<DefId> {
185193
*self
186194
.def_ids

src/chc.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,7 @@ pub enum Formula<V = TermVarIdx> {
12511251
Not(Box<Formula<V>>),
12521252
And(Vec<Formula<V>>),
12531253
Or(Vec<Formula<V>>),
1254+
Implies(Box<Formula<V>>, Box<Formula<V>>),
12541255
Exists(Vec<(String, Sort)>, Box<Formula<V>>),
12551256
}
12561257

@@ -1293,6 +1294,13 @@ where
12931294
);
12941295
inner.group()
12951296
}
1297+
Formula::Implies(lhs, rhs) => lhs
1298+
.pretty_atom(allocator)
1299+
.append(allocator.space())
1300+
.append(allocator.text("==>"))
1301+
.append(allocator.line())
1302+
.append(rhs.pretty_atom(allocator))
1303+
.group(),
12961304
Formula::Exists(vars, fo) => {
12971305
let vars = allocator.intersperse(
12981306
vars.iter().map(|(name, sort)| {
@@ -1327,7 +1335,7 @@ impl<V> Formula<V> {
13271335
D::Doc: Clone,
13281336
{
13291337
match self {
1330-
Formula::And(_) | Formula::Or(_) | Formula::Exists { .. } => {
1338+
Formula::And(_) | Formula::Or(_) | Formula::Implies(_, _) | Formula::Exists { .. } => {
13311339
self.pretty(allocator).parens()
13321340
}
13331341
_ => self.pretty(allocator),
@@ -1348,6 +1356,7 @@ impl<V> Formula<V> {
13481356
Formula::Not(fo) => fo.is_bottom(),
13491357
Formula::And(fs) => fs.iter().all(Formula::is_top),
13501358
Formula::Or(fs) => fs.iter().any(Formula::is_top),
1359+
Formula::Implies(lhs, rhs) => lhs.is_bottom() || rhs.is_top(),
13511360
Formula::Exists(_, fo) => fo.is_top(),
13521361
}
13531362
}
@@ -1358,6 +1367,7 @@ impl<V> Formula<V> {
13581367
Formula::Not(fo) => fo.is_top(),
13591368
Formula::And(fs) => fs.iter().any(Formula::is_bottom),
13601369
Formula::Or(fs) => fs.iter().all(Formula::is_bottom),
1370+
Formula::Implies(lhs, rhs) => lhs.is_top() && rhs.is_bottom(),
13611371
Formula::Exists(_, fo) => fo.is_bottom(),
13621372
}
13631373
}
@@ -1389,6 +1399,10 @@ impl<V> Formula<V> {
13891399
}
13901400
}
13911401

1402+
pub fn implies(self, other: Self) -> Self {
1403+
Formula::Implies(Box::new(self), Box::new(other))
1404+
}
1405+
13921406
pub fn exists(vars: Vec<(String, Sort)>, body: Self) -> Self {
13931407
Formula::Exists(vars, Box::new(body))
13941408
}
@@ -1406,6 +1420,9 @@ impl<V> Formula<V> {
14061420
Formula::And(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect())
14071421
}
14081422
Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect()),
1423+
Formula::Implies(lhs, rhs) => {
1424+
Formula::Implies(Box::new(lhs.subst_var(&mut f)), Box::new(rhs.subst_var(f)))
1425+
}
14091426
Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.subst_var(f))),
14101427
}
14111428
}
@@ -1421,6 +1438,9 @@ impl<V> Formula<V> {
14211438
Formula::Not(fo) => Formula::Not(Box::new(fo.map_var(&mut f))),
14221439
Formula::And(fs) => Formula::And(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()),
14231440
Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()),
1441+
Formula::Implies(lhs, rhs) => {
1442+
Formula::Implies(Box::new(lhs.map_var(&mut f)), Box::new(rhs.map_var(f)))
1443+
}
14241444
Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.map_var(f))),
14251445
}
14261446
}
@@ -1435,6 +1455,7 @@ impl<V> Formula<V> {
14351455
Formula::Not(fo) => Box::new(fo.fv()),
14361456
Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::fv)),
14371457
Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::fv)),
1458+
Formula::Implies(lhs, rhs) => Box::new(lhs.fv().chain(rhs.fv())),
14381459
Formula::Exists(_, fo) => Box::new(fo.fv()),
14391460
}
14401461
}
@@ -1449,6 +1470,7 @@ impl<V> Formula<V> {
14491470
Formula::Not(fo) => Box::new(fo.iter_atoms()),
14501471
Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)),
14511472
Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)),
1473+
Formula::Implies(lhs, rhs) => Box::new(lhs.iter_atoms().chain(rhs.iter_atoms())),
14521474
Formula::Exists(_, fo) => Box::new(fo.iter_atoms()),
14531475
}
14541476
}
@@ -1469,6 +1491,17 @@ impl<V> Formula<V> {
14691491
match self {
14701492
Formula::Atom(_atom) => {}
14711493
Formula::Not(fo) => fo.simplify(),
1494+
Formula::Implies(lhs, rhs) => {
1495+
lhs.simplify();
1496+
rhs.simplify();
1497+
if lhs.is_bottom() || rhs.is_top() {
1498+
*self = Formula::top();
1499+
} else if lhs.is_top() {
1500+
*self = std::mem::take(&mut **rhs);
1501+
} else if rhs.is_bottom() {
1502+
*self = std::mem::take(&mut **lhs).not();
1503+
}
1504+
}
14721505
Formula::And(fs) => {
14731506
for fo in &mut *fs {
14741507
fo.simplify();
@@ -1624,7 +1657,7 @@ where
16241657
.into_iter()
16251658
.map(|a| a.guarded(guard.clone()))
16261659
.collect(),
1627-
formula: guard.not().or(formula),
1660+
formula: guard.implies(formula),
16281661
}
16291662
}
16301663
}

src/chc/smtlib2.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,11 @@ impl<'ctx, 'a> std::fmt::Display for Formula<'ctx, 'a> {
288288
let fs = List::open(fs.iter().map(|fo| Formula::new(self.ctx, self.clause, fo)));
289289
write!(f, "(or {})", fs)
290290
}
291+
chc::Formula::Implies(lhs, rhs) => {
292+
let lhs = Formula::new(self.ctx, self.clause, lhs);
293+
let rhs = Formula::new(self.ctx, self.clause, rhs);
294+
write!(f, "(=> {lhs} {rhs})")
295+
}
291296
chc::Formula::Exists(vars, fo) => {
292297
let vars =
293298
List::closed(vars.iter().map(|(v, s)| {

src/chc/unbox.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ fn unbox_formula(formula: Formula) -> Formula {
8181
Formula::Not(fo) => Formula::Not(Box::new(unbox_formula(*fo))),
8282
Formula::And(fs) => Formula::And(fs.into_iter().map(unbox_formula).collect()),
8383
Formula::Or(fs) => Formula::Or(fs.into_iter().map(unbox_formula).collect()),
84+
Formula::Implies(lhs, rhs) => {
85+
Formula::Implies(Box::new(unbox_formula(*lhs)), Box::new(unbox_formula(*rhs)))
86+
}
8487
Formula::Exists(vars, fo) => {
8588
let vars = vars.into_iter().map(|(v, s)| (v, unbox_sort(s))).collect();
8689
Formula::Exists(vars, Box::new(unbox_formula(*fo)))

src/rty.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,6 +1930,10 @@ fn subst_ty_params_in_formula<T, V>(formula: &mut chc::Formula<V>, subst: &TypeP
19301930
subst_ty_params_in_formula(f, subst);
19311931
}
19321932
}
1933+
chc::Formula::Implies(lhs, rhs) => {
1934+
subst_ty_params_in_formula(lhs, subst);
1935+
subst_ty_params_in_formula(rhs, subst);
1936+
}
19331937
chc::Formula::Exists(vars, f) => {
19341938
for (_, sort) in vars {
19351939
subst_ty_params_in_sort(sort, subst);

std.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ mod thrust_models {
325325
unimplemented!()
326326
}
327327

328+
#[allow(dead_code)]
329+
#[thrust::def::implies]
330+
#[thrust::ignored]
331+
pub fn implies(_lhs: bool, _rhs: bool) -> bool {
332+
unimplemented!()
333+
}
334+
328335
#[thrust::def::invariant_marker]
329336
#[thrust::ignored]
330337
#[inline(never)]

thrust-macros/src/formula.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//! which parses formulas as [`syn::Expr`] — never chokes on it, and gives one
88
//! place to run preprocessing before the body reaches rustc / HIR lowering.
99
//!
10-
//! The only pass today is implication desugaring; further passes can be appended
10+
//! The only pass today is implication lowering; further passes can be appended
1111
//! in [`expand`].
1212
1313
use proc_macro2::{Group, Punct, Spacing, TokenStream, TokenTree};
@@ -96,17 +96,17 @@ pub fn expand(input: TokenStream) -> TokenStream {
9696

9797
// `==>` is desugared to assignment (the lowest-precedence, right-associative
9898
// operator) so `syn` reproduces its precedence, then each assignment node is
99-
// rewritten into `!lhs || rhs`.
99+
// rewritten into a marker call that the analyzer lowers to `chc::Formula::Implies`.
100100
let desugared = desugar_arrows(input);
101101
let mut expr: syn::Expr = match syn::parse2(desugared) {
102102
Ok(expr) => expr,
103103
Err(e) => return e.to_compile_error(),
104104
};
105105

106106
// Rewrites each assignment `lhs = rhs` (produced by [`desugar_arrows`] from
107-
// `lhs ==> rhs`) into `(!(lhs)) || (rhs)`. Visiting post-order means nested
108-
// implications are rewritten innermost-first, so the right-associative chain
109-
// `a ==> b ==> c` becomes `!a || (!b || c)`.
107+
// `lhs ==> rhs`) into `thrust_models::implies(lhs, rhs)`. Visiting
108+
// post-order means nested implications are rewritten innermost-first, so the
109+
// right-associative chain `a ==> b ==> c` becomes `implies(a, implies(b, c))`.
110110
struct ImplicationRewriter;
111111

112112
impl VisitMut for ImplicationRewriter {
@@ -115,7 +115,7 @@ pub fn expand(input: TokenStream) -> TokenStream {
115115
if let syn::Expr::Assign(assign) = expr {
116116
let left = &assign.left;
117117
let right = &assign.right;
118-
*expr = syn::parse_quote!((!(#left)) || (#right));
118+
*expr = syn::parse_quote!(thrust_models::implies((#left), (#right)));
119119
}
120120
}
121121
}
@@ -213,21 +213,24 @@ mod tests {
213213

214214
#[test]
215215
fn desugars_implication() {
216-
assert_eq!(expand_expr("a ==> b"), expect("(!(a)) || (b)"));
216+
assert_eq!(
217+
expand_expr("a ==> b"),
218+
expect("thrust_models::implies((a), (b))")
219+
);
217220
// right-associative
218221
assert_eq!(
219222
expand_expr("a ==> b ==> c"),
220-
expect("(!(a)) || ((!(b)) || (c))")
223+
expect("thrust_models::implies((a), (thrust_models::implies((b), (c))))")
221224
);
222225
// lower precedence than `||` and `==`
223226
assert_eq!(
224227
expand_expr("a || b ==> c == d"),
225-
expect("(!(a || b)) || (c == d)")
228+
expect("thrust_models::implies((a || b), (c == d))")
226229
);
227230
// nested inside a closure argument
228231
assert_eq!(
229232
expand_expr("exists(|x| a ==> b)"),
230-
expect("exists(|x| (!(a)) || (b))")
233+
expect("exists(|x| thrust_models::implies((a), (b)))")
231234
);
232235
}
233236

thrust-macros/src/spec.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ pub fn expand_predicate(item: TokenStream) -> TokenStream {
5454
}
5555

5656
pub fn expand_requires(attr: TokenStream, item: TokenStream) -> TokenStream {
57-
let expr = crate::formula::wrap_expr(attr.into());
57+
let expr = crate::formula::expand(attr.into());
5858
let mut func = parse_macro_input!(item as FnItemWithSignature);
5959

6060
let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) {
@@ -69,7 +69,7 @@ pub fn expand_requires(attr: TokenStream, item: TokenStream) -> TokenStream {
6969
}
7070

7171
pub fn expand_ensures(attr: TokenStream, item: TokenStream) -> TokenStream {
72-
let expr = crate::formula::wrap_expr(attr.into());
72+
let expr = crate::formula::expand(attr.into());
7373
let mut func = parse_macro_input!(item as FnItemWithSignature);
7474

7575
let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) {

0 commit comments

Comments
 (0)