Skip to content

Commit 1a91955

Browse files
authored
feat: enforce using aggregate function in order by when group by is used (#18)
1 parent 620bf78 commit 1a91955

7 files changed

Lines changed: 205 additions & 11 deletions

src/analysis.rs

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -655,17 +655,21 @@ impl<'a> Analysis<'a> {
655655
}
656656
}
657657

658+
let project = self.analyze_projection(&mut ctx, &query.projection)?;
659+
658660
if let Some(order_by) = &query.order_by {
659-
if !matches!(&order_by.expr.value, Value::Access(_)) {
661+
self.analyze_expr(&mut ctx, &order_by.expr, Type::Unspecified)?;
662+
663+
if query.group_by.is_none() && !matches!(&order_by.expr.value, Value::Access(_)) {
660664
return Err(AnalysisError::ExpectFieldLiteral(
661665
order_by.expr.attrs.pos.line,
662666
order_by.expr.attrs.pos.col,
663667
));
668+
} else if query.group_by.is_some() {
669+
self.expect_agg_expr(&order_by.expr)?;
664670
}
665-
self.analyze_expr(&mut ctx, &order_by.expr, Type::Unspecified)?;
666671
}
667672

668-
let project = self.analyze_projection(&mut ctx, &query.projection)?;
669673
let scope = self.exit_scope();
670674

671675
Ok(Query {
@@ -875,11 +879,11 @@ impl<'a> Analysis<'a> {
875879
));
876880
}
877881

878-
for arg in &app.args {
879-
if *aggregate {
880-
self.ensure_agg_param_is_source_bound(arg)?;
881-
}
882+
if *aggregate {
883+
return self.expect_agg_expr(expr);
884+
}
882885

886+
for arg in &app.args {
883887
self.invalidate_agg_func_usage(arg)?;
884888
}
885889
}
@@ -897,7 +901,27 @@ impl<'a> Analysis<'a> {
897901
}
898902
}
899903

