@@ -788,7 +788,7 @@ pub struct Aggregate {
788788}
789789
790790impl Aggregate {
791- fn new ( _schema : & Schema , name : & str , field_type : & s:: Type , dir : & s:: Directive ) -> Self {
791+ fn new ( schema : & Schema , name : & str , field_type : & s:: Type , dir : & s:: Directive ) -> Self {
792792 let func = dir
793793 . argument ( "fn" )
794794 . unwrap ( )
@@ -818,7 +818,7 @@ impl Aggregate {
818818 arg,
819819 cumulative,
820820 field_type : field_type. clone ( ) ,
821- value_type : field_type . get_base_type ( ) . parse ( ) . unwrap ( ) ,
821+ value_type : Field :: scalar_value_type ( schema , field_type ) ,
822822 }
823823 }
824824
@@ -2366,27 +2366,63 @@ mod validations {
23662366 }
23672367 }
23682368
2369- fn aggregate_fields_are_numbers ( agg_type : & s:: ObjectType , errors : & mut Vec < Err > ) {
2369+ fn aggregate_field_types (
2370+ schema : & Schema ,
2371+ agg_type : & s:: ObjectType ,
2372+ errors : & mut Vec < Err > ,
2373+ ) {
2374+ fn is_first_last ( agg_directive : & s:: Directive ) -> bool {
2375+ match agg_directive. argument ( kw:: FUNC ) {
2376+ Some ( s:: Value :: Enum ( func) | s:: Value :: String ( func) ) => {
2377+ func == AggregateFn :: First . as_str ( )
2378+ || func == AggregateFn :: Last . as_str ( )
2379+ }
2380+ _ => false ,
2381+ }
2382+ }
2383+
23702384 let errs = agg_type
23712385 . fields
23722386 . iter ( )
2373- . filter ( |field| field. find_directive ( kw:: AGGREGATE ) . is_some ( ) )
2374- . map ( |field| match field. field_type . value_type ( ) {
2375- Ok ( vt) => {
2376- if vt. is_numeric ( ) {
2377- Ok ( ( ) )
2378- } else {
2379- Err ( Err :: NonNumericAggregate (
2387+ . filter_map ( |field| {
2388+ field
2389+ . find_directive ( kw:: AGGREGATE )
2390+ . map ( |agg_directive| ( field, agg_directive) )
2391+ } )
2392+ . map ( |( field, agg_directive) | {
2393+ let is_first_last = is_first_last ( agg_directive) ;
2394+
2395+ match field. field_type . value_type ( ) {
2396+ Ok ( value_type) if value_type. is_numeric ( ) => Ok ( ( ) ) ,
2397+ Ok ( ValueType :: Bytes | ValueType :: String ) if is_first_last => Ok ( ( ) ) ,
2398+ Ok ( _) if is_first_last => Err ( Err :: InvalidFirstLastAggregate (
2399+ agg_type. name . clone ( ) ,
2400+ field. name . clone ( ) ,
2401+ ) ) ,
2402+ Ok ( _) => Err ( Err :: NonNumericAggregate (
2403+ agg_type. name . to_owned ( ) ,
2404+ field. name . to_owned ( ) ,
2405+ ) ) ,
2406+ Err ( _) => {
2407+ if is_first_last
2408+ && schema
2409+ . entity_types
2410+ . iter ( )
2411+ . find ( |entity_type| {
2412+ entity_type. name . eq ( field. field_type . get_base_type ( ) )
2413+ } )
2414+ . is_some ( )
2415+ {
2416+ return Ok ( ( ) ) ;
2417+ }
2418+
2419+ Err ( Err :: FieldTypeUnknown (
23802420 agg_type. name . to_owned ( ) ,
23812421 field. name . to_owned ( ) ,
2422+ field. field_type . get_base_type ( ) . to_owned ( ) ,
23822423 ) )
23832424 }
23842425 }
2385- Err ( _) => Err ( Err :: FieldTypeUnknown (
2386- agg_type. name . to_owned ( ) ,
2387- field. name . to_owned ( ) ,
2388- field. field_type . get_base_type ( ) . to_owned ( ) ,
2389- ) ) ,
23902426 } )
23912427 . filter_map ( |err| err. err ( ) ) ;
23922428 errors. extend ( errs) ;
@@ -2519,16 +2555,10 @@ mod validations {
25192555 continue ;
25202556 }
25212557 } ;
2522- let field_type = match field. field_type . value_type ( ) {
2523- Ok ( field_type) => field_type,
2524- Err ( _) => {
2525- errors. push ( Err :: NonNumericAggregate (
2526- agg_type. name . to_owned ( ) ,
2527- field. name . to_owned ( ) ,
2528- ) ) ;
2529- continue ;
2530- }
2531- } ;
2558+
2559+ let is_first_last =
2560+ matches ! ( func, AggregateFn :: First | AggregateFn :: Last ) ;
2561+
25322562 // It would be nicer to use a proper struct here
25332563 // and have that implement
25342564 // `sqlexpr::ExprVisitor` but we need access to
@@ -2539,6 +2569,18 @@ mod validations {
25392569 let arg_type = match source. field ( ident) {
25402570 Some ( arg_field) => match arg_field. field_type . value_type ( ) {
25412571 Ok ( arg_type) if arg_type. is_numeric ( ) => arg_type,
2572+ Ok ( ValueType :: Bytes | ValueType :: String )
2573+ if is_first_last =>
2574+ {
2575+ return Ok ( ( ) ) ;
2576+ }
2577+ Err ( _)
2578+ if is_first_last
2579+ && arg_field. field_type . get_base_type ( )
2580+ == field. field_type . get_base_type ( ) =>
2581+ {
2582+ return Ok ( ( ) ) ;
2583+ }
25422584 Ok ( _) | Err ( _) => {
25432585 return Err ( Err :: AggregationNonNumericArg (
25442586 agg_type. name . to_owned ( ) ,
@@ -2556,15 +2598,27 @@ mod validations {
25562598 ) ) ;
25572599 }
25582600 } ;
2559- if arg_type > field_type {
2560- return Err ( Err :: AggregationNonMatchingArg (
2561- agg_type. name . to_owned ( ) ,
2562- field. name . to_owned ( ) ,
2563- arg. to_owned ( ) ,
2564- arg_type. to_str ( ) . to_owned ( ) ,
2565- field_type. to_str ( ) . to_owned ( ) ,
2566- ) ) ;
2601+
2602+ match field. field_type . value_type ( ) {
2603+ Ok ( field_type) if field_type. is_numeric ( ) => {
2604+ if arg_type > field_type {
2605+ return Err ( Err :: AggregationNonMatchingArg (
2606+ agg_type. name . to_owned ( ) ,
2607+ field. name . to_owned ( ) ,
2608+ arg. to_owned ( ) ,
2609+ arg_type. to_str ( ) . to_owned ( ) ,
2610+ field_type. to_str ( ) . to_owned ( ) ,
2611+ ) ) ;
2612+ }
2613+ }
2614+ Ok ( _) | Err ( _) => {
2615+ return Err ( Err :: NonNumericAggregate (
2616+ agg_type. name . to_owned ( ) ,
2617+ field. name . to_owned ( ) ,
2618+ ) ) ;
2619+ }
25672620 }
2621+
25682622 Ok ( ( ) )
25692623 } ;
25702624 if let Err ( mut errs) = sqlexpr:: parse ( arg, check_ident) {
@@ -2661,7 +2715,7 @@ mod validations {
26612715 errors. push ( err) ;
26622716 }
26632717 no_derived_fields ( agg_type, & mut errors) ;
2664- aggregate_fields_are_numbers ( agg_type, & mut errors) ;
2718+ aggregate_field_types ( self , agg_type, & mut errors) ;
26652719 aggregate_directive ( self , agg_type, & mut errors) ;
26662720 // check timeseries directive has intervals and args
26672721 aggregation_intervals ( agg_type, & mut errors) ;
0 commit comments