@@ -9,7 +9,7 @@ use serde::{Serialize, ser::SerializeMap};
99use unicase:: Ascii ;
1010
1111use crate :: {
12- Attrs , Binary , Expr , Field , FunArgs , Query , Raw , Source , SourceKind , Type , Value ,
12+ App , Attrs , Binary , Expr , Field , FunArgs , Query , Raw , Source , SourceKind , Type , Value ,
1313 error:: AnalysisError , token:: Operator ,
1414} ;
1515
@@ -36,6 +36,9 @@ pub struct Typed {
3636 /// including bindings from FROM clauses and their associated types.
3737 #[ serde( skip) ]
3838 pub scope : Scope ,
39+
40+ /// Indicates if the query uses aggregate functions.
41+ pub aggregate : bool ,
3942}
4043
4144/// Result type for static analysis operations.
@@ -528,6 +531,9 @@ pub struct AnalysisContext {
528531 /// Set to `true` to allow aggregate functions, `false` to reject them.
529532 /// Defaults to `false`.
530533 pub allow_agg_func : bool ,
534+
535+ /// Indicates if the query uses aggregate functions.
536+ pub use_agg_funcs : bool ,
531537}
532538
533539/// A type checker and static analyzer for EventQL expressions.
@@ -631,7 +637,7 @@ impl<'a> Analysis<'a> {
631637 }
632638
633639 if let Some ( expr) = & query. predicate {
634- self . analyze_expr ( & ctx, expr, Type :: Bool ) ?;
640+ self . analyze_expr ( & mut ctx, expr, Type :: Bool ) ?;
635641 }
636642
637643 if let Some ( group_by) = & query. group_by {
@@ -642,10 +648,10 @@ impl<'a> Analysis<'a> {
642648 ) ) ;
643649 }
644650
645- self . analyze_expr ( & ctx, & group_by. expr , Type :: Unspecified ) ?;
651+ self . analyze_expr ( & mut ctx, & group_by. expr , Type :: Unspecified ) ?;
646652
647653 if let Some ( expr) = & group_by. predicate {
648- self . analyze_expr ( & ctx, expr, Type :: Bool ) ?;
654+ self . analyze_expr ( & mut ctx, expr, Type :: Bool ) ?;
649655 }
650656 }
651657
@@ -656,7 +662,7 @@ impl<'a> Analysis<'a> {
656662 order_by. expr . attrs . pos . col ,
657663 ) ) ;
658664 }
659- self . analyze_expr ( & ctx, & order_by. expr , Type :: Unspecified ) ?;
665+ self . analyze_expr ( & mut ctx, & order_by. expr , Type :: Unspecified ) ?;
660666 }
661667
662668 let project = self . analyze_projection ( & mut ctx, & query. projection ) ?;
@@ -671,7 +677,11 @@ impl<'a> Analysis<'a> {
671677 limit : query. limit ,
672678 projection : query. projection ,
673679 distinct : query. distinct ,
674- meta : Typed { project, scope } ,
680+ meta : Typed {
681+ project,
682+ scope,
683+ aggregate : ctx. use_agg_funcs ,
684+ } ,
675685 } )
676686 }
677687
@@ -732,6 +742,20 @@ impl<'a> Analysis<'a> {
732742 Ok ( tpe)
733743 }
734744
745+ Value :: App ( app) => {
746+ ctx. allow_agg_func = true ;
747+
748+ let tpe = self . analyze_expr ( ctx, expr, Type :: Unspecified ) ?;
749+
750+ if ctx. use_agg_funcs {
751+ self . check_projection_on_field_expr ( & mut CheckContext :: default ( ) , expr) ?;
752+ } else {
753+ self . reject_constant_func ( & expr. attrs , app) ?;
754+ }
755+
756+ Ok ( tpe)
757+ }
758+
735759 Value :: Id ( id) => {
736760 if let Some ( tpe) = self . scope . entries . get ( id. as_str ( ) ) . cloned ( ) {
737761 Ok ( tpe)
@@ -994,6 +1018,87 @@ impl<'a> Analysis<'a> {
9941018 }
9951019 }
9961020
1021+ fn reject_constant_func ( & self , attrs : & Attrs , app : & App ) -> AnalysisResult < ( ) > {
1022+ if app. args . is_empty ( ) {
1023+ return Err ( AnalysisError :: ConstantExprInProjectIntoClause (
1024+ attrs. pos . line ,
1025+ attrs. pos . col ,
1026+ ) ) ;
1027+ }
1028+
1029+ let mut errored = None ;
1030+ for arg in & app. args {
1031+ if let Err ( e) = self . reject_constant_expr ( arg) {
1032+ if errored. is_none ( ) {
1033+ errored = Some ( e) ;
1034+ }
1035+
1036+ continue ;
1037+ }
1038+
1039+ // if at least one arg is sourced-bound is ok
1040+ return Ok ( ( ) ) ;
1041+ }
1042+
1043+ Err ( errored. expect ( "to be defined at that point" ) )
1044+ }
1045+
1046+ fn reject_constant_expr ( & self , expr : & Expr ) -> AnalysisResult < ( ) > {
1047+ match & expr. value {
1048+ Value :: Id ( id) if self . scope . entries . contains_key ( id. as_str ( ) ) => Ok ( ( ) ) ,
1049+
1050+ Value :: Array ( exprs) => {
1051+ let mut errored = None ;
1052+ for expr in exprs {
1053+ if let Err ( e) = self . reject_constant_expr ( expr) {
1054+ if errored. is_none ( ) {
1055+ errored = Some ( e) ;
1056+ }
1057+
1058+ continue ;
1059+ }
1060+
1061+ // if at least one arg is sourced-bound is ok
1062+ return Ok ( ( ) ) ;
1063+ }
1064+
1065+ Err ( errored. expect ( "to be defined at that point" ) )
1066+ }
1067+
1068+ Value :: Record ( fields) => {
1069+ let mut errored = None ;
1070+ for field in fields {
1071+ if let Err ( e) = self . reject_constant_expr ( & field. value ) {
1072+ if errored. is_none ( ) {
1073+ errored = Some ( e) ;
1074+ }
1075+
1076+ continue ;
1077+ }
1078+
1079+ // if at least one arg is sourced-bound is ok
1080+ return Ok ( ( ) ) ;
1081+ }
1082+
1083+ Err ( errored. expect ( "to be defined at that point" ) )
1084+ }
1085+
1086+ Value :: Binary ( binary) => self
1087+ . reject_constant_expr ( & binary. lhs )
1088+ . or_else ( |e| self . reject_constant_expr ( & binary. rhs ) . map_err ( |_| e) ) ,
1089+
1090+ Value :: Access ( access) => self . reject_constant_expr ( access. target . as_ref ( ) ) ,
1091+ Value :: App ( app) => self . reject_constant_func ( & expr. attrs , app) ,
1092+ Value :: Unary ( unary) => self . reject_constant_expr ( & unary. expr ) ,
1093+ Value :: Group ( expr) => self . reject_constant_expr ( expr) ,
1094+
1095+ _ => Err ( AnalysisError :: ConstantExprInProjectIntoClause (
1096+ expr. attrs . pos . line ,
1097+ expr. attrs . pos . col ,
1098+ ) ) ,
1099+ }
1100+ }
1101+
9971102 /// Analyzes an expression and checks it against an expected type.
9981103 ///
9991104 /// This method performs type checking on an expression, verifying that all operations
@@ -1025,7 +1130,7 @@ impl<'a> Analysis<'a> {
10251130 /// ```
10261131 pub fn analyze_expr (
10271132 & mut self ,
1028- ctx : & AnalysisContext ,
1133+ ctx : & mut AnalysisContext ,
10291134 expr : & Expr ,
10301135 mut expect : Type ,
10311136 ) -> AnalysisResult < Type > {
@@ -1147,6 +1252,10 @@ impl<'a> Analysis<'a> {
11471252 ) ) ;
11481253 }
11491254
1255+ if * aggregate && ctx. allow_agg_func {
1256+ ctx. use_agg_funcs = true ;
1257+ }
1258+
11501259 for ( arg, tpe) in app. args . iter ( ) . zip ( args. values . iter ( ) . cloned ( ) ) {
11511260 self . analyze_expr ( ctx, arg, tpe) ?;
11521261 }
0 commit comments