Skip to content

Commit d00b714

Browse files
committed
feat: add out of order function calls
1 parent 0445ef9 commit d00b714

2 files changed

Lines changed: 365 additions & 12 deletions

File tree

src/ast.rs

Lines changed: 355 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::collections::hash_map::Entry;
2-
use std::collections::HashMap;
2+
use std::collections::{HashMap, HashSet};
33
use std::num::NonZeroUsize;
44
use std::str::FromStr;
55
use std::sync::Arc;
@@ -721,21 +721,57 @@ impl Program {
721721
pub fn analyze(from: &parse::Program) -> Result<Self, RichError> {
722722
let unit = ResolvedType::unit();
723723
let mut scope = Scope::default();
724-
let items = from
724+
725+
// Pass 1: Process type aliases in file order.
726+
for item in from.items() {
727+
if let parse::Item::TypeAlias(alias) = item {
728+
scope
729+
.insert_alias(alias.name().clone(), alias.ty().clone())
730+
.with_span(alias)?;
731+
}
732+
}
733+
734+
// Collect all non-main custom function definitions.
735+
let func_items: Vec<&parse::Function> = from
725736
.items()
726737
.iter()
727-
.map(|s| Item::analyze(s, &unit, &mut scope))
728-
.collect::<Result<Vec<Item>, RichError>>()?;
738+
.filter_map(|item| match item {
739+
parse::Item::Function(f) if f.name().as_inner() != "main" => Some(f),
740+
_ => None,
741+
})
742+
.collect();
743+
744+
// Build a call graph, reject calls to main, and reject recursive cycles.
745+
let call_graph = build_call_graph(&func_items);
746+
check_no_calls_to_main(&func_items, &call_graph)?;
747+
check_for_cycles(&func_items, &call_graph)?;
748+
749+
// Pass 2: Analyze custom functions in dependency order so that a
750+
// callee is always in scope before its callers are analyzed.
751+
for func in topological_sort(&func_items, &call_graph) {
752+
Function::analyze(func, &unit, &mut scope)?;
753+
}
754+
755+
// Pass 3: Find and analyze the main function.
756+
let mut main_expr: Option<Expression> = None;
757+
let mut seen_main = false;
758+
for item in from.items() {
759+
if let parse::Item::Function(f) = item {
760+
if f.name().as_inner() == "main" {
761+
if seen_main {
762+
return Err(Error::FunctionRedefined(FunctionName::main())).with_span(f);
763+
}
764+
seen_main = true;
765+
if let Function::Main(expr) = Function::analyze(f, &unit, &mut scope)? {
766+
main_expr = Some(expr);
767+
}
768+
}
769+
}
770+
}
771+
772+
let main = main_expr.ok_or(Error::MainRequired).with_span(from)?;
729773
debug_assert!(scope.is_topmost());
730774
let (parameters, witness_types, call_tracker) = scope.destruct();
731-
let mut iter = items.into_iter().filter_map(|item| match item {
732-
Item::Function(Function::Main(expr)) => Some(expr),
733-
_ => None,
734-
});
735-
let main = iter.next().ok_or(Error::MainRequired).with_span(from)?;
736-
if iter.next().is_some() {
737-
return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from);
738-
}
739775
Ok(Self {
740776
main,
741777
parameters,
@@ -745,6 +781,141 @@ impl Program {
745781
}
746782
}
747783

784+
/// Scan a parse expression for every custom function it calls directly.
785+
/// Covers plain calls, `array_fold`, `fold`, and `for_while` references.
786+
fn collect_custom_calls(expr: &parse::Expression) -> HashSet<FunctionName> {
787+
let mut calls = HashSet::new();
788+
for node in parse::ExprTree::Expression(expr).pre_order_iter() {
789+
if let parse::ExprTree::Call(call) = node {
790+
match call.name() {
791+
parse::CallName::Custom(name)
792+
| parse::CallName::ArrayFold(name, _)
793+
| parse::CallName::Fold(name, _)
794+
| parse::CallName::ForWhile(name) => {
795+
calls.insert(name.clone());
796+
}
797+
_ => {}
798+
}
799+
}
800+
}
801+
calls
802+
}
803+
804+
/// Build a map from each function name to the set of custom functions it calls.
805+
fn build_call_graph(
806+
func_items: &[&parse::Function],
807+
) -> HashMap<FunctionName, HashSet<FunctionName>> {
808+
func_items
809+
.iter()
810+
.map(|func| (func.name().clone(), collect_custom_calls(func.body())))
811+
.collect()
812+
}
813+
814+
/// Return an error if any function directly calls `main`.
815+
fn check_no_calls_to_main(
816+
func_items: &[&parse::Function],
817+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
818+
) -> Result<(), RichError> {
819+
let main = FunctionName::main();
820+
for func in func_items {
821+
if let Some(deps) = graph.get(func.name()) {
822+
if deps.contains(&main) {
823+
return Err(Error::MainNotCallable).with_span(*func);
824+
}
825+
}
826+
}
827+
Ok(())
828+
}
829+
830+
/// DFS helper that returns the name of a node involved in a cycle, if any.
831+
fn dfs_find_cycle(
832+
name: &FunctionName,
833+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
834+
visiting: &mut HashSet<FunctionName>,
835+
visited: &mut HashSet<FunctionName>,
836+
) -> Option<FunctionName> {
837+
if visiting.contains(name) {
838+
return Some(name.clone());
839+
}
840+
if visited.contains(name) {
841+
return None;
842+
}
843+
visiting.insert(name.clone());
844+
if let Some(deps) = graph.get(name) {
845+
for dep in deps {
846+
if let Some(cycle) = dfs_find_cycle(dep, graph, visiting, visited) {
847+
return Some(cycle);
848+
}
849+
}
850+
}
851+
visiting.remove(name);
852+
visited.insert(name.clone());
853+
None
854+
}
855+
856+
/// Return an error if any function in the call graph participates in a cycle
857+
/// (direct self-recursion or mutual recursion).
858+
fn check_for_cycles(
859+
func_items: &[&parse::Function],
860+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
861+
) -> Result<(), RichError> {
862+
let mut visiting: HashSet<FunctionName> = HashSet::new();
863+
let mut visited: HashSet<FunctionName> = HashSet::new();
864+
for func in func_items {
865+
if !visited.contains(func.name()) {
866+
if let Some(cycle_name) =
867+
dfs_find_cycle(func.name(), graph, &mut visiting, &mut visited)
868+
{
869+
let offending = func_items
870+
.iter()
871+
.find(|f| f.name() == &cycle_name)
872+
.copied()
873+
.unwrap_or(func);
874+
return Err(Error::FunctionRecursive(cycle_name)).with_span(offending);
875+
}
876+
}
877+
}
878+
Ok(())
879+
}
880+
881+
/// DFS post-order visitor for topological sort.
882+
fn topo_dfs<'a>(
883+
name: &FunctionName,
884+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
885+
func_map: &HashMap<&FunctionName, &'a parse::Function>,
886+
visited: &mut HashSet<FunctionName>,
887+
result: &mut Vec<&'a parse::Function>,
888+
) {
889+
if visited.contains(name) {
890+
return;
891+
}
892+
visited.insert(name.clone());
893+
if let Some(deps) = graph.get(name) {
894+
for dep in deps {
895+
topo_dfs(dep, graph, func_map, visited, result);
896+
}
897+
}
898+
if let Some(&func) = func_map.get(name) {
899+
result.push(func);
900+
}
901+
}
902+
903+
/// Return the non-main functions in an order where every callee appears before
904+
/// its callers. Assumes no cycles (call `check_for_cycles` first).
905+
fn topological_sort<'a>(
906+
func_items: &[&'a parse::Function],
907+
graph: &HashMap<FunctionName, HashSet<FunctionName>>,
908+
) -> Vec<&'a parse::Function> {
909+
let func_map: HashMap<&FunctionName, &'a parse::Function> =
910+
func_items.iter().map(|f| (f.name(), *f)).collect();
911+
let mut visited: HashSet<FunctionName> = HashSet::new();
912+
let mut result: Vec<&'a parse::Function> = Vec::new();
913+
for func in func_items {
914+
topo_dfs(func.name(), graph, &func_map, &mut visited, &mut result);
915+
}
916+
result
917+
}
918+
748919
impl AbstractSyntaxTree for Item {
749920
type From = parse::Item;
750921

@@ -1579,3 +1750,175 @@ impl AsRef<Span> for ModuleAssignment {
15791750
&self.span
15801751
}
15811752
}
1753+
1754+
#[cfg(test)]
1755+
mod tests {
1756+
use crate::TemplateProgram;
1757+
1758+
fn compile_ok(src: &str) {
1759+
TemplateProgram::new(src).unwrap_or_else(|e| panic!("Expected success:\n{e}"));
1760+
}
1761+
1762+
fn compile_err(src: &str, expected_msg: &str) {
1763+
match TemplateProgram::new(src) {
1764+
Ok(_) => {
1765+
panic!("Expected error containing '{expected_msg}', but compilation succeeded")
1766+
}
1767+
Err(e) => assert!(
1768+
e.contains(expected_msg),
1769+
"Error message did not contain '{expected_msg}':\n{e}"
1770+
),
1771+
}
1772+
}
1773+
1774+
// -----------------------------------------------------------------------
1775+
// Forward references
1776+
// -----------------------------------------------------------------------
1777+
1778+
#[test]
1779+
fn forward_reference_simple() {
1780+
// `always_true` is called in `main` but defined after it.
1781+
compile_ok(
1782+
r#"fn main() {
1783+
let x: bool = always_true();
1784+
assert!(x);
1785+
}
1786+
1787+
fn always_true() -> bool {
1788+
true
1789+
}"#,
1790+
);
1791+
}
1792+
1793+
#[test]
1794+
fn forward_reference_chain() {
1795+
// `main` calls `outer`, which calls `inner`, both defined after main.
1796+
compile_ok(
1797+
r#"fn main() {
1798+
let x: bool = outer();
1799+
assert!(x);
1800+
}
1801+
1802+
fn outer() -> bool {
1803+
inner()
1804+
}
1805+
1806+
fn inner() -> bool {
1807+
true
1808+
}"#,
1809+
);
1810+
}
1811+
1812+
#[test]
1813+
fn forward_reference_array_fold() {
1814+
// `summer` is referenced in `array_fold` before it is defined.
1815+
compile_ok(
1816+
r#"fn main() {
1817+
let arr: [u32; 3] = [1, 2, 3];
1818+
let _total: u32 = array_fold::<summer, 3>(arr, 0);
1819+
}
1820+
1821+
fn summer(elt: u32, acc: u32) -> u32 {
1822+
let (_, result): (bool, u32) = jet::add_32(elt, acc);
1823+
result
1824+
}"#,
1825+
);
1826+
}
1827+
1828+
// -----------------------------------------------------------------------
1829+
// Recursion detection
1830+
// -----------------------------------------------------------------------
1831+
1832+
#[test]
1833+
fn direct_self_recursion_rejected() {
1834+
compile_err(
1835+
r#"fn foo(n: u32) -> u32 {
1836+
foo(n)
1837+
}
1838+
1839+
fn main() {
1840+
let _x: u32 = foo(1);
1841+
}"#,
1842+
"recursive call cycle",
1843+
);
1844+
}
1845+
1846+
#[test]
1847+
fn mutual_recursion_two_functions_rejected() {
1848+
// foo → bar → foo
1849+
compile_err(
1850+
r#"fn foo(n: u32) -> u32 {
1851+
bar(n)
1852+
}
1853+
1854+
fn bar(n: u32) -> u32 {
1855+
foo(n)
1856+
}
1857+
1858+
fn main() {
1859+
let _x: u32 = foo(1);
1860+
}"#,
1861+
"recursive call cycle",
1862+
);
1863+
}
1864+
1865+
#[test]
1866+
fn mutual_recursion_three_functions_rejected() {
1867+
// a → b → c → a
1868+
compile_err(
1869+
r#"fn a(n: u32) -> u32 {
1870+
b(n)
1871+
}
1872+
1873+
fn b(n: u32) -> u32 {
1874+
c(n)
1875+
}
1876+
1877+
fn c(n: u32) -> u32 {
1878+
a(n)
1879+
}
1880+
1881+
fn main() {
1882+
let _x: u32 = a(1);
1883+
}"#,
1884+
"recursive call cycle",
1885+
);
1886+
}
1887+
1888+
#[test]
1889+
fn call_to_main_rejected() {
1890+
compile_err(
1891+
r#"fn foo() -> bool {
1892+
main()
1893+
}
1894+
1895+
fn main() {
1896+
assert!(foo());
1897+
}"#,
1898+
"Function `main` cannot be called",
1899+
);
1900+
}
1901+
1902+
#[test]
1903+
fn non_recursive_chain_ok() {
1904+
// top → mid → bottom, defined in reverse order, no cycle.
1905+
compile_ok(
1906+
r#"fn main() {
1907+
let x: bool = top();
1908+
assert!(x);
1909+
}
1910+
1911+
fn top() -> bool {
1912+
mid()
1913+
}
1914+
1915+
fn mid() -> bool {
1916+
bottom()
1917+
}
1918+
1919+
fn bottom() -> bool {
1920+
true
1921+
}"#,
1922+
);
1923+
}
1924+
}

0 commit comments

Comments
 (0)