11use std:: collections:: hash_map:: Entry ;
2- use std:: collections:: HashMap ;
2+ use std:: collections:: { HashMap , HashSet } ;
33use std:: num:: NonZeroUsize ;
44use std:: str:: FromStr ;
55use std:: sync:: Arc ;
@@ -721,21 +721,57 @@ impl Program {
721721 pub fn analyze ( from : & parse:: Program ) -> Result < Self , RichError > {
722722 let unit = ResolvedType :: unit ( ) ;
723723 let mut scope = Scope :: default ( ) ;
724- let items = from
724+
725+ // Pass 1: Process type aliases in file order.
726+ for item in from. items ( ) {
727+ if let parse:: Item :: TypeAlias ( alias) = item {
728+ scope
729+ . insert_alias ( alias. name ( ) . clone ( ) , alias. ty ( ) . clone ( ) )
730+ . with_span ( alias) ?;
731+ }
732+ }
733+
734+ // Collect all non-main custom function definitions.
735+ let func_items: Vec < & parse:: Function > = from
725736 . items ( )
726737 . iter ( )
727- . map ( |s| Item :: analyze ( s, & unit, & mut scope) )
728- . collect :: < Result < Vec < Item > , RichError > > ( ) ?;
738+ . filter_map ( |item| match item {
739+ parse:: Item :: Function ( f) if f. name ( ) . as_inner ( ) != "main" => Some ( f) ,
740+ _ => None ,
741+ } )
742+ . collect ( ) ;
743+
744+ // Build a call graph, reject calls to main, and reject recursive cycles.
745+ let call_graph = build_call_graph ( & func_items) ;
746+ check_no_calls_to_main ( & func_items, & call_graph) ?;
747+ check_for_cycles ( & func_items, & call_graph) ?;
748+
749+ // Pass 2: Analyze custom functions in dependency order so that a
750+ // callee is always in scope before its callers are analyzed.
751+ for func in topological_sort ( & func_items, & call_graph) {
752+ Function :: analyze ( func, & unit, & mut scope) ?;
753+ }
754+
755+ // Pass 3: Find and analyze the main function.
756+ let mut main_expr: Option < Expression > = None ;
757+ let mut seen_main = false ;
758+ for item in from. items ( ) {
759+ if let parse:: Item :: Function ( f) = item {
760+ if f. name ( ) . as_inner ( ) == "main" {
761+ if seen_main {
762+ return Err ( Error :: FunctionRedefined ( FunctionName :: main ( ) ) ) . with_span ( f) ;
763+ }
764+ seen_main = true ;
765+ if let Function :: Main ( expr) = Function :: analyze ( f, & unit, & mut scope) ? {
766+ main_expr = Some ( expr) ;
767+ }
768+ }
769+ }
770+ }
771+
772+ let main = main_expr. ok_or ( Error :: MainRequired ) . with_span ( from) ?;
729773 debug_assert ! ( scope. is_topmost( ) ) ;
730774 let ( parameters, witness_types, call_tracker) = scope. destruct ( ) ;
731- let mut iter = items. into_iter ( ) . filter_map ( |item| match item {
732- Item :: Function ( Function :: Main ( expr) ) => Some ( expr) ,
733- _ => None ,
734- } ) ;
735- let main = iter. next ( ) . ok_or ( Error :: MainRequired ) . with_span ( from) ?;
736- if iter. next ( ) . is_some ( ) {
737- return Err ( Error :: FunctionRedefined ( FunctionName :: main ( ) ) ) . with_span ( from) ;
738- }
739775 Ok ( Self {
740776 main,
741777 parameters,
@@ -745,6 +781,141 @@ impl Program {
745781 }
746782}
747783
784+ /// Scan a parse expression for every custom function it calls directly.
785+ /// Covers plain calls, `array_fold`, `fold`, and `for_while` references.
786+ fn collect_custom_calls ( expr : & parse:: Expression ) -> HashSet < FunctionName > {
787+ let mut calls = HashSet :: new ( ) ;
788+ for node in parse:: ExprTree :: Expression ( expr) . pre_order_iter ( ) {
789+ if let parse:: ExprTree :: Call ( call) = node {
790+ match call. name ( ) {
791+ parse:: CallName :: Custom ( name)
792+ | parse:: CallName :: ArrayFold ( name, _)
793+ | parse:: CallName :: Fold ( name, _)
794+ | parse:: CallName :: ForWhile ( name) => {
795+ calls. insert ( name. clone ( ) ) ;
796+ }
797+ _ => { }
798+ }
799+ }
800+ }
801+ calls
802+ }
803+
804+ /// Build a map from each function name to the set of custom functions it calls.
805+ fn build_call_graph (
806+ func_items : & [ & parse:: Function ] ,
807+ ) -> HashMap < FunctionName , HashSet < FunctionName > > {
808+ func_items
809+ . iter ( )
810+ . map ( |func| ( func. name ( ) . clone ( ) , collect_custom_calls ( func. body ( ) ) ) )
811+ . collect ( )
812+ }
813+
814+ /// Return an error if any function directly calls `main`.
815+ fn check_no_calls_to_main (
816+ func_items : & [ & parse:: Function ] ,
817+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
818+ ) -> Result < ( ) , RichError > {
819+ let main = FunctionName :: main ( ) ;
820+ for func in func_items {
821+ if let Some ( deps) = graph. get ( func. name ( ) ) {
822+ if deps. contains ( & main) {
823+ return Err ( Error :: MainNotCallable ) . with_span ( * func) ;
824+ }
825+ }
826+ }
827+ Ok ( ( ) )
828+ }
829+
830+ /// DFS helper that returns the name of a node involved in a cycle, if any.
831+ fn dfs_find_cycle (
832+ name : & FunctionName ,
833+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
834+ visiting : & mut HashSet < FunctionName > ,
835+ visited : & mut HashSet < FunctionName > ,
836+ ) -> Option < FunctionName > {
837+ if visiting. contains ( name) {
838+ return Some ( name. clone ( ) ) ;
839+ }
840+ if visited. contains ( name) {
841+ return None ;
842+ }
843+ visiting. insert ( name. clone ( ) ) ;
844+ if let Some ( deps) = graph. get ( name) {
845+ for dep in deps {
846+ if let Some ( cycle) = dfs_find_cycle ( dep, graph, visiting, visited) {
847+ return Some ( cycle) ;
848+ }
849+ }
850+ }
851+ visiting. remove ( name) ;
852+ visited. insert ( name. clone ( ) ) ;
853+ None
854+ }
855+
856+ /// Return an error if any function in the call graph participates in a cycle
857+ /// (direct self-recursion or mutual recursion).
858+ fn check_for_cycles (
859+ func_items : & [ & parse:: Function ] ,
860+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
861+ ) -> Result < ( ) , RichError > {
862+ let mut visiting: HashSet < FunctionName > = HashSet :: new ( ) ;
863+ let mut visited: HashSet < FunctionName > = HashSet :: new ( ) ;
864+ for func in func_items {
865+ if !visited. contains ( func. name ( ) ) {
866+ if let Some ( cycle_name) =
867+ dfs_find_cycle ( func. name ( ) , graph, & mut visiting, & mut visited)
868+ {
869+ let offending = func_items
870+ . iter ( )
871+ . find ( |f| f. name ( ) == & cycle_name)
872+ . copied ( )
873+ . unwrap_or ( func) ;
874+ return Err ( Error :: FunctionRecursive ( cycle_name) ) . with_span ( offending) ;
875+ }
876+ }
877+ }
878+ Ok ( ( ) )
879+ }
880+
881+ /// DFS post-order visitor for topological sort.
882+ fn topo_dfs < ' a > (
883+ name : & FunctionName ,
884+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
885+ func_map : & HashMap < & FunctionName , & ' a parse:: Function > ,
886+ visited : & mut HashSet < FunctionName > ,
887+ result : & mut Vec < & ' a parse:: Function > ,
888+ ) {
889+ if visited. contains ( name) {
890+ return ;
891+ }
892+ visited. insert ( name. clone ( ) ) ;
893+ if let Some ( deps) = graph. get ( name) {
894+ for dep in deps {
895+ topo_dfs ( dep, graph, func_map, visited, result) ;
896+ }
897+ }
898+ if let Some ( & func) = func_map. get ( name) {
899+ result. push ( func) ;
900+ }
901+ }
902+
903+ /// Return the non-main functions in an order where every callee appears before
904+ /// its callers. Assumes no cycles (call `check_for_cycles` first).
905+ fn topological_sort < ' a > (
906+ func_items : & [ & ' a parse:: Function ] ,
907+ graph : & HashMap < FunctionName , HashSet < FunctionName > > ,
908+ ) -> Vec < & ' a parse:: Function > {
909+ let func_map: HashMap < & FunctionName , & ' a parse:: Function > =
910+ func_items. iter ( ) . map ( |f| ( f. name ( ) , * f) ) . collect ( ) ;
911+ let mut visited: HashSet < FunctionName > = HashSet :: new ( ) ;
912+ let mut result: Vec < & ' a parse:: Function > = Vec :: new ( ) ;
913+ for func in func_items {
914+ topo_dfs ( func. name ( ) , graph, & func_map, & mut visited, & mut result) ;
915+ }
916+ result
917+ }
918+
748919impl AbstractSyntaxTree for Item {
749920 type From = parse:: Item ;
750921
@@ -1579,3 +1750,175 @@ impl AsRef<Span> for ModuleAssignment {
15791750 & self . span
15801751 }
15811752}
1753+
1754+ #[ cfg( test) ]
1755+ mod tests {
1756+ use crate :: TemplateProgram ;
1757+
1758+ fn compile_ok ( src : & str ) {
1759+ TemplateProgram :: new ( src) . unwrap_or_else ( |e| panic ! ( "Expected success:\n {e}" ) ) ;
1760+ }
1761+
1762+ fn compile_err ( src : & str , expected_msg : & str ) {
1763+ match TemplateProgram :: new ( src) {
1764+ Ok ( _) => {
1765+ panic ! ( "Expected error containing '{expected_msg}', but compilation succeeded" )
1766+ }
1767+ Err ( e) => assert ! (
1768+ e. contains( expected_msg) ,
1769+ "Error message did not contain '{expected_msg}':\n {e}"
1770+ ) ,
1771+ }
1772+ }
1773+
1774+ // -----------------------------------------------------------------------
1775+ // Forward references
1776+ // -----------------------------------------------------------------------
1777+
1778+ #[ test]
1779+ fn forward_reference_simple ( ) {
1780+ // `always_true` is called in `main` but defined after it.
1781+ compile_ok (
1782+ r#"fn main() {
1783+ let x: bool = always_true();
1784+ assert!(x);
1785+ }
1786+
1787+ fn always_true() -> bool {
1788+ true
1789+ }"# ,
1790+ ) ;
1791+ }
1792+
1793+ #[ test]
1794+ fn forward_reference_chain ( ) {
1795+ // `main` calls `outer`, which calls `inner`, both defined after main.
1796+ compile_ok (
1797+ r#"fn main() {
1798+ let x: bool = outer();
1799+ assert!(x);
1800+ }
1801+
1802+ fn outer() -> bool {
1803+ inner()
1804+ }
1805+
1806+ fn inner() -> bool {
1807+ true
1808+ }"# ,
1809+ ) ;
1810+ }
1811+
1812+ #[ test]
1813+ fn forward_reference_array_fold ( ) {
1814+ // `summer` is referenced in `array_fold` before it is defined.
1815+ compile_ok (
1816+ r#"fn main() {
1817+ let arr: [u32; 3] = [1, 2, 3];
1818+ let _total: u32 = array_fold::<summer, 3>(arr, 0);
1819+ }
1820+
1821+ fn summer(elt: u32, acc: u32) -> u32 {
1822+ let (_, result): (bool, u32) = jet::add_32(elt, acc);
1823+ result
1824+ }"# ,
1825+ ) ;
1826+ }
1827+
1828+ // -----------------------------------------------------------------------
1829+ // Recursion detection
1830+ // -----------------------------------------------------------------------
1831+
1832+ #[ test]
1833+ fn direct_self_recursion_rejected ( ) {
1834+ compile_err (
1835+ r#"fn foo(n: u32) -> u32 {
1836+ foo(n)
1837+ }
1838+
1839+ fn main() {
1840+ let _x: u32 = foo(1);
1841+ }"# ,
1842+ "recursive call cycle" ,
1843+ ) ;
1844+ }
1845+
1846+ #[ test]
1847+ fn mutual_recursion_two_functions_rejected ( ) {
1848+ // foo → bar → foo
1849+ compile_err (
1850+ r#"fn foo(n: u32) -> u32 {
1851+ bar(n)
1852+ }
1853+
1854+ fn bar(n: u32) -> u32 {
1855+ foo(n)
1856+ }
1857+
1858+ fn main() {
1859+ let _x: u32 = foo(1);
1860+ }"# ,
1861+ "recursive call cycle" ,
1862+ ) ;
1863+ }
1864+
1865+ #[ test]
1866+ fn mutual_recursion_three_functions_rejected ( ) {
1867+ // a → b → c → a
1868+ compile_err (
1869+ r#"fn a(n: u32) -> u32 {
1870+ b(n)
1871+ }
1872+
1873+ fn b(n: u32) -> u32 {
1874+ c(n)
1875+ }
1876+
1877+ fn c(n: u32) -> u32 {
1878+ a(n)
1879+ }
1880+
1881+ fn main() {
1882+ let _x: u32 = a(1);
1883+ }"# ,
1884+ "recursive call cycle" ,
1885+ ) ;
1886+ }
1887+
1888+ #[ test]
1889+ fn call_to_main_rejected ( ) {
1890+ compile_err (
1891+ r#"fn foo() -> bool {
1892+ main()
1893+ }
1894+
1895+ fn main() {
1896+ assert!(foo());
1897+ }"# ,
1898+ "Function `main` cannot be called" ,
1899+ ) ;
1900+ }
1901+
1902+ #[ test]
1903+ fn non_recursive_chain_ok ( ) {
1904+ // top → mid → bottom, defined in reverse order, no cycle.
1905+ compile_ok (
1906+ r#"fn main() {
1907+ let x: bool = top();
1908+ assert!(x);
1909+ }
1910+
1911+ fn top() -> bool {
1912+ mid()
1913+ }
1914+
1915+ fn mid() -> bool {
1916+ bottom()
1917+ }
1918+
1919+ fn bottom() -> bool {
1920+ true
1921+ }"# ,
1922+ ) ;
1923+ }
1924+ }
0 commit comments