Skip to content

Commit ca18f5b

Browse files
wthollidayclaude
andcommitted
Add require clauses and post-monomorph body safety checks
Introduces `require <expr>` clauses on function declarations, checked by the safety checker at each call site and assumed true inside the body. Adds a second safety-check pass after monomorphization so that bodies of `[T; N]` generic functions are verified with concrete sizes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ca2c0d2 commit ca18f5b

17 files changed

Lines changed: 557 additions & 8 deletions

cli/src/main.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,23 @@ fn run(args: Args) -> i32 {
149149
return 1;
150150
}
151151

152+
// For `--check`, also run monomorphization so the post-monomorph safety
153+
// check (which catches bounds violations in `[T; N]` function bodies)
154+
// fires here too. The actual compiled code is discarded. Skip if no
155+
// entry point is defined — many test snippets are entry-point-less.
156+
if args.check && compiler.has_decls() {
157+
let has_entry = compiler
158+
.effective_entry_points()
159+
.iter()
160+
.any(|name| !compiler.decls().find(*name).is_empty());
161+
if has_entry {
162+
if let Err(e) = compiler.specialize() {
163+
eprintln!("{}", e);
164+
return 1;
165+
}
166+
}
167+
}
168+
152169
compiler.print_ir = args.ir;
153170

154171
// Select backend via --backend flag, falling back to LYTE_BACKEND env var.

src/checker.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,17 @@ impl Checker {
988988
}
989989
}
990990

991+
// Type-check require clauses: each must have type bool.
992+
for &req in &func_decl.requires {
993+
let req_ty = self.check_expr(req, &func_decl.arena, decls);
994+
self.eq(
995+
req_ty,
996+
mk_type(Type::Bool),
997+
func_decl.arena.locs[req],
998+
"require clause must be a boolean expression",
999+
);
1000+
}
1001+
9911002
// Check the body of the function.
9921003
let ty = self.check_expr(body, &func_decl.arena, decls);
9931004

