Skip to content

Commit b36aba8

Browse files
committed
[anneal] Fix type ambiguity in spec generation
When a Rust function uses a type via use import, Anneal's spec generator only sees the short name from the syn AST. For names like Ordering, that can conflict with Lean's native Ordering brought in by open Aeneas.Std and cause type mismatches in generated Pre/Post structs. After Aeneas generates Funs.lean, parse its def signatures and use the fully-qualified Aeneas parameter and return types as overrides in spec generation. Keep escaped Lean parameter names, such as show1, in the same namespace as the generated specs. Fall back to map_type when no Aeneas signature is available.
1 parent be6f199 commit b36aba8

4 files changed

Lines changed: 358 additions & 8 deletions

File tree

anneal/src/aeneas.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,10 @@ pub fn generate_lean_workspace(roots: &LockedRoots, artifacts: &[AnnealArtifact]
346346
let slug = artifact.artifact_slug();
347347
let output_dir = lean_generated_root.join(&slug);
348348

349+
let funs_types = parse_funs_types_in_dir(&output_dir)?;
350+
349351
// Generate Anneal specs
350-
let generated = generate::generate_artifact(artifact);
352+
let generated = generate::generate_artifact_with_funs_types(artifact, &funs_types);
351353
let specs_path = output_dir.join(artifact.lean_spec_file_name());
352354
let map_path = output_dir.join(format!("{}.lean.map", artifact.artifact_slug()));
353355

@@ -591,6 +593,17 @@ pub fn generate_lean_workspace(roots: &LockedRoots, artifacts: &[AnnealArtifact]
591593
Ok(())
592594
}
593595

596+
pub(crate) fn parse_funs_types_in_dir(output_dir: &Path) -> Result<crate::funs_types::FunsTypeMap> {
597+
let funs_path = output_dir.join("Funs.lean");
598+
if !funs_path.exists() {
599+
return Ok(crate::funs_types::FunsTypeMap::new());
600+
}
601+
602+
let content = std::fs::read_to_string(&funs_path)
603+
.with_context(|| format!("Failed to read {}", funs_path.display()))?;
604+
Ok(crate::funs_types::parse_funs_types(&content))
605+
}
606+
594607
/// Completes Lean verification by generating Anneal `Specs.lean`, writing `Generated.lean`,
595608
/// and running `lake build` + diagnostics.
596609
pub fn verify_lean_workspace(roots: &LockedRoots, artifacts: &[AnnealArtifact]) -> Result<()> {

anneal/src/funs_types.rs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
//! Parses Aeneas-generated `Funs.lean` signatures.
2+
//!
3+
//! Anneal's spec generator re-derives types from the Rust AST, losing
4+
//! qualification for `use`-imported names (e.g., `Ordering` instead of
5+
//! `core.sync.atomic.Ordering`). This module provides a lookup table from the
6+
//! Aeneas output as a corrective.
7+
8+
use std::collections::HashMap;
9+
10+
/// A function signature parsed from Aeneas-generated Lean.
11+
#[derive(Debug, Clone, Default, PartialEq, Eq)]
12+
pub struct FunsSignature {
13+
pub params: Vec<(String, String)>,
14+
pub ret: Option<String>,
15+
}
16+
17+
/// Function name → Aeneas-generated function signature.
18+
pub type FunsTypeMap = HashMap<String, FunsSignature>;
19+
20+
/// Parses `def` signatures from `Funs.lean`.
21+
///
22+
/// This extracts explicit parameter types and return types while skipping
23+
/// implicit `{}` parameters. Malformed signatures are ignored.
24+
pub fn parse_funs_types(content: &str) -> FunsTypeMap {
25+
let mut map = FunsTypeMap::new();
26+
let lines: Vec<&str> = content.lines().collect();
27+
28+
for i in 0..lines.len() {
29+
let Some(rest) = lines[i].trim().strip_prefix("def ") else {
30+
continue;
31+
};
32+
let name = rest.split([' ', '(', '{', ':']).next().unwrap_or("");
33+
if name.is_empty() {
34+
continue;
35+
}
36+
37+
// Collect signature lines until `:=`.
38+
let mut sig = String::from(rest);
39+
for j in (i + 1)..lines.len() {
40+
if sig.contains(":=") {
41+
break;
42+
}
43+
sig.push(' ');
44+
sig.push_str(lines[j].trim());
45+
}
46+
47+
let params = extract_params(&sig);
48+
let ret = extract_return_type(&sig);
49+
if !params.is_empty() || ret.is_some() {
50+
map.insert(name.to_string(), FunsSignature { params, ret });
51+
}
52+
}
53+
map
54+
}
55+
56+
/// Extracts `(name : type)` bindings, skipping `{implicit}` params.
57+
fn extract_params(sig: &str) -> Vec<(String, String)> {
58+
let mut params = Vec::new();
59+
let mut chars = sig.chars().peekable();
60+
61+
while let Some(&c) = chars.peek() {
62+
match c {
63+
'(' => {
64+
chars.next();
65+
if let Some(pair) = parse_binding(&collect_delimited(&mut chars, ')')) {
66+
params.push(pair);
67+
}
68+
}
69+
'{' => {
70+
chars.next();
71+
collect_delimited(&mut chars, '}');
72+
}
73+
':' => break,
74+
_ => {
75+
chars.next();
76+
}
77+
}
78+
}
79+
params
80+
}
81+
82+
/// Extracts the function return type.
83+
fn extract_return_type(sig: &str) -> Option<String> {
84+
let return_start = top_level_return_colon(sig)? + 1;
85+
let return_end = sig[return_start..].find(":=").map(|i| return_start + i).unwrap_or(sig.len());
86+
let ret = sig[return_start..return_end].trim();
87+
if ret.is_empty() {
88+
return None;
89+
}
90+
Some(ret.strip_prefix("Result ").unwrap_or(ret).trim().to_string())
91+
}
92+
93+
/// Finds the colon that separates the parameter list from the return type.
94+
fn top_level_return_colon(sig: &str) -> Option<usize> {
95+
let mut paren_depth = 0u32;
96+
let mut brace_depth = 0u32;
97+
for (i, c) in sig.char_indices() {
98+
match c {
99+
'(' => paren_depth += 1,
100+
')' => paren_depth = paren_depth.saturating_sub(1),
101+
'{' => brace_depth += 1,
102+
'}' => brace_depth = brace_depth.saturating_sub(1),
103+
':' if paren_depth == 0 && brace_depth == 0 => return Some(i),
104+
_ => {}
105+
}
106+
}
107+
None
108+
}
109+
110+
/// Reads chars until the matching `close` delimiter, handling nesting.
111+
fn collect_delimited(chars: &mut std::iter::Peekable<std::str::Chars<'_>>, close: char) -> String {
112+
let open = if close == ')' { '(' } else { '{' };
113+
let mut depth = 1u32;
114+
let mut buf = String::new();
115+
for c in chars.by_ref() {
116+
if c == open {
117+
depth += 1;
118+
} else if c == close {
119+
depth -= 1;
120+
if depth == 0 {
121+
return buf;
122+
}
123+
}
124+
buf.push(c);
125+
}
126+
buf
127+
}
128+
129+
/// Splits `"name : type"` on the first ` : `.
130+
fn parse_binding(s: &str) -> Option<(String, String)> {
131+
let s = s.trim();
132+
let i = s.find(" : ")?;
133+
let (name, ty) = (s[..i].trim(), s[i + 3..].trim());
134+
(!name.is_empty() && !ty.is_empty()).then_some((name.to_string(), ty.to_string()))
135+
}
136+
137+
#[cfg(test)]
138+
mod tests {
139+
use super::*;
140+
141+
#[test]
142+
fn simple() {
143+
let m = parse_funs_types("def hash_key (k : Std.Usize) : Result Std.Usize := do\n");
144+
assert_eq!(m["hash_key"].params, [("k".into(), "Std.Usize".into())]);
145+
assert_eq!(m["hash_key"].ret.as_deref(), Some("Std.Usize"));
146+
}
147+
148+
#[test]
149+
fn multiline() {
150+
let m = parse_funs_types(
151+
"def frame.AtomicFrameState.load\n\
152+
\x20 (self : frame.AtomicFrameState) (order : core.sync.atomic.Ordering) :\n\
153+
\x20 Result frame.FrameState\n\
154+
\x20 := do\n",
155+
);
156+
let p = &m["frame.AtomicFrameState.load"].params;
157+
assert_eq!(p[0], ("self".into(), "frame.AtomicFrameState".into()));
158+
assert_eq!(p[1], ("order".into(), "core.sync.atomic.Ordering".into()));
159+
assert_eq!(m["frame.AtomicFrameState.load"].ret.as_deref(), Some("frame.FrameState"));
160+
}
161+
162+
#[test]
163+
fn skips_implicits() {
164+
let m = parse_funs_types(
165+
"def HashMap.alloc {T : Type} (slots : alloc.vec.Vec T) (n : Std.Usize) :\n\
166+
\x20 Result Unit := do\n",
167+
);
168+
let p = &m["HashMap.alloc"].params;
169+
assert_eq!(p.len(), 2);
170+
assert_eq!(p[0].0, "slots");
171+
assert_eq!(p[1], ("n".into(), "Std.Usize".into()));
172+
}
173+
174+
#[test]
175+
fn multiple_ordering_params() {
176+
let m = parse_funs_types(
177+
"def f.compare_exchange\n\
178+
\x20 (self : f.T) (expected : f.S)\n\
179+
\x20 (success : core.sync.atomic.Ordering)\n\
180+
\x20 (failure : core.sync.atomic.Ordering) :\n\
181+
\x20 Result Unit := do\n",
182+
);
183+
let p = &m["f.compare_exchange"].params;
184+
assert_eq!(p.len(), 4);
185+
assert_eq!(p[2].1, "core.sync.atomic.Ordering");
186+
assert_eq!(p[3].1, "core.sync.atomic.Ordering");
187+
}
188+
189+
#[test]
190+
fn preserves_escaped_keyword_params() {
191+
let m = parse_funs_types("def f (show1 : core.sync.atomic.Ordering) : Result Unit := do\n");
192+
assert_eq!(m["f"].params, [("show1".into(), "core.sync.atomic.Ordering".into())]);
193+
}
194+
195+
#[test]
196+
fn parses_return_type() {
197+
let m = parse_funs_types(
198+
"def f (order : core.sync.atomic.Ordering) :\n\
199+
\x20 Result core.sync.atomic.Ordering\n\
200+
\x20 := do\n",
201+
);
202+
assert_eq!(m["f"].ret.as_deref(), Some("core.sync.atomic.Ordering"));
203+
}
204+
205+
#[test]
206+
fn trait_instance_no_params() {
207+
let m = parse_funs_types("def Foo.Clone : core.clone.Clone Foo := {\n clone := x\n}\n");
208+
assert!(m["Foo.Clone"].params.is_empty());
209+
assert_eq!(m["Foo.Clone"].ret.as_deref(), Some("core.clone.Clone Foo"));
210+
}
211+
}

0 commit comments

Comments
 (0)