Skip to content

Commit fc68b48

Browse files
committed
feat: add out of order function calls
1 parent 0445ef9 commit fc68b48

2 files changed

Lines changed: 364 additions & 12 deletions

File tree

src/ast.rs

Lines changed: 354 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::collections::hash_map::Entry;
2-
use std::collections::HashMap;
2+
use std::collections::{HashMap, HashSet};
33
use std::num::NonZeroUsize;
44
use std::str::FromStr;
55
use std::sync::Arc;
@@ -721,21 +721,58 @@ impl Program {
721721
pub fn analyze(from: &parse::Program) -> Result<Self, RichError> {
722722
let unit = ResolvedType::unit();
723723
let mut scope = Scope::default();
724-
let items = from
724+
725+
// Pass 1: Process type aliases in file order.
726+
for item in from.items() {
727+
if let parse::Item::TypeAlias(alias) = item {
728+
scope
729+
.insert_alias(alias.name().clone(), alias.ty().clone())
730+
.with_span(alias)?;
731+
}
732+
}
733+
734+
// Collect all non-main custom function definitions.
735+
let func_items: Vec<&parse::Function> = from
725736
.items()
726737
.iter()
727-
.map(|s| Item::analyze(s, &unit, &mut scope))
728-
.collect::<Result<Vec<Item>, RichError>>()?;
738+
.filter_map(|item| match item {
739+
parse::Item::Function(f) if f.name().as_inner() != "main" => Some(f),
740+
_ => None,
741+
})
742+
.collect();
743+
744+
// Build a call graph, reject calls to main, and reject recursive cycles.
745+
let call_graph = build_call_graph(&func_items);
746+
check_no_calls_to_main(&func_items, &call_graph)?;
747+
check_for_cycles(&func_items, &call_graph)?;
748+
749+
// Pass 2: Analyze custom functions in dependency order so that a
750+
// callee is always in scope before its callers are analyzed.
751+
for func in topological_sort(&func_items, &call_graph) {
752+
Function::analyze(func, &unit, &mut scope)?;
753+
}
754+
755+
// Pass 3: Find and analyze the main function.
756+
let mut main_expr: Option<Expression> = None;
757+
let mut seen_main = false;
758+
for item in from.items() {
759+
if let parse::Item::Function(f) = item {
760+
if f.name().as_inner() == "main" {
761+
if seen_main {
762+
return Err(Error::FunctionRedefined(FunctionName::main()))
763+
.with_span(f);
764+
}
765+
seen_main = true;
766+
if let Function::Main(expr) = Function::analyze(f, &unit, &mut scope)? {
767+
main_expr = Some(expr);
768+
}
769+
}
770+
}
771+
}
772+
773+
let main = main_expr.ok_or(Error::MainRequired).with_span(from)?;
729774
debug_assert!(scope.is_topmost());
730775
let (parameters, witness_types, call_tracker) = scope.destruct();
731-
let mut iter = items.into_iter().filter_map(|item| match item {
732-
Item::Function(Function::Main(expr)) => Some(expr),
733-
_ => None,
734-
});
735-
let main = iter.next().ok_or(Error::MainRequired).with_span(from)?;
736-
if iter.next().is_some() {
737-
return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from);
738-
}
739776
Ok(Self {
740777
main,
741778
parameters,
@@ -745,6 +782,141 @@ impl Program {
745782
}
746783
}
747784

785+
/// Scan a parse expression for every custom function it calls directly.
786+
/// Covers plain calls, `array_fold`, `fold`, and `for_while` references.
787+
fn collect_custom_calls(expr: &parse::Expression) -> HashSet<FunctionName> {
788+
let mut calls = HashSet::new();
789+
for node in parse::ExprTree::Expression(expr).pre_order_iter() {
790+
if let parse::ExprTree::Call(call) = node {
791+
match call.name() {
792+
parse::CallName::Custom(name)
793+
| parse::CallName::ArrayFold(name, _)
794+
| parse::CallName::Fold(name, _)
795+
| parse::CallName::ForWhile(name) => {
796+
calls.insert(name.clone());
797+
}
798+
_ => {}
799+
}
800+
}
801+
}
802+
calls
803+
}
804+
805+
/// Build a map from each function name to the set of custom functions it calls.
806+
fn build_call_graph(
807+
func_items: &[&parse::Function],
808+
) -> HashMap<FunctionName, HashSet<FunctionName>> {
809+
func_items
810+
.iter()
811+
.map(|func| (func.name().clone(), collect_custom_calls(func.body())))
812+
.collect()
813+
}
814+
815+
/// Return an error if any function directly calls `main`.
816+
fn check_no_calls_to_main(
817+
func_items: &[&parse::Function],
818+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
819+
) -> Result<(), RichError> {
820+
let main = FunctionName::main();
821+
for func in func_items {
822+
if let Some(deps) = graph.get(func.name()) {
823+
if deps.contains(&main) {
824+
return Err(Error::MainNotCallable).with_span(*func);
825+
}
826+
}
827+
}
828+
Ok(())
829+
}
830+
831+
/// DFS helper that returns the name of a node involved in a cycle, if any.
832+
fn dfs_find_cycle(
833+
name: &FunctionName,
834+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
835+
visiting: &mut HashSet<FunctionName>,
836+
visited: &mut HashSet<FunctionName>,
837+
) -> Option<FunctionName> {
838+
if visiting.contains(name) {
839+
return Some(name.clone());
840+
}
841+
if visited.contains(name) {
842+
return None;
843+
}
844+
visiting.insert(name.clone());
845+
if let Some(deps) = graph.get(name) {
846+
for dep in deps {
847+
if let Some(cycle) = dfs_find_cycle(dep, graph, visiting, visited) {
848+
return Some(cycle);
849+
}
850+
}
851+
}
852+
visiting.remove(name);
853+
visited.insert(name.clone());
854+
None
855+
}
856+
857+
/// Return an error if any function in the call graph participates in a cycle
858+
/// (direct self-recursion or mutual recursion).
859+
fn check_for_cycles(
860+
func_items: &[&parse::Function],
861+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
862+
) -> Result<(), RichError> {
863+
let mut visiting: HashSet<FunctionName> = HashSet::new();
864+
let mut visited: HashSet<FunctionName> = HashSet::new();
865+
for func in func_items {
866+
if !visited.contains(func.name()) {
867+
if let Some(cycle_name) =
868+
dfs_find_cycle(func.name(), graph, &mut visiting, &mut visited)
869+
{
870+
let offending = func_items
871+
.iter()
872+
.find(|f| f.name() == &cycle_name)
873+
.copied()
874+
.unwrap_or(func);
875+
return Err(Error::FunctionRecursive(cycle_name)).with_span(&*offending);
876+
}
877+
}
878+
}
879+
Ok(())
880+
}
881+
882+
/// DFS post-order visitor for topological sort.
883+
fn topo_dfs<'a>(
884+
name: &FunctionName,
885+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
886+
func_map: &HashMap<&FunctionName, &'a parse::Function>,
887+
visited: &mut HashSet<FunctionName>,
888+
result: &mut Vec<&'a parse::Function>,
889+
) {
890+
if visited.contains(name) {
891+
return;
892+
}
893+
visited.insert(name.clone());
894+
if let Some(deps) = graph.get(name) {
895+
for dep in deps {
896+
topo_dfs(dep, graph, func_map, visited, result);
897+
}
898+
}
899+
if let Some(&func) = func_map.get(name) {
900+
result.push(func);
901+
}
902+
}
903+
904+
/// Return the non-main functions in an order where every callee appears before
905+
/// its callers. Assumes no cycles (call `check_for_cycles` first).
906+
fn topological_sort<'a>(
907+
func_items: &[&'a parse::Function],
908+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
909+
) -> Vec<&'a parse::Function> {
910+
let func_map: HashMap<&FunctionName, &'a parse::Function> =
911+
func_items.iter().map(|f| (f.name(), *f)).collect();
912+
let mut visited: HashSet<FunctionName> = HashSet::new();
913+
let mut result: Vec<&'a parse::Function> = Vec::new();
914+
for func in func_items {
915+
topo_dfs(func.name(), graph, &func_map, &mut visited, &mut result);
916+
}
917+
result
918+
}
919+
748920
impl AbstractSyntaxTree for Item {
749921
type From = parse::Item;
750922

@@ -1579,3 +1751,173 @@ impl AsRef<Span> for ModuleAssignment {
15791751
&self.span
15801752
}
15811753
}
1754+
1755+
#[cfg(test)]
1756+
mod tests {
1757+
use crate::TemplateProgram;
1758+
1759+
fn compile_ok(src: &str) {
1760+
TemplateProgram::new(src).unwrap_or_else(|e| panic!("Expected success:\n{e}"));
1761+
}
1762+
1763+
fn compile_err(src: &str, expected_msg: &str) {
1764+
match TemplateProgram::new(src) {
1765+
Ok(_) => panic!("Expected error containing '{expected_msg}', but compilation succeeded"),
1766+
Err(e) => assert!(
1767+
e.contains(expected_msg),
1768+
"Error message did not contain '{expected_msg}':\n{e}"
1769+
),
1770+
}
1771+
}
1772+
1773+
// -----------------------------------------------------------------------
1774+
// Forward references
1775+
// -----------------------------------------------------------------------
1776+
1777+
#[test]
1778+
fn forward_reference_simple() {
1779+
// `always_true` is called in `main` but defined after it.
1780+
compile_ok(
1781+
r#"fn main() {
1782+
let x: bool = always_true();
1783+
assert!(x);
1784+
}
1785+
1786+
fn always_true() -> bool {
1787+
true
1788+
}"#,
1789+
);
1790+
}
1791+
1792+
#[test]
1793+
fn forward_reference_chain() {
1794+
// `main` calls `outer`, which calls `inner`, both defined after main.
1795+
compile_ok(
1796+
r#"fn main() {
1797+
let x: bool = outer();
1798+
assert!(x);
1799+
}
1800+
1801+
fn outer() -> bool {
1802+
inner()
1803+
}
1804+
1805+
fn inner() -> bool {
1806+
true
1807+
}"#,
1808+
);
1809+
}
1810+
1811+
#[test]
1812+
fn forward_reference_array_fold() {
1813+
// `summer` is referenced in `array_fold` before it is defined.
1814+
compile_ok(
1815+
r#"fn main() {
1816+
let arr: [u32; 3] = [1, 2, 3];
1817+
let _total: u32 = array_fold::<summer, 3>(arr, 0);
1818+
}
1819+
1820+
fn summer(elt: u32, acc: u32) -> u32 {
1821+
let (_, result): (bool, u32) = jet::add_32(elt, acc);
1822+
result
1823+
}"#,
1824+
);
1825+
}
1826+
1827+
// -----------------------------------------------------------------------
1828+
// Recursion detection
1829+
// -----------------------------------------------------------------------
1830+
1831+
#[test]
1832+
fn direct_self_recursion_rejected() {
1833+
compile_err(
1834+
r#"fn foo(n: u32) -> u32 {
1835+
foo(n)
1836+
}
1837+
1838+
fn main() {
1839+
let _x: u32 = foo(1);
1840+
}"#,
1841+
"recursive call cycle",
1842+
);
1843+
}
1844+
1845+
#[test]
1846+
fn mutual_recursion_two_functions_rejected() {
1847+
// foo → bar → foo
1848+
compile_err(
1849+
r#"fn foo(n: u32) -> u32 {
1850+
bar(n)
1851+
}
1852+
1853+
fn bar(n: u32) -> u32 {
1854+
foo(n)
1855+
}
1856+
1857+
fn main() {
1858+
let _x: u32 = foo(1);
1859+
}"#,
1860+
"recursive call cycle",
1861+
);
1862+
}
1863+
1864+
#[test]
1865+
fn mutual_recursion_three_functions_rejected() {
1866+
// a → b → c → a
1867+
compile_err(
1868+
r#"fn a(n: u32) -> u32 {
1869+
b(n)
1870+
}
1871+
1872+
fn b(n: u32) -> u32 {
1873+
c(n)
1874+
}
1875+
1876+
fn c(n: u32) -> u32 {
1877+
a(n)
1878+
}
1879+
1880+
fn main() {
1881+
let _x: u32 = a(1);
1882+
}"#,
1883+
"recursive call cycle",
1884+
);
1885+
}
1886+
1887+
#[test]
1888+
fn call_to_main_rejected() {
1889+
compile_err(
1890+
r#"fn foo() -> bool {
1891+
main()
1892+
}
1893+
1894+
fn main() {
1895+
assert!(foo());
1896+
}"#,
1897+
"Function `main` cannot be called",
1898+
);
1899+
}
1900+
1901+
#[test]
1902+
fn non_recursive_chain_ok() {
1903+
// top → mid → bottom, defined in reverse order, no cycle.
1904+
compile_ok(
1905+
r#"fn main() {
1906+
let x: bool = top();
1907+
assert!(x);
1908+
}
1909+
1910+
fn top() -> bool {
1911+
mid()
1912+
}
1913+
1914+
fn mid() -> bool {
1915+
bottom()
1916+
}
1917+
1918+
fn bottom() -> bool {
1919+
true
1920+
}"#,
1921+
);
1922+
}
1923+
}

0 commit comments

Comments
 (0)