src/compiler.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ fn builtin_decls() -> Vec<Decl> {
6565
body: None,
6666
ret: mk_type(Type::Void),
6767
constraints: vec![],
68+
requires: vec![],
6869
loc: test_loc(),
6970
arena: ExprArena::new(),
7071
types: vec![],
@@ -83,6 +84,7 @@ fn builtin_decls() -> Vec<Decl> {
8384
body: None,
8485
ret: mk_type(Type::Void),
8586
constraints: vec![],
87+
requires: vec![],
8688
loc: test_loc(),
8789
arena: ExprArena::new(),
8890
types: vec![],
@@ -101,6 +103,7 @@ fn builtin_decls() -> Vec<Decl> {
101103
body: None,
102104
ret: mk_type(Type::Void),
103105
constraints: vec![],
106+
requires: vec![],
104107
loc: test_loc(),
105108
arena: ExprArena::new(),
106109
types: vec![],
@@ -130,6 +133,7 @@ fn builtin_decls() -> Vec<Decl> {
130133
body: None,
131134
ret: ret_ty,
132135
constraints: vec![],
136+
requires: vec![],
133137
loc: test_loc(),
134138
arena: ExprArena::new(),
135139
types: vec![],
@@ -155,6 +159,7 @@ fn builtin_decls() -> Vec<Decl> {
155159
body: None,
156160
ret: bool_ty,
157161
constraints: vec![],
162+
requires: vec![],
158163
loc: test_loc(),
159164
arena: ExprArena::new(),
160165
types: vec![],
@@ -188,6 +193,7 @@ fn builtin_decls() -> Vec<Decl> {
188193
body: None,
189194
ret: ret_ty,
190195
constraints: vec![],
196+
requires: vec![],
191197
loc: test_loc(),
192198
arena: ExprArena::new(),
193199
types: vec![],
@@ -225,6 +231,7 @@ fn builtin_decls() -> Vec<Decl> {
225231
body: None,
226232
ret: f32x4_ty,
227233
constraints: vec![],
234+
requires: vec![],
228235
loc: test_loc(),
229236
arena: ExprArena::new(),
230237
types: vec![],
@@ -244,6 +251,7 @@ fn builtin_decls() -> Vec<Decl> {
244251
body: None,
245252
ret: f32x4_ty,
246253
constraints: vec![],
254+
requires: vec![],
247255
loc: test_loc(),
248256
arena: ExprArena::new(),
249257
types: vec![],
@@ -627,7 +635,11 @@ impl Compiler {
627635
let mut pass = MonomorphPass::new();
628636
let entry_points = self.effective_entry_points();
629637
let all_decls = pass.monomorphize_multi(&self.decls, &entry_points)?;
630-
// monomorphize now returns all decls (original + specialized)
638+
// Capture the names of newly-generated specialized decls so we can
639+
// run a focused safety-check pass on them (their bodies were skipped
640+
// pre-monomorph because they contained size variables).
641+
let specialized_names: std::collections::HashSet<Name> =
642+
pass.instantiated_names().collect();
631643
self.decls = DeclTable::new(all_decls);
632644

633645
// Rename non-generic overloaded functions to unique symbols.
@@ -636,6 +648,32 @@ impl Compiler {
636648
rename_overloaded_functions(&mut self.decls);
637649
self.decls = DeclTable::new(self.decls.decls.clone());
638650

651+
// Re-run the safety checker on specialized declarations only.
652+
// Their bodies were skipped pre-monomorph because they had non-empty
653+
// size_vars; now sizes are concrete (`[T; Known(K)]`) so bounds and
654+
// require clauses can be verified properly.
655+
if !specialized_names.is_empty() {
656+
let mut sc = SafetyChecker::new();
657+
for decl in &self.decls.decls {
658+
if let Decl::Func(f) = decl {
659+
if specialized_names.contains(&f.name) {
660+
sc.check_decl(decl, &self.decls);
661+
}
662+
}
663+
}
664+
if !self.quiet {
665+
sc.print_errors();
666+
}
667+
for err in &sc.errors {
668+
self.last_errors
669+
.push(format_error(err.location, &err.message));
670+
}
671+
self.last_safety_errors.extend(sc.errors.iter().cloned());
672+
if !sc.errors.is_empty() {
673+
return Err(format!("safety check failed for {} call(s)", sc.errors.len()));
674+
}
675+
}
676+
639677
// Hoist loop-invariant struct field reads (after monomorphization
640678
// so we operate on concrete types, and after safety checking).
641679
{

src/decl.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ pub struct FuncDecl {
5656
/// The interfaces that must be available for this function.
5757
pub constraints: Vec<InterfaceConstraint>,
5858

59+
/// Precondition expressions checked by the safety checker:
60+
/// assumed true inside the body, must be provable at every call site.
61+
/// Each ExprID indexes into `arena`.
62+
pub requires: Vec<ExprID>,
63+
5964
/// Location of the declaration in source code.
6065
pub loc: Loc,
6166

@@ -312,6 +317,12 @@ fn format_func_decl(func: &FuncDecl, is_macro: bool) -> String {
312317
format!(" → {}", func.ret.pretty_print())
313318
};
314319
let constraints = format_constraints(&func.constraints);
320+
let requires = func
321+
.requires
322+
.iter()
323+
.map(|&r| format!(" require {}", func.arena.exprs[r].pretty_print(&func.arena, 0)))
324+
.collect::<Vec<_>>()
325+
.join("");
315326

316327
let signature = if is_macro {
317328
format!("macro {}{}", func.name, typevars)
@@ -324,11 +335,14 @@ fn format_func_decl(func: &FuncDecl, is_macro: bool) -> String {
324335
if let Some(body_id) = func.body {
325336
let body_str = func.arena.exprs[body_id].pretty_print(&func.arena, 0);
326337
format!(
327-
"{}({}){}{} {}",
328-
signature, params, ret_type, constraints, body_str
338+
"{}({}){}{}{} {}",
339+
signature, params, ret_type, constraints, requires, body_str
329340
)
330341
} else {
331-
format!("{}({}){}{}", signature, params, ret_type, constraints)
342+
format!(
343+
"{}({}){}{}{}",
344+
signature, params, ret_type, constraints, requires
345+
)
332346
}
333347
}
334348

@@ -400,6 +414,7 @@ mod tests {
400414
body: None,
401415
ret: mk_type(Type::Int32),
402416
constraints: vec![],
417+
requires: vec![],
403418
loc: test_loc(),
404419
arena: ExprArena::new(),
405420
types: vec![],
@@ -428,6 +443,7 @@ mod tests {
428443
body: Some(body_id),
429444
ret: mk_type(Type::Var(Name::str("T"))),
430445
constraints: vec![],
446+
requires: vec![],
431447
loc: test_loc(),
432448
arena,
433449
types: vec![],
@@ -486,6 +502,7 @@ mod tests {
486502
body: None,
487503
ret: mk_type(Type::Int32),
488504
constraints: vec![],
505+
requires: vec![],
489506
loc: test_loc(),
490507
arena: ExprArena::new(),
491508
types: vec![],
@@ -546,6 +563,7 @@ mod tests {
546563
body: Some(body_id),
547564
ret: mk_type(Type::Int32),
548565
constraints: vec![],
566+
requires: vec![],
549567
loc: test_loc(),
550568
arena,
551569
types: vec![],

src/jit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,7 @@ impl<'a> FunctionTranslator<'a> {
17611761
body: Some(*body),
17621762
ret: rng,
17631763
constraints: vec![],
1764+
requires: vec![],
17641765
loc: decl.loc,
17651766
arena: decl.arena.clone(),
17661767
types: decl.types.clone(),

src/lexer.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub enum Token {
8383
Assume,
8484
Const,
8585
Extern,
86+
Require,
8687
End,
8788
Error,
8889
}
@@ -211,6 +212,7 @@ impl Lexer {
211212
"assume" => Token::Assume,
212213
"const" => Token::Const,
213214
"extern" => Token::Extern,
215+
"require" => Token::Require,
214216
_ => Token::Id(Name::str(&id)),
215217
};
216218
}

src/llvm_jit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3400,6 +3400,7 @@ impl<'a, 'ctx> FunctionTranslator<'a, 'ctx> {
34003400
body: Some(body),
34013401
ret: rng,
34023402
constraints: vec![],
3403+
requires: vec![],
34033404
loc: decl.loc,
34043405
arena: decl.arena.clone(),
34053406
types: decl.types.clone(),

src/monomorph_pass.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,13 @@ impl MonomorphPass {
618618
&self.out_decls
619619
}
620620

621+
/// The mangled names of all functions newly created by monomorphization
622+
/// (i.e., specialized versions of generic functions). Excludes entry
623+
/// points and any pre-existing non-generic functions.
624+
pub fn instantiated_names(&self) -> impl Iterator<Item = Name> + '_ {
625+
self.instantiations.values().copied()
626+
}
627+
621628
/// Get the mangled name for a specific instantiation, if it exists
622629
pub fn get_instantiation(&self, key: &MonomorphKey) -> Option<Name> {
623630
self.instantiations.get(key).copied()
@@ -721,6 +728,7 @@ mod tests {
721728
size_vars: vec![],
722729
params: Vec::new(),
723730
constraints: Vec::new(),
731+
requires: vec![],
724732
ret: mk_type(Type::Void),
725733
body: Some(body_expr),
726734
arena,
@@ -834,6 +842,7 @@ mod tests {
834842
size_vars: Vec::new(),
835843
params: Vec::new(),
836844
constraints: Vec::new(),
845+
requires: vec![],
837846
ret: mk_type(Type::Void),
838847
body: Some(block),
839848
arena,
@@ -863,6 +872,7 @@ mod tests {
863872
size_vars: Vec::new(),
864873
params: Vec::new(),
865874
constraints: Vec::new(),
875+
requires: vec![],
866876
ret: mk_type(Type::Void),
867877
body: Some(binop),
868878
arena,
@@ -892,6 +902,7 @@ mod tests {
892902
size_vars: Vec::new(),
893903
params: Vec::new(),
894904
constraints: Vec::new(),
905+
requires: vec![],
895906
ret: mk_type(Type::Void),
896907
body: Some(array),
897908
arena,
@@ -1127,6 +1138,7 @@ mod tests {
11271138
size_vars: Vec::new(),
11281139
params: Vec::new(),
11291140
constraints: Vec::new(),
1141+
requires: vec![],
11301142
ret: mk_type(Type::Void),
11311143
body: Some(if_expr),
11321144
arena,
@@ -1160,6 +1172,7 @@ mod tests {
11601172
size_vars: Vec::new(),
11611173
params: Vec::new(),
11621174
constraints: Vec::new(),
1175+
requires: vec![],
11631176
ret: mk_type(Type::Int32),
11641177
body: Some(body_expr),
11651178
arena,
@@ -1199,6 +1212,7 @@ mod tests {
11991212
ty: Some(t_var),
12001213
}],
12011214
constraints: Vec::new(),
1215+
requires: vec![],
12021216
ret: t_var,
12031217
body: Some(id_param_expr),
12041218
arena: id_arena,
@@ -1224,6 +1238,7 @@ mod tests {
12241238
size_vars: Vec::new(),
12251239
params: Vec::new(),
12261240
constraints: Vec::new(),
1241+
requires: vec![],
12271242
ret: i32_type,
12281243
body: Some(call_expr),
12291244
arena: main_arena,

src/parser.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,7 @@ fn parse_func_decl(name: Name, cx: &mut ParseContext) -> FuncDecl {
945945
let mut params = vec![];
946946
let mut all_vars = vec![];
947947
let mut constraints = vec![];
948+
let mut requires = vec![];
948949
let loc = cx.lex.loc;
949950
let mut arena = ExprArena::new();
950951

@@ -984,6 +985,16 @@ fn parse_func_decl(name: Name, cx: &mut ParseContext) -> FuncDecl {
984985
}
985986
}
986987

988+
skip_newlines(cx.lex);
989+
990+
// Parse zero or more `require <expr>` clauses.
991+
while cx.lex.tok == Token::Require {
992+
cx.next();
993+
let cond = parse_expr(&mut arena, &all_vars, cx);
994+
requires.push(cond);
995+
skip_newlines(cx.lex);
996+
}
997+
987998
// Separate size vars (used as array sizes) from type vars.
988999
let mut size_vars: Vec<Name> = vec![];
9891000
for param in &params {
@@ -1013,6 +1024,7 @@ fn parse_func_decl(name: Name, cx: &mut ParseContext) -> FuncDecl {
10131024
body,
10141025
ret,
10151026
constraints,
1027+
requires,
10161028
loc,
10171029
arena,
10181030
types: vec![],

0 commit comments

Comments
 (0)