diff --git a/src/ast.rs b/src/ast.rs index 20dd2a7..70e773e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,5 +1,5 @@ use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::num::NonZeroUsize; use std::str::FromStr; use std::sync::Arc; @@ -721,21 +721,60 @@ impl Program { pub fn analyze(from: &parse::Program) -> Result { let unit = ResolvedType::unit(); let mut scope = Scope::default(); - let items = from + + // Pass 1: Process type aliases in file order. + for item in from.items() { + if let parse::Item::TypeAlias(alias) = item { + scope + .insert_alias(alias.name().clone(), alias.ty().clone()) + .with_span(alias)?; + } + } + + // Collect all non-main custom function definitions. + let func_items: Vec<&parse::Function> = from .items() .iter() - .map(|s| Item::analyze(s, &unit, &mut scope)) - .collect::, RichError>>()?; + .filter_map(|item| match item { + parse::Item::Function(f) if f.name().as_inner() != "main" => Some(f), + _ => None, + }) + .collect(); + + // Reject duplicate non-main function names before building the call graph. + check_no_duplicate_functions(&func_items)?; + + // Build a call graph, reject calls to main, and reject recursive cycles. + let call_graph = build_call_graph(&func_items); + check_no_calls_to_main(&func_items, &call_graph)?; + check_for_cycles(&func_items, &call_graph)?; + + // Pass 2: Analyze custom functions in dependency order so that a + // callee is always in scope before its callers are analyzed. + for func in topological_sort(&func_items, &call_graph) { + Function::analyze(func, &unit, &mut scope)?; + } + + // Pass 3: Find and analyze the main function. + let mut main_expr: Option = None; + let mut seen_main = false; + for item in from.items() { + if let parse::Item::Function(f) = item { + if f.name().as_inner() == "main" { + if seen_main { + return Err(Error::FunctionRedefined(FunctionName::main())).with_span(f); + } + seen_main = true; + if let Function::Main(expr) = Function::analyze(f, &unit, &mut scope)? { + main_expr = Some(expr); + } + } + } + } + + let main = main_expr.ok_or(Error::MainRequired).with_span(from)?; debug_assert!(scope.is_topmost()); let (parameters, witness_types, call_tracker) = scope.destruct(); - let mut iter = items.into_iter().filter_map(|item| match item { - Item::Function(Function::Main(expr)) => Some(expr), - _ => None, - }); - let main = iter.next().ok_or(Error::MainRequired).with_span(from)?; - if iter.next().is_some() { - return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from); - } Ok(Self { main, parameters, @@ -745,6 +784,152 @@ impl Program { } } +/// Scan a parse expression for every custom function it calls directly. +/// Covers plain calls, `array_fold`, `fold`, and `for_while` references. +fn collect_custom_calls(expr: &parse::Expression) -> HashSet { + let mut calls = HashSet::new(); + for node in parse::ExprTree::Expression(expr).pre_order_iter() { + if let parse::ExprTree::Call(call) = node { + match call.name() { + parse::CallName::Custom(name) + | parse::CallName::ArrayFold(name, _) + | parse::CallName::Fold(name, _) + | parse::CallName::ForWhile(name) => { + calls.insert(name.clone()); + } + _ => {} + } + } + } + calls +} + +/// Return an error if any two non-main functions share the same name. +fn check_no_duplicate_functions(func_items: &[&parse::Function]) -> Result<(), RichError> { + let mut seen: HashSet<&FunctionName> = HashSet::new(); + for func in func_items { + if !seen.insert(func.name()) { + return Err(Error::FunctionRedefined(func.name().clone())).with_span(*func); + } + } + Ok(()) +} + +/// Build a map from each function name to the set of custom functions it calls. +fn build_call_graph( + func_items: &[&parse::Function], +) -> HashMap> { + func_items + .iter() + .map(|func| (func.name().clone(), collect_custom_calls(func.body()))) + .collect() +} + +/// Return an error if any function directly calls `main`. +fn check_no_calls_to_main( + func_items: &[&parse::Function], + graph: &HashMap>, +) -> Result<(), RichError> { + let main = FunctionName::main(); + for func in func_items { + if let Some(deps) = graph.get(func.name()) { + if deps.contains(&main) { + return Err(Error::MainNotCallable).with_span(*func); + } + } + } + Ok(()) +} + +/// DFS helper that returns the name of a node involved in a cycle, if any. +fn dfs_find_cycle( + name: &FunctionName, + graph: &HashMap>, + visiting: &mut HashSet, + visited: &mut HashSet, +) -> Option { + if visiting.contains(name) { + return Some(name.clone()); + } + if visited.contains(name) { + return None; + } + visiting.insert(name.clone()); + if let Some(deps) = graph.get(name) { + for dep in deps { + if let Some(cycle) = dfs_find_cycle(dep, graph, visiting, visited) { + return Some(cycle); + } + } + } + visiting.remove(name); + visited.insert(name.clone()); + None +} + +/// Return an error if any function in the call graph participates in a cycle +/// (direct self-recursion or mutual recursion). +fn check_for_cycles( + func_items: &[&parse::Function], + graph: &HashMap>, +) -> Result<(), RichError> { + let mut visiting: HashSet = HashSet::new(); + let mut visited: HashSet = HashSet::new(); + for func in func_items { + if !visited.contains(func.name()) { + if let Some(cycle_name) = + dfs_find_cycle(func.name(), graph, &mut visiting, &mut visited) + { + let offending = func_items + .iter() + .find(|f| f.name() == &cycle_name) + .copied() + .unwrap_or(func); + return Err(Error::FunctionRecursive(cycle_name)).with_span(offending); + } + } + } + Ok(()) +} + +/// DFS post-order visitor for topological sort. +fn topo_dfs<'a>( + name: &FunctionName, + graph: &HashMap>, + func_map: &HashMap<&FunctionName, &'a parse::Function>, + visited: &mut HashSet, + result: &mut Vec<&'a parse::Function>, +) { + if visited.contains(name) { + return; + } + visited.insert(name.clone()); + if let Some(deps) = graph.get(name) { + for dep in deps { + topo_dfs(dep, graph, func_map, visited, result); + } + } + if let Some(&func) = func_map.get(name) { + result.push(func); + } +} + +/// Return the non-main functions in an order where every callee appears before +/// its callers. Assumes no cycles (call `check_for_cycles` first). +fn topological_sort<'a>( + func_items: &[&'a parse::Function], + graph: &HashMap>, +) -> Vec<&'a parse::Function> { + let func_map: HashMap<&FunctionName, &'a parse::Function> = + func_items.iter().map(|f| (f.name(), *f)).collect(); + let mut visited: HashSet = HashSet::new(); + let mut result: Vec<&'a parse::Function> = Vec::new(); + for func in func_items { + topo_dfs(func.name(), graph, &func_map, &mut visited, &mut result); + } + result +} + impl AbstractSyntaxTree for Item { type From = parse::Item; @@ -1579,3 +1764,193 @@ impl AsRef for ModuleAssignment { &self.span } } + +#[cfg(test)] +mod tests { + use crate::TemplateProgram; + + fn compile_ok(src: &str) { + TemplateProgram::new(src).unwrap_or_else(|e| panic!("Expected success:\n{e}")); + } + + fn compile_err(src: &str, expected_msg: &str) { + match TemplateProgram::new(src) { + Ok(_) => { + panic!("Expected error containing '{expected_msg}', but compilation succeeded") + } + Err(e) => assert!( + e.contains(expected_msg), + "Error message did not contain '{expected_msg}':\n{e}" + ), + } + } + + // ----------------------------------------------------------------------- + // Forward references + // ----------------------------------------------------------------------- + + #[test] + fn forward_reference_simple() { + // `always_true` is called in `main` but defined after it. + compile_ok( + r#"fn main() { + let x: bool = always_true(); + assert!(x); +} + +fn always_true() -> bool { + true +}"#, + ); + } + + #[test] + fn forward_reference_chain() { + // `main` calls `outer`, which calls `inner`, both defined after main. + compile_ok( + r#"fn main() { + let x: bool = outer(); + assert!(x); +} + +fn outer() -> bool { + inner() +} + +fn inner() -> bool { + true +}"#, + ); + } + + #[test] + fn forward_reference_array_fold() { + // `summer` is referenced in `array_fold` before it is defined. + compile_ok( + r#"fn main() { + let arr: [u32; 3] = [1, 2, 3]; + let _total: u32 = array_fold::(arr, 0); +} + +fn summer(elt: u32, acc: u32) -> u32 { + let (_, result): (bool, u32) = jet::add_32(elt, acc); + result +}"#, + ); + } + + // ----------------------------------------------------------------------- + // Recursion detection + // ----------------------------------------------------------------------- + + #[test] + fn direct_self_recursion_rejected() { + compile_err( + r#"fn foo(n: u32) -> u32 { + foo(n) +} + +fn main() { + let _x: u32 = foo(1); +}"#, + "recursive call cycle", + ); + } + + #[test] + fn mutual_recursion_two_functions_rejected() { + // foo → bar → foo + compile_err( + r#"fn foo(n: u32) -> u32 { + bar(n) +} + +fn bar(n: u32) -> u32 { + foo(n) +} + +fn main() { + let _x: u32 = foo(1); +}"#, + "recursive call cycle", + ); + } + + #[test] + fn mutual_recursion_three_functions_rejected() { + // a → b → c → a + compile_err( + r#"fn a(n: u32) -> u32 { + b(n) +} + +fn b(n: u32) -> u32 { + c(n) +} + +fn c(n: u32) -> u32 { + a(n) +} + +fn main() { + let _x: u32 = a(1); +}"#, + "recursive call cycle", + ); + } + + #[test] + fn call_to_main_rejected() { + compile_err( + r#"fn foo() -> bool { + main() +} + +fn main() { + assert!(foo()); +}"#, + "Function `main` cannot be called", + ); + } + + #[test] + fn duplicate_helper_function_rejected() { + compile_err( + r#"fn foo() -> bool { + true +} + +fn foo() -> bool { + false +} + +fn main() { + assert!(foo()); +}"#, + "Function `foo` was defined multiple times", + ); + } + + #[test] + fn non_recursive_chain_ok() { + // top → mid → bottom, defined in reverse order, no cycle. + compile_ok( + r#"fn main() { + let x: bool = top(); + assert!(x); +} + +fn top() -> bool { + mid() +} + +fn mid() -> bool { + bottom() +} + +fn bottom() -> bool { + true +}"#, + ); + } +} diff --git a/src/error.rs b/src/error.rs index b4d5a4b..86ec13e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -426,8 +426,10 @@ pub enum Error { MainNoInputs, MainNoOutput, MainRequired, + MainNotCallable, FunctionRedefined(FunctionName), FunctionUndefined(FunctionName), + FunctionRecursive(FunctionName), InvalidNumberOfArguments(usize, usize), FunctionNotFoldable(FunctionName), FunctionNotLoopable(FunctionName), @@ -516,6 +518,10 @@ impl fmt::Display for Error { f, "Main function is required" ), + Error::MainNotCallable => write!( + f, + "Function `main` cannot be called" + ), Error::FunctionRedefined(name) => write!( f, "Function `{name}` was defined multiple times" @@ -524,6 +530,10 @@ impl fmt::Display for Error { f, "Function `{name}` was called but not defined" ), + Error::FunctionRecursive(name) => write!( + f, + "Function `{name}` is part of a recursive call cycle, which is not allowed" + ), Error::InvalidNumberOfArguments(expected, found) => write!( f, "Expected {expected} arguments, found {found} arguments"