Skip to content

Commit 76e6ab3

Browse files
committed
Represent formula implication explicitly
1 parent 9776e66 commit 76e6ab3

10 files changed

Lines changed: 114 additions & 16 deletions

File tree

src/analyze/annot.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ pub fn exists_path() -> [Symbol; 3] {
145145
]
146146
}
147147

148+
pub fn implies_path() -> [Symbol; 3] {
149+
[
150+
Symbol::intern("thrust"),
151+
Symbol::intern("def"),
152+
Symbol::intern("implies"),
153+
]
154+
}
155+
148156
pub fn invariant_marker_path() -> [Symbol; 3] {
149157
[
150158
Symbol::intern("thrust"),

src/analyze/annot_fn.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ enum FormulaOrTerm<T> {
102102
BinOp(chc::Term<T>, AmbiguousBinOp, chc::Term<T>),
103103
And(Box<FormulaOrTerm<T>>, Box<FormulaOrTerm<T>>),
104104
Or(Box<FormulaOrTerm<T>>, Box<FormulaOrTerm<T>>),
105+
Implies(Box<FormulaOrTerm<T>>, Box<FormulaOrTerm<T>>),
105106
Not(Box<FormulaOrTerm<T>>),
106107
Literal(bool),
107108
}
@@ -124,6 +125,7 @@ impl<T> FormulaOrTerm<T> {
124125
}
125126
FormulaOrTerm::And(lhs, rhs) => lhs.into_formula()?.and(rhs.into_formula()?),
126127
FormulaOrTerm::Or(lhs, rhs) => lhs.into_formula()?.or(rhs.into_formula()?),
128+
FormulaOrTerm::Implies(lhs, rhs) => lhs.into_formula()?.implies(rhs.into_formula()?),
127129
FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_formula()?.not(),
128130
FormulaOrTerm::Literal(b) => {
129131
if b {
@@ -148,6 +150,7 @@ impl<T> FormulaOrTerm<T> {
148150
FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Lt, rhs) => lhs.lt(rhs),
149151
FormulaOrTerm::And(lhs, rhs) => lhs.into_term()?.and(rhs.into_term()?),
150152
FormulaOrTerm::Or(lhs, rhs) => lhs.into_term()?.or(rhs.into_term()?),
153+
FormulaOrTerm::Implies(lhs, rhs) => lhs.into_term()?.not().or(rhs.into_term()?),
151154
FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_term()?.not(),
152155
FormulaOrTerm::Literal(b) => chc::Term::bool(b),
153156
};
@@ -607,6 +610,14 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
607610
body_formula,
608611
));
609612
}
613+
if Some(def_id) == self.def_ids.implies() {
614+
let [lhs, rhs] = args else {
615+
panic!("implies takes exactly 2 arguments");
616+
};
617+
let lhs = self.to_formula_or_term(lhs);
618+
let rhs = self.to_formula_or_term(rhs);
619+
return FormulaOrTerm::Implies(lhs.into(), rhs.into());
620+
}
610621
if Some(def_id) == self.def_ids.mut_model_new() {
611622
assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments");
612623
let t1 = self.to_term(&args[0]);

src/analyze/did_cache.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct DefIds {
2525
array_model_store: OnceCell<Option<DefId>>,
2626

2727
exists: OnceCell<Option<DefId>>,
28+
implies: OnceCell<Option<DefId>>,
2829
invariant_marker: OnceCell<Option<DefId>>,
2930

3031
closure_precondition: OnceCell<Option<DefId>>,
@@ -181,6 +182,13 @@ impl<'tcx> DefIdCache<'tcx> {
181182
.get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path()))
182183
}
183184

