11use std:: collections:: hash_map:: Entry ;
2- use std:: collections:: HashMap ;
2+ use std:: collections:: { HashMap , HashSet } ;
33use std:: num:: NonZeroUsize ;
44use std:: str:: FromStr ;
55use std:: sync:: Arc ;
@@ -721,21 +721,58 @@ impl Program {
721721 pub fn analyze ( from : & parse:: Program ) -> Result < Self , RichError > {
722722 let unit = ResolvedType :: unit ( ) ;
723723 let mut scope = Scope :: default ( ) ;
724- let items = from
724+
725+ // Pass 1: Process type aliases in file order.
726+ for item in from. items ( ) {
727+ if let parse:: Item :: TypeAlias ( alias) = item {
728+ scope
729+ . insert_alias ( alias. name ( ) . clone ( ) , alias. ty ( ) . clone ( ) )
730+ . with_span ( alias) ?;
731+ }
732+ }
733+
734+ // Collect all non-main custom function definitions.
735+ let func_items: Vec < & parse:: Function > = from
725736 . items ( )
726737 . iter ( )
727- . map ( |s| Item :: analyze ( s, & unit, & mut scope) )
728- . collect :: < Result < Vec < Item > , RichError > > ( ) ?;
738+ . filter_map ( |item| match item {
739+ parse:: Item :: Function ( f) if f. name ( ) . as_inner ( ) != "main" => Some ( f) ,
740+ _ => None ,
741+ } )
742+ . collect ( ) ;
743+
744+ // Build a call graph, reject calls to main, and reject recursive cycles.
745+ let call_graph = build_call_graph ( & func_items) ;
746+ check_no_calls_to_main ( & func_items, & call_graph) ?;
747+ check_for_cycles ( & func_items, & call_graph) ?;
748+
749+ // Pass 2: Analyze custom functions in dependency order so that a
750+ // callee is always in scope before its callers are analyzed.
751+ for func in topological_sort ( & func_items, & call_graph) {
752+ Function :: analyze ( func, & unit, & mut scope) ?;
753+ }
754+
755+ // Pass 3: Find and analyze the main function.
756+ let mut main_expr: Option < Expression > = None ;
757+ let mut seen_main = false ;
758+ for item in from. items ( ) {
759+ if let parse:: Item :: Function ( f) = item {
760+ if f. name ( ) . as_inner ( ) == "main" {
761+ if seen_main {
762+ return Err ( Error :: FunctionRedefined ( FunctionName :: main ( ) ) )
763+ . with_span ( f) ;
764+ }
765+ seen_main = true ;
766+ if let Function :: Main ( expr) = Function :: analyze ( f, & unit, & mut scope) ? {
767+ main_expr = Some ( expr) ;
768+ }
769+ }
770+ }
771+ }
772+
773+ let main = main_expr. ok_or ( Error :: MainRequired ) . with_span ( from) ?;
729774 debug_assert ! ( scope. is_topmost( ) ) ;
730775 let ( parameters, witness_types, call_tracker) = scope. destruct ( ) ;
731- let mut iter = items. into_iter ( ) . filter_map ( |item| match item {
732- Item :: Function ( Function :: Main ( expr) ) => Some ( expr) ,
733- _ => None ,
734- } ) ;
735- let main = iter. next ( ) . ok_or ( Error :: MainRequired ) . with_span ( from) ?;
736- if iter. next ( ) . is_some ( ) {
737- return Err ( Error :: FunctionRedefined ( FunctionName :: main ( ) ) ) . with_span ( from) ;
738- }
739776 Ok ( Self {
740777 main,
741778 parameters,
@@ -745,6 +782,141 @@ impl Program {
745782 }
746783}
747784
785+ /// Scan a parse expression for every custom function it calls directly.
786+ /// Covers plain calls, `array_fold`, `fold`, and `for_while` references.
787+ fn collect_custom_calls ( expr : & parse:: Expression ) -> HashSet < FunctionName > {
788+ let mut calls = HashSet :: new ( ) ;
789+ for node in parse:: ExprTree :: Expression ( expr) . pre_order_iter ( ) {
790+ if let parse:: ExprTree :: Call ( call) = node {
791+ match call. name ( ) {
792+ parse:: CallName :: Custom ( name)
793+ | parse:: CallName :: ArrayFold ( name, _)
794+ | parse:: CallName :: Fold ( name, _)
795+ | parse:: CallName :: ForWhile ( name) => {
796+ calls. insert ( name. clone ( ) ) ;
797+ }
798+ _ => { }
799+ }
800+ }
801+ }
802+ calls
803+ }
804+
805+ /// Build a map from each function name to the set of custom functions it calls.
806+ fn build_call_graph (
807+ func_items : & [ & parse:: Function ] ,
808+ ) -> HashMap < FunctionName , HashSet < FunctionName > > {
809+ func_items
810+ . iter ( )
811+ . map ( |func| ( func. name ( ) . clone ( ) , collect_custom_calls ( func. body ( ) ) ) )
812+ . collect ( )
813+ }
814+
815+ /// Return an error if any function directly calls `main`.
816+ fn check_no_calls_to_main (
817+ func_items : & [ & parse:: Function ] ,
818+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
819+ ) -> Result < ( ) , RichError > {
820+ let main = FunctionName :: main ( ) ;
821+ for func in func_items {
822+ if let Some ( deps) = graph. get ( func. name ( ) ) {
823+ if deps. contains ( & main) {
824+ return Err ( Error :: MainNotCallable ) . with_span ( * func) ;
825+ }
826+ }
827+ }
828+ Ok ( ( ) )
829+ }
830+
831+ /// DFS helper that returns the name of a node involved in a cycle, if any.
832+ fn dfs_find_cycle (
833+ name : & FunctionName ,
834+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
835+ visiting : & mut HashSet < FunctionName > ,
836+ visited : & mut HashSet < FunctionName > ,
837+ ) -> Option < FunctionName > {
838+ if visiting. contains ( name) {
839+ return Some ( name. clone ( ) ) ;
840+ }
841+ if visited. contains ( name) {
842+ return None ;
843+ }
844+ visiting. insert ( name. clone ( ) ) ;
845+ if let Some ( deps) = graph. get ( name) {
846+ for dep in deps {
847+ if let Some ( cycle) = dfs_find_cycle ( dep, graph, visiting, visited) {
848+ return Some ( cycle) ;
849+ }
850+ }
851+ }
852+ visiting. remove ( name) ;
853+ visited. insert ( name. clone ( ) ) ;
854+ None
855+ }
856+
857+ /// Return an error if any function in the call graph participates in a cycle
858+ /// (direct self-recursion or mutual recursion).
859+ fn check_for_cycles (
860+ func_items : & [ & parse:: Function ] ,
861+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
862+ ) -> Result < ( ) , RichError > {
863+ let mut visiting: HashSet < FunctionName > = HashSet :: new ( ) ;
864+ let mut visited: HashSet < FunctionName > = HashSet :: new ( ) ;
865+ for func in func_items {
866+ if !visited. contains ( func. name ( ) ) {
867+ if let Some ( cycle_name) =
868+ dfs_find_cycle ( func. name ( ) , graph, & mut visiting, & mut visited)
869+ {
870+ let offending = func_items
871+ . iter ( )
872+ . find ( |f| f. name ( ) == & cycle_name)
873+ . copied ( )
874+ . unwrap_or ( func) ;
875+ return Err ( Error :: FunctionRecursive ( cycle_name) ) . with_span ( & * offending) ;
876+ }
877+ }
878+ }
879+ Ok ( ( ) )
880+ }
881+
882+ /// DFS post-order visitor for topological sort.
883+ fn topo_dfs < ' a > (
884+ name : & FunctionName ,
885+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
886+ func_map : & HashMap < & FunctionName , & ' a parse:: Function > ,
887+ visited : & mut HashSet < FunctionName > ,
888+ result : & mut Vec < & ' a parse:: Function > ,
889+ ) {
890+ if visited. contains ( name) {
891+ return ;
892+ }
893+ visited. insert ( name. clone ( ) ) ;
894+ if let Some ( deps) = graph. get ( name) {
895+ for dep in deps {
896+ topo_dfs ( dep, graph, func_map, visited, result) ;
897+ }
898+ }
899+ if let Some ( & func) = func_map. get ( name) {
900+ result. push ( func) ;
901+ }
902+ }
903+
904+ /// Return the non-main functions in an order where every callee appears before
905+ /// its callers. Assumes no cycles (call `check_for_cycles` first).
906+ fn topological_sort < ' a > (
907+ func_items : & [ & ' a parse:: Function ] ,
908+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
909+ ) -> Vec < & ' a parse:: Function > {
910+ let func_map: HashMap < & FunctionName , & ' a parse:: Function > =
911+ func_items. iter ( ) . map ( |f| ( f. name ( ) , * f) ) . collect ( ) ;
912+ let mut visited: HashSet < FunctionName > = HashSet :: new ( ) ;
913+ let mut result: Vec < & ' a parse:: Function > = Vec :: new ( ) ;
914+ for func in func_items {
915+ topo_dfs ( func. name ( ) , graph, & func_map, & mut visited, & mut result) ;
916+ }
917+ result
918+ }
919+
748920impl AbstractSyntaxTree for Item {
749921 type From = parse:: Item ;
750922
@@ -1579,3 +1751,173 @@ impl AsRef<Span> for ModuleAssignment {
15791751 & self . span
15801752 }
15811753}
1754+
1755+ #[ cfg( test) ]
1756+ mod tests {
1757+ use crate :: TemplateProgram ;
1758+
1759+ fn compile_ok ( src : & str ) {
1760+ TemplateProgram :: new ( src) . unwrap_or_else ( |e| panic ! ( "Expected success:\n {e}" ) ) ;
1761+ }
1762+
1763+ fn compile_err ( src : & str , expected_msg : & str ) {
1764+ match TemplateProgram :: new ( src) {
1765+ Ok ( _) => panic ! ( "Expected error containing '{expected_msg}', but compilation succeeded" ) ,
1766+ Err ( e) => assert ! (
1767+ e. contains( expected_msg) ,
1768+ "Error message did not contain '{expected_msg}':\n {e}"
1769+ ) ,
1770+ }
1771+ }
1772+
1773+ // -----------------------------------------------------------------------
1774+ // Forward references
1775+ // -----------------------------------------------------------------------
1776+
1777+ #[ test]
1778+ fn forward_reference_simple ( ) {
1779+ // `always_true` is called in `main` but defined after it.
1780+ compile_ok (
1781+ r#"fn main() {
1782+ let x: bool = always_true();
1783+ assert!(x);
1784+ }
1785+
1786+ fn always_true() -> bool {
1787+ true
1788+ }"# ,
1789+ ) ;
1790+ }
1791+
1792+ #[ test]
1793+ fn forward_reference_chain ( ) {
1794+ // `main` calls `outer`, which calls `inner`, both defined after main.
1795+ compile_ok (
1796+ r#"fn main() {
1797+ let x: bool = outer();
1798+ assert!(x);
1799+ }
1800+
1801+ fn outer() -> bool {
1802+ inner()
1803+ }
1804+
1805+ fn inner() -> bool {
1806+ true
1807+ }"# ,
1808+ ) ;
1809+ }
1810+
1811+ #[ test]
1812+ fn forward_reference_array_fold ( ) {
1813+ // `summer` is referenced in `array_fold` before it is defined.
1814+ compile_ok (
1815+ r#"fn main() {
1816+ let arr: [u32; 3] = [1, 2, 3];
1817+ let _total: u32 = array_fold::<summer, 3>(arr, 0);
1818+ }
1819+
1820+ fn summer(elt: u32, acc: u32) -> u32 {
1821+ let (_, result): (bool, u32) = jet::add_32(elt, acc);
1822+ result
1823+ }"# ,
1824+ ) ;
1825+ }
1826+
1827+ // -----------------------------------------------------------------------
1828+ // Recursion detection
1829+ // -----------------------------------------------------------------------
1830+
1831+ #[ test]
1832+ fn direct_self_recursion_rejected ( ) {
1833+ compile_err (
1834+ r#"fn foo(n: u32) -> u32 {
1835+ foo(n)
1836+ }
1837+
1838+ fn main() {
1839+ let _x: u32 = foo(1);
1840+ }"# ,
1841+ "recursive call cycle" ,
1842+ ) ;
1843+ }
1844+
1845+ #[ test]
1846+ fn mutual_recursion_two_functions_rejected ( ) {
1847+ // foo → bar → foo
1848+ compile_err (
1849+ r#"fn foo(n: u32) -> u32 {
1850+ bar(n)
1851+ }
1852+
1853+ fn bar(n: u32) -> u32 {
1854+ foo(n)
1855+ }
1856+
1857+ fn main() {
1858+ let _x: u32 = foo(1);
1859+ }"# ,
1860+ "recursive call cycle" ,
1861+ ) ;
1862+ }
1863+
1864+ #[ test]
1865+ fn mutual_recursion_three_functions_rejected ( ) {
1866+ // a → b → c → a
1867+ compile_err (
1868+ r#"fn a(n: u32) -> u32 {
1869+ b(n)
1870+ }
1871+
1872+ fn b(n: u32) -> u32 {
1873+ c(n)
1874+ }
1875+
1876+ fn c(n: u32) -> u32 {
1877+ a(n)
1878+ }
1879+
1880+ fn main() {
1881+ let _x: u32 = a(1);
1882+ }"# ,
1883+ "recursive call cycle" ,
1884+ ) ;
1885+ }
1886+
1887+ #[ test]
1888+ fn call_to_main_rejected ( ) {
1889+ compile_err (
1890+ r#"fn foo() -> bool {
1891+ main()
1892+ }
1893+
1894+ fn main() {
1895+ assert!(foo());
1896+ }"# ,
1897+ "Function `main` cannot be called" ,
1898+ ) ;
1899+ }
1900+
1901+ #[ test]
1902+ fn non_recursive_chain_ok ( ) {
1903+ // top → mid → bottom, defined in reverse order, no cycle.
1904+ compile_ok (
1905+ r#"fn main() {
1906+ let x: bool = top();
1907+ assert!(x);
1908+ }
1909+
1910+ fn top() -> bool {
1911+ mid()
1912+ }
1913+
1914+ fn mid() -> bool {
1915+ bottom()
1916+ }
1917+
1918+ fn bottom() -> bool {
1919+ true
1920+ }"# ,
1921+ ) ;
1922+ }
1923+ }
0 commit comments