900-
fn ensure_agg_param_is_source_bound(&mut self, expr: &Expr) -> AnalysisResult<()> {
904+
fn expect_agg_expr(&self, expr: &Expr) -> AnalysisResult<()> {
905+
if let Value::App(app) = &expr.value
906+
&& let Some(Type::App {
907+
aggregate: true, ..
908+
}) = self.options.default_scope.entries.get(app.func.as_str())
909+
{
910+
for arg in &app.args {
911+
self.ensure_agg_param_is_source_bound(arg)?;
912+
self.invalidate_agg_func_usage(arg)?;
913+
}
914+
915+
return Ok(());
916+
}
917+
918+
Err(AnalysisError::ExpectAggExpr(
919+
expr.attrs.pos.line,
920+
expr.attrs.pos.col,
921+
))
922+
}
923+
924+
fn ensure_agg_param_is_source_bound(&self, expr: &Expr) -> AnalysisResult<()> {
901925
match &expr.value {
902926
Value::Id(id) if !self.options.default_scope.entries.contains_key(id.as_str()) => {
903927
Ok(())
@@ -914,7 +938,7 @@ impl<'a> Analysis<'a> {
914938
}
915939

916940
fn ensure_agg_binary_op_is_source_bound(
917-
&mut self,
941+
&self,
918942
attrs: &Attrs,
919943
binary: &Binary,
920944
) -> AnalysisResult<()> {
@@ -930,7 +954,7 @@ impl<'a> Analysis<'a> {
930954
Ok(())
931955
}
932956

933-
fn ensure_agg_binary_op_branch_is_source_bound(&mut self, expr: &Expr) -> bool {
957+
fn ensure_agg_binary_op_branch_is_source_bound(&self, expr: &Expr) -> bool {
934958
match &expr.value {
935959
Value::Id(id) => !self.options.default_scope.entries.contains_key(id.as_str()),
936960
Value::Array(exprs) => {
@@ -965,7 +989,7 @@ impl<'a> Analysis<'a> {
965989
}
966990
}
967991

968-
fn invalidate_agg_func_usage(&mut self, expr: &Expr) -> AnalysisResult<()> {
992+
fn invalidate_agg_func_usage(&self, expr: &Expr) -> AnalysisResult<()> {
969993
match &expr.value {
970994
Value::Number(_)
971995
| Value::String(_)

src/error.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,29 @@ pub enum AnalysisError {
339339
/// ```
340340
#[error("{0}:{1}: constant expressions are forbidden in PROJECT INTO clause")]
341341
ConstantExprInProjectIntoClause(u32, u32),
342+
343+
/// Expect an aggregate expression at this position in the query.
344+
///
345+
/// Fields: `(line, column)`
346+
///
347+
/// Invalid usage:
348+
/// ```eql
349+
/// FROM e IN events
350+
/// GROUP BY e.data.department
351+
/// // ERROR: the order by clause should use an aggregage expresion because GROUP Bys
352+
/// // require an aggregate expression in this context.
353+
/// ORDER BY e.data.salary
354+
/// PROJECT INTO AVG(e.data.salary)
355+
/// ```
356+
/// Valid usage:
357+
/// ```eql
358+
/// FROM e IN events
359+
/// GROUP BY e.data.department
360+
/// ORDER BY AVG(e.data.salary)
361+
/// PROJECT INTO AVG(e.data.salary)
362+
/// ```
363+
#[error("{0}:{1}: expect aggregate expression")]
364+
ExpectAggExpr(u32, u32),
342365
}
343366

344367
impl From<LexerError> for Error {

src/tests/analysis.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,21 @@ fn test_analyze_allow_constant_agg_func() {
217217
let query = parse_query(include_str!("./resources/allow_constant_agg_func.eql")).unwrap();
218218
insta::assert_yaml_snapshot!(query.run_static_analysis(&Default::default()));
219219
}
220+
221+
#[test]
222+
fn test_analyze_reject_group_by_with_order_by_no_agg() {
223+
let query = parse_query(include_str!(
224+
"./resources/reject_group_by_with_order_by_no_agg.eql"
225+
))
226+
.unwrap();
227+
insta::assert_yaml_snapshot!(query.run_static_analysis(&Default::default()));
228+
}
229+
230+
#[test]
231+
fn test_analyze_accept_group_by_with_order_by_with_agg() {
232+
let query = parse_query(include_str!(
233+
"./resources/accept_group_by_with_order_by_with_agg.eql"
234+
))
235+
.unwrap();
236+
insta::assert_yaml_snapshot!(query.run_static_analysis(&Default::default()));
237+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
FROM e IN events
2+
GROUP BY e.data.department
3+
ORDER BY AVG(e.data.salary)
4+
PROJECT INTO AVG(e.data.salary)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
FROM e IN events
2+
GROUP BY e.data.department
3+
ORDER BY e.data.salary
4+
PROJECT INTO AVG(e.data.salary)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
---
2+
source: src/tests/analysis.rs
3+
expression: "query.run_static_analysis(&Default::default())"
4+
---
5+
Ok:
6+
attrs:
7+
pos:
8+
line: 1
9+
col: 1
10+
sources:
11+
- binding:
12+
name: e
13+
pos:
14+
line: 1
15+
col: 6
16+
kind:
17+
Name: events
18+
predicate: ~
19+
group_by:
20+
expr:
21+
attrs:
22+
pos:
23+
line: 2
24+
col: 10
25+
value:
26+
Access:
27+
target:
28+
attrs:
29+
pos:
30+
line: 2
31+
col: 10
32+
value:
33+
Access:
34+
target:
35+
attrs:
36+
pos:
37+
line: 2
38+
col: 10
39+
value:
40+
Id: e
41+
field: data
42+
field: department
43+
predicate: ~
44+
order_by:
45+
expr:
46+
attrs:
47+
pos:
48+
line: 3
49+
col: 10
50+
value:
51+
App:
52+
func: AVG
53+
args:
54+
- attrs:
55+
pos:
56+
line: 3
57+
col: 14
58+
value:
59+
Access:
60+
target:
61+
attrs:
62+
pos:
63+
line: 3
64+
col: 14
65+
value:
66+
Access:
67+
target:
68+
attrs:
69+
pos:
70+
line: 3
71+
col: 14
72+
value:
73+
Id: e
74+
field: data
75+
field: salary
76+
order: Asc
77+
limit: ~
78+
projection:
79+
attrs:
80+
pos:
81+
line: 4
82+
col: 14
83+
value:
84+
App:
85+
func: AVG
86+
args:
87+
- attrs:
88+
pos:
89+
line: 4
90+
col: 18
91+
value:
92+
Access:
93+
target:
94+
attrs:
95+
pos:
96+
line: 4
97+
col: 18
98+
value:
99+
Access:
100+
target:
101+
attrs:
102+
pos:
103+
line: 4
104+
col: 18
105+
value:
106+
Id: e
107+
field: data
108+
field: salary
109+
distinct: false
110+
meta:
111+
project: Number
112+
aggregate: true
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
source: src/tests/analysis.rs
3+
expression: "query.run_static_analysis(&Default::default())"
4+
---
5+
Err:
6+
Analysis:
7+
ExpectAggExpr:
8+
- 3
9+
- 10

0 commit comments

Comments
 (0)