185+
pub fn implies(&self) -> Option<DefId> {
186+
*self
187+
.def_ids
188+
.implies
189+
.get_or_init(|| self.annotated_def(&crate::analyze::annot::implies_path()))
190+
}
191+
184192
pub fn invariant_marker(&self) -> Option<DefId> {
185193
*self
186194
.def_ids

src/chc.rs

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,15 +1245,32 @@ impl<V> Atom<V> {
12451245
/// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known
12461246
/// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to
12471247
/// enforce this restriction statically. Also see the definition of [`Body`].
1248-
#[derive(Debug, Clone)]
1248+
#[derive(Clone)]
12491249
pub enum Formula<V = TermVarIdx> {
12501250
Atom(Atom<V>),
12511251
Not(Box<Formula<V>>),
12521252
And(Vec<Formula<V>>),
12531253
Or(Vec<Formula<V>>),
1254+
Implies(Box<Formula<V>>, Box<Formula<V>>),
12541255
Exists(Vec<(String, Sort)>, Box<Formula<V>>),
12551256
}
12561257

1258+
impl<V> std::fmt::Debug for Formula<V>
1259+
where
1260+
V: std::fmt::Debug,
1261+
{
1262+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1263+
match self {
1264+
Formula::Atom(atom) => atom.fmt(f),
1265+
Formula::Not(fo) => f.debug_tuple("Not").field(fo).finish(),
1266+
Formula::And(fs) => f.debug_tuple("And").field(fs).finish(),
1267+
Formula::Or(fs) => f.debug_tuple("Or").field(fs).finish(),
1268+
Formula::Implies(lhs, rhs) => write!(f, "({lhs:?} ==> {rhs:?})"),
1269+
Formula::Exists(vars, fo) => f.debug_tuple("Exists").field(vars).field(fo).finish(),
1270+
}
1271+
}
1272+
}
1273+
12571274
impl<V> Default for Formula<V> {
12581275
fn default() -> Self {
12591276
Formula::top()
@@ -1293,6 +1310,13 @@ where
12931310
);
12941311
inner.group()
12951312
}
1313+
Formula::Implies(lhs, rhs) => lhs
1314+
.pretty_atom(allocator)
1315+
.append(allocator.space())
1316+
.append(allocator.text("==>"))
1317+
.append(allocator.line())
1318+
.append(rhs.pretty_atom(allocator))
1319+
.group(),
12961320
Formula::Exists(vars, fo) => {
12971321
let vars = allocator.intersperse(
12981322
vars.iter().map(|(name, sort)| {
@@ -1327,7 +1351,7 @@ impl<V> Formula<V> {
13271351
D::Doc: Clone,
13281352
{
13291353
match self {
1330-
Formula::And(_) | Formula::Or(_) | Formula::Exists { .. } => {
1354+
Formula::And(_) | Formula::Or(_) | Formula::Implies(_, _) | Formula::Exists { .. } => {
13311355
self.pretty(allocator).parens()
13321356
}
13331357
_ => self.pretty(allocator),
@@ -1348,6 +1372,7 @@ impl<V> Formula<V> {
13481372
Formula::Not(fo) => fo.is_bottom(),
13491373
Formula::And(fs) => fs.iter().all(Formula::is_top),
13501374
Formula::Or(fs) => fs.iter().any(Formula::is_top),
1375+
Formula::Implies(lhs, rhs) => lhs.is_bottom() || rhs.is_top(),
13511376
Formula::Exists(_, fo) => fo.is_top(),
13521377
}
13531378
}
@@ -1358,6 +1383,7 @@ impl<V> Formula<V> {
13581383
Formula::Not(fo) => fo.is_top(),
13591384
Formula::And(fs) => fs.iter().any(Formula::is_bottom),
13601385
Formula::Or(fs) => fs.iter().all(Formula::is_bottom),
1386+
Formula::Implies(lhs, rhs) => lhs.is_top() && rhs.is_bottom(),
13611387
Formula::Exists(_, fo) => fo.is_bottom(),
13621388
}
13631389
}
@@ -1389,6 +1415,10 @@ impl<V> Formula<V> {
13891415
}
13901416
}
13911417

1418+
pub fn implies(self, other: Self) -> Self {
1419+
Formula::Implies(Box::new(self), Box::new(other))
1420+
}
1421+
13921422
pub fn exists(vars: Vec<(String, Sort)>, body: Self) -> Self {
13931423
Formula::Exists(vars, Box::new(body))
13941424
}
@@ -1406,6 +1436,9 @@ impl<V> Formula<V> {
14061436
Formula::And(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect())
14071437
}
14081438
Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect()),
1439+
Formula::Implies(lhs, rhs) => {
1440+
Formula::Implies(Box::new(lhs.subst_var(&mut f)), Box::new(rhs.subst_var(f)))
1441+
}
14091442
Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.subst_var(f))),
14101443
}
14111444
}
@@ -1421,6 +1454,9 @@ impl<V> Formula<V> {
14211454
Formula::Not(fo) => Formula::Not(Box::new(fo.map_var(&mut f))),
14221455
Formula::And(fs) => Formula::And(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()),
14231456
Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()),
1457+
Formula::Implies(lhs, rhs) => {
1458+
Formula::Implies(Box::new(lhs.map_var(&mut f)), Box::new(rhs.map_var(f)))
1459+
}
14241460
Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.map_var(f))),
14251461
}
14261462
}
@@ -1435,6 +1471,7 @@ impl<V> Formula<V> {
14351471
Formula::Not(fo) => Box::new(fo.fv()),
14361472
Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::fv)),
14371473
Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::fv)),
1474+
Formula::Implies(lhs, rhs) => Box::new(lhs.fv().chain(rhs.fv())),
14381475
Formula::Exists(_, fo) => Box::new(fo.fv()),
14391476
}
14401477
}
@@ -1449,6 +1486,7 @@ impl<V> Formula<V> {
14491486
Formula::Not(fo) => Box::new(fo.iter_atoms()),
14501487
Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)),
14511488
Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)),
1489+
Formula::Implies(lhs, rhs) => Box::new(lhs.iter_atoms().chain(rhs.iter_atoms())),
14521490
Formula::Exists(_, fo) => Box::new(fo.iter_atoms()),
14531491
}
14541492
}
@@ -1469,6 +1507,17 @@ impl<V> Formula<V> {
14691507
match self {
14701508
Formula::Atom(_atom) => {}
14711509
Formula::Not(fo) => fo.simplify(),
1510+
Formula::Implies(lhs, rhs) => {
1511+
lhs.simplify();
1512+
rhs.simplify();
1513+
if lhs.is_bottom() || rhs.is_top() {
1514+
*self = Formula::top();
1515+
} else if lhs.is_top() {
1516+
*self = std::mem::take(&mut **rhs);
1517+
} else if rhs.is_bottom() {
1518+
*self = std::mem::take(&mut **lhs).not();
1519+
}
1520+
}
14721521
Formula::And(fs) => {
14731522
for fo in &mut *fs {
14741523
fo.simplify();
@@ -1624,7 +1673,7 @@ where
16241673
.into_iter()
16251674
.map(|a| a.guarded(guard.clone()))
16261675
.collect(),
1627-
formula: guard.not().or(formula),
1676+
formula: guard.implies(formula),
16281677
}
16291678
}
16301679
}

src/chc/smtlib2.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,11 @@ impl<'ctx, 'a> std::fmt::Display for Formula<'ctx, 'a> {
288288
let fs = List::open(fs.iter().map(|fo| Formula::new(self.ctx, self.clause, fo)));
289289
write!(f, "(or {})", fs)
290290
}
291+
chc::Formula::Implies(lhs, rhs) => {
292+
let lhs = Formula::new(self.ctx, self.clause, lhs);
293+
let rhs = Formula::new(self.ctx, self.clause, rhs);
294+
write!(f, "(=> {lhs} {rhs})")
295+
}
291296
chc::Formula::Exists(vars, fo) => {
292297
let vars =
293298
List::closed(vars.iter().map(|(v, s)| {

src/chc/unbox.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ fn unbox_formula(formula: Formula) -> Formula {
8181
Formula::Not(fo) => Formula::Not(Box::new(unbox_formula(*fo))),
8282
Formula::And(fs) => Formula::And(fs.into_iter().map(unbox_formula).collect()),
8383
Formula::Or(fs) => Formula::Or(fs.into_iter().map(unbox_formula).collect()),
84+
Formula::Implies(lhs, rhs) => {
85+
Formula::Implies(Box::new(unbox_formula(*lhs)), Box::new(unbox_formula(*rhs)))
86+
}
8487
Formula::Exists(vars, fo) => {
8588
let vars = vars.into_iter().map(|(v, s)| (v, unbox_sort(s))).collect();
8689
Formula::Exists(vars, Box::new(unbox_formula(*fo)))

src/rty.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,6 +1930,10 @@ fn subst_ty_params_in_formula<T, V>(formula: &mut chc::Formula<V>, subst: &TypeP
19301930
subst_ty_params_in_formula(f, subst);
19311931
}
19321932
}
1933+
chc::Formula::Implies(lhs, rhs) => {
1934+
subst_ty_params_in_formula(lhs, subst);
1935+
subst_ty_params_in_formula(rhs, subst);
1936+
}
19331937
chc::Formula::Exists(vars, f) => {
19341938
for (_, sort) in vars {
19351939
subst_ty_params_in_sort(sort, subst);

std.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ mod thrust_models {
325325
unimplemented!()
326326
}
327327

328+
#[allow(dead_code)]
329+
#[thrust::def::implies]
330+
#[thrust::ignored]
331+
pub fn implies(_lhs: bool, _rhs: bool) -> bool {
332+
unimplemented!()
333+
}
334+
328335
#[thrust::def::invariant_marker]
329336
#[thrust::ignored]
330337
#[inline(never)]

thrust-macros/src/formula.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//! which parses formulas as [`syn::Expr`] — never chokes on it, and gives one
88
//! place to run preprocessing before the body reaches rustc / HIR lowering.
99
//!
10-
//! The only pass today is implication desugaring; further passes can be appended
10+
//! The only pass today is implication lowering; further passes can be appended
1111
//! in [`expand`].
1212
1313
use proc_macro2::{Group, Punct, Spacing, TokenStream, TokenTree};
@@ -27,17 +27,17 @@ pub fn expand(input: TokenStream) -> TokenStream {
2727

2828
// `==>` is desugared to assignment (the lowest-precedence, right-associative
2929
// operator) so `syn` reproduces its precedence, then each assignment node is
30-
// rewritten into `!lhs || rhs`.
30+
// rewritten into a marker call that the analyzer lowers to `chc::Formula::Implies`.
3131
let desugared = desugar_arrows(input);
3232
let mut expr: syn::Expr = match syn::parse2(desugared) {
3333
Ok(expr) => expr,
3434
Err(e) => return e.to_compile_error(),
3535
};
3636

3737
// Rewrites each assignment `lhs = rhs` (produced by [`desugar_arrows`] from
38-
// `lhs ==> rhs`) into the boolean expression `(!(lhs)) || (rhs)`. Visiting
38+
// `lhs ==> rhs`) into `thrust_models::model::implies(lhs, rhs)`. Visiting
3939
// post-order means nested implications are rewritten innermost-first, so the
40-
// right-associative chain `a ==> b ==> c` becomes `!a || (!b || c)`.
40+
// right-associative chain `a ==> b ==> c` becomes `implies(a, implies(b, c))`.
4141
struct ImplicationRewriter;
4242

4343
impl VisitMut for ImplicationRewriter {
@@ -46,7 +46,7 @@ pub fn expand(input: TokenStream) -> TokenStream {
4646
if let syn::Expr::Assign(assign) = expr {
4747
let left = &assign.left;
4848
let right = &assign.right;
49-
*expr = syn::parse_quote!((!(#left)) || (#right));
49+
*expr = syn::parse_quote!(thrust_models::model::implies((#left), (#right)));
5050
}
5151
}
5252
}
@@ -98,10 +98,10 @@ fn desugar_arrows(input: TokenStream) -> TokenStream {
9898
new_group.set_span(group.span());
9999
out.extend([TokenTree::Group(new_group)]);
100100
}
101-
TokenTree::Punct(p) if p.as_char() == '=' && p.spacing() == Spacing::Joint => {
101+
TokenTree::Punct(p) if p.as_char() == '=' => {
102102
// Look for `=` `=` `>`. `p` is the first `=`.
103103
if let Some(TokenTree::Punct(p2)) = iter.peek() {
104-
if p2.as_char() == '=' && p2.spacing() == Spacing::Joint {
104+
if p2.as_char() == '=' {
105105
let mut lookahead = iter.clone();
106106
lookahead.next(); // consume the second `=`
107107
if let Some(TokenTree::Punct(p3)) = lookahead.peek() {
@@ -140,21 +140,24 @@ mod tests {
140140

141141
#[test]
142142
fn desugars_implication() {
143-
assert_eq!(expand_expr("a ==> b"), expect("(!(a)) || (b)"));
143+
assert_eq!(
144+
expand_expr("a ==> b"),
145+
expect("thrust_models::model::implies((a), (b))")
146+
);
144147
// right-associative
145148
assert_eq!(
146149
expand_expr("a ==> b ==> c"),
147-
expect("(!(a)) || ((!(b)) || (c))")
150+
expect("thrust_models::model::implies((a), (thrust_models::model::implies((b), (c))))")
148151
);
149152
// lower precedence than `||` and `==`
150153
assert_eq!(
151154
expand_expr("a || b ==> c == d"),
152-
expect("(!(a || b)) || (c == d)")
155+
expect("thrust_models::model::implies((a || b), (c == d))")
153156
);
154157
// nested inside a closure argument
155158
assert_eq!(
156159
expand_expr("exists(|x| a ==> b)"),
157-
expect("exists(|x| (!(a)) || (b))")
160+
expect("exists(|x| thrust_models::model::implies((a), (b)))")
158161
);
159162
}
160163

thrust-macros/src/spec.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ fn wrap_trailing_block(item: TokenStream2) -> TokenStream2 {
103103
}
104104

105105
pub fn expand_requires(attr: TokenStream, item: TokenStream) -> TokenStream {
106-
let expr = crate::formula::wrap_formula(attr.into());
106+
let expr = crate::formula::expand(attr.into());
107107
let mut func = parse_macro_input!(item as FnItemWithSignature);
108108

109109
let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) {
@@ -118,7 +118,7 @@ pub fn expand_requires(attr: TokenStream, item: TokenStream) -> TokenStream {
118118
}
119119

120120
pub fn expand_ensures(attr: TokenStream, item: TokenStream) -> TokenStream {
121-
let expr = crate::formula::wrap_formula(attr.into());
121+
let expr = crate::formula::expand(attr.into());
122122
let mut func = parse_macro_input!(item as FnItemWithSignature);
123123

124124
let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) {

0 commit comments

Comments
 (0)