Skip to content

Commit 6aa2699

Browse files
authored
feat(cubesql): Flatten filter rules (#10940)
* feat(cubesql): Flatten filter rules * Linter * Clippy * Linter * Fixes
1 parent 335d079 commit 6aa2699

5 files changed

Lines changed: 303 additions & 3 deletions

File tree

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14454,6 +14454,90 @@ ORDER BY "source"."str0" ASC
1445414454
)
1445514455
}
1445614456

14457+
/// ThoughtSpot-style day-of-quarter expression split across inner/outer query
14458+
/// with MEASURE() and CASE WHEN filter on the measure column.
14459+
///
14460+
/// The inner query projects two parts of the quarter calculation as separate
14461+
/// columns plus a CASE WHEN filtered amount, with no GROUP BY.
14462+
/// The outer query computes day_of_quarter from those columns, wraps the
14463+
/// filtered amount in MEASURE(), and groups.
14464+
///
14465+
/// This exercises the E-graph's ability to:
14466+
/// 1. Flatten the subquery so the quarter expression becomes a single tree
14467+
/// 2. Atomically rewrite the quarter expression to DATE_TRUNC('quarter', ...)
14468+
/// before sub-expression simplification rules break the pattern
14469+
/// 3. Expand MEASURE() at the correct aggregation level
14470+
/// 4. Avoid emitting INTERVAL '1 month' * expr (invalid on Snowflake)
14471+
#[tokio::test]
14472+
async fn test_thoughtspot_pg_day_of_quarter_split_with_measure() {
14473+
if !Rewriter::sql_push_down_enabled() {
14474+
return;
14475+
}
14476+
init_testing_logger();
14477+
14478+
let query_plan = convert_select_to_query_plan(
14479+
r#"
14480+
SELECT
14481+
CAST("inner_query"."order_date" AS date)
14482+
- CAST("inner_query"."quarter_start" AS date)
14483+
+ 1 AS "day_of_quarter",
14484+
MEASURE("inner_query"."sumPrice") AS "revenue"
14485+
FROM (
14486+
SELECT
14487+
"ta_1"."order_date" AS "order_date",
14488+
CAST(
14489+
EXTRACT(YEAR FROM "ta_1"."order_date") || '-'
14490+
|| EXTRACT(MONTH FROM "ta_1"."order_date") || '-01'
14491+
AS DATE)
14492+
+ (((MOD(CAST((EXTRACT(MONTH FROM "ta_1"."order_date") - 1)
14493+
AS numeric), 3) + 1) - 1) * -1)
14494+
* INTERVAL '1 month'
14495+
AS "quarter_start",
14496+
CASE WHEN "ta_1"."customer_gender" = 'female'
14497+
THEN "ta_1"."sumPrice" END AS "sumPrice"
14498+
FROM "db"."public"."KibanaSampleDataEcommerce" AS "ta_1"
14499+
) "inner_query"
14500+
WHERE
14501+
CAST("inner_query"."order_date" AS date)
14502+
- CAST("inner_query"."quarter_start" AS date)
14503+
+ 1 <= 45
14504+
GROUP BY 1
14505+
ORDER BY 1
14506+
;"#
14507+
.to_string(),
14508+
DatabaseProtocol::PostgreSQL,
14509+
)
14510+
.await;
14511+
14512+
let logical_plan = query_plan.as_logical_plan();
14513+
14514+
let request = logical_plan.find_cube_scan().request;
14515+
14516+
// The rewriter should recognize the complex quarter expression and
14517+
// simplify it to DATE_TRUNC('quarter', col) via the
14518+
// thoughtspot-pg-quarter-start-to-date-trunc rule, which then gets
14519+
// recognized as a quarter time dimension.
14520+
assert_eq!(
14521+
request,
14522+
V1LoadRequestQuery {
14523+
measures: Some(vec!["KibanaSampleDataEcommerce.sumPrice".to_string(),]),
14524+
dimensions: Some(vec![
14525+
"KibanaSampleDataEcommerce.order_date".to_string(),
14526+
"KibanaSampleDataEcommerce.customer_gender".to_string(),
14527+
]),
14528+
segments: Some(vec![]),
14529+
time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension {
14530+
dimension: "KibanaSampleDataEcommerce.order_date".to_string(),
14531+
granularity: Some("quarter".to_string()),
14532+
date_range: None,
14533+
},]),
14534+
order: Some(vec![]),
14535+
ungrouped: Some(true),
14536+
..Default::default()
14537+
}
14538+
);
14539+
}
14540+
1445714541
#[tokio::test]
1445814542
async fn test_domo_filter_date_gt() {
1445914543
init_testing_logger();

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,9 @@ pub enum ListType {
906906
AggregateGroupExpr,
907907
AggregateAggrExpr,
908908
ScalarFunctionExprArgs,
909+
ScalarUDFExprArgs,
910+
AggregateFunctionExprArgs,
911+
AggregateUDFExprArgs,
909912
GroupingSetExprMembers,
910913
WrappedSelectProjectionExpr,
911914
WrappedSelectGroupExpr,
@@ -923,6 +926,11 @@ impl ListType {
923926
Self::AggregateAggrExpr => aggr_aggr_expr_empty_tail(),
924927
Self::GroupingSetExprMembers => grouping_set_expr_members_empty_tail(),
925928
Self::ScalarFunctionExprArgs => scalar_fun_expr_args_empty_tail(),
929+
Self::ScalarUDFExprArgs => udf_fun_expr_args_empty_tail(),
930+
Self::AggregateFunctionExprArgs => {
931+
list_expr("AggregateFunctionExprArgs", Vec::<String>::new())
932+
}
933+
Self::AggregateUDFExprArgs => udaf_fun_expr_args_empty_tail(),
926934
Self::WrappedSelectProjectionExpr => wrapped_select_projection_expr_empty_tail(),
927935
Self::WrappedSelectGroupExpr => wrapped_select_group_expr_empty_tail(),
928936
Self::WrappedSelectAggrExpr => wrapped_select_aggr_expr_empty_tail(),
@@ -987,6 +995,15 @@ impl ListNodeSearcher {
987995
ListType::ScalarFunctionExprArgs => {
988996
matches!(node, LogicalPlanLanguage::ScalarFunctionExprArgs(_))
989997
}
998+
ListType::ScalarUDFExprArgs => {
999+
matches!(node, LogicalPlanLanguage::ScalarUDFExprArgs(_))
1000+
}
1001+
ListType::AggregateFunctionExprArgs => {
1002+
matches!(node, LogicalPlanLanguage::AggregateFunctionExprArgs(_))
1003+
}
1004+
ListType::AggregateUDFExprArgs => {
1005+
matches!(node, LogicalPlanLanguage::AggregateUDFExprArgs(_))
1006+
}
9901007
ListType::WrappedSelectProjectionExpr => {
9911008
matches!(node, LogicalPlanLanguage::WrappedSelectProjectionExpr(_))
9921009
}
@@ -1154,6 +1171,11 @@ impl ListNodeApplierList {
11541171
ListType::AggregateGroupExpr => LogicalPlanLanguage::AggregateGroupExpr(list),
11551172
ListType::AggregateAggrExpr => LogicalPlanLanguage::AggregateAggrExpr(list),
11561173
ListType::ScalarFunctionExprArgs => LogicalPlanLanguage::ScalarFunctionExprArgs(list),
1174+
ListType::ScalarUDFExprArgs => LogicalPlanLanguage::ScalarUDFExprArgs(list),
1175+
ListType::AggregateFunctionExprArgs => {
1176+
LogicalPlanLanguage::AggregateFunctionExprArgs(list)
1177+
}
1178+
ListType::AggregateUDFExprArgs => LogicalPlanLanguage::AggregateUDFExprArgs(list),
11571179
ListType::WrappedSelectProjectionExpr => {
11581180
LogicalPlanLanguage::WrappedSelectProjectionExpr(list)
11591181
}

rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515
var, var_iter,
1616
};
1717
use datafusion::{
18-
arrow::datatypes::{DataType, TimeUnit},
18+
arrow::datatypes::{DataType, DataType as ArrowDataType, TimeUnit},
1919
logical_plan::DFSchema,
2020
scalar::ScalarValue,
2121
};
@@ -409,6 +409,53 @@ impl RewriteRules for DateRules {
409409
"?new_granularity",
410410
),
411411
),
412+
// ThoughtSpot's PostgreSQL quarter start calculation uses INTERVAL arithmetic
413+
// that is incompatible with non-PostgreSQL dialects. Recognize the pattern and
414+
// replace with DATE_TRUNC('quarter', col) which all dialects support.
415+
transforming_rewrite(
416+
"thoughtspot-pg-quarter-start-to-date-trunc",
417+
alias_expr(
418+
binary_expr(
419+
cast_expr_explicit(
420+
binary_expr(
421+
binary_expr(
422+
binary_expr(
423+
self.fun_expr(
424+
"DatePart",
425+
vec![literal_string("year"), column_expr("?column")],
426+
),
427+
"||",
428+
literal_string("-"),
429+
),
430+
"||",
431+
self.fun_expr(
432+
"DatePart",
433+
vec![literal_string("month"), column_expr("?column")],
434+
),
435+
),
436+
"||",
437+
literal_string("-01"),
438+
),
439+
ArrowDataType::Date32,
440+
),
441+
"+",
442+
binary_expr(
443+
binary_expr("?mod_part", "*", "?neg_one"),
444+
"*",
445+
"?interval_val",
446+
),
447+
),
448+
"?alias",
449+
),
450+
alias_expr(
451+
self.fun_expr(
452+
"DateTrunc",
453+
vec![literal_string("quarter"), column_expr("?column")],
454+
),
455+
"?alias",
456+
),
457+
Self::transform_quarter_interval_check("?neg_one", "?interval_val"),
458+
),
412459
// AGE function seems to be a popular choice for this date arithmetic,
413460
// but it is not supported in SQL push down by most dialects.
414461
transforming_rewrite_with_root(
@@ -535,6 +582,40 @@ impl DateRules {
535582
}
536583
}
537584

585+
fn transform_quarter_interval_check(
586+
neg_one_var: &'static str,
587+
interval_val_var: &'static str,
588+
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
589+
let neg_one_var = var!(neg_one_var);
590+
let interval_val_var = var!(interval_val_var);
591+
move |egraph, subst| {
592+
let neg_one_ok = match &egraph[subst[neg_one_var]].data.constant {
593+
Some(ConstantFolding::Scalar(v)) => match v {
594+
ScalarValue::Int64(Some(-1))
595+
| ScalarValue::Int32(Some(-1))
596+
| ScalarValue::Decimal128(Some(-1), _, _) => true,
597+
_ => false,
598+
},
599+
_ => false,
600+
};
601+
if !neg_one_ok {
602+
return false;
603+
}
604+
match &egraph[subst[interval_val_var]].data.constant {
605+
Some(ConstantFolding::Scalar(v)) => match v {
606+
ScalarValue::IntervalYearMonth(Some(1)) => true,
607+
ScalarValue::IntervalMonthDayNano(Some(iv)) => {
608+
let months = (*iv >> 96) as i32;
609+
let days = ((*iv >> 64) & 0xFFFF_FFFF) as i32;
610+
months == 1 && days == 0
611+
}
612+
_ => false,
613+
},
614+
_ => false,
615+
}
616+
}
617+
}
618+
538619
fn transform_to_date_to_timestamp(
539620
&self,
540621
format_var: &'static str,

rust/cubesql/cubesql/src/compile/rewrite/rules/flatten/pass_through.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::compile::rewrite::{
22
agg_fun_expr, agg_fun_expr_within_group_empty_tail, alias_expr, binary_expr, cast_expr,
33
flatten_pushdown_replacer, fun_expr_var_arg, is_not_null_expr, is_null_expr, rewrite,
4-
rewriter::CubeRewrite, rules::flatten::FlattenRules, udf_expr_var_arg,
4+
rewriter::CubeRewrite, rules::flatten::FlattenRules, udaf_expr_var_arg, udf_expr_var_arg,
55
};
66

77
impl FlattenRules {
@@ -30,6 +30,11 @@ impl FlattenRules {
3030
|expr| udf_expr_var_arg("?fun", expr),
3131
rules,
3232
);
33+
self.single_arg_pass_through_rules(
34+
"udaf-function",
35+
|expr| udaf_expr_var_arg("?fun", expr, "?distinct"),
36+
rules,
37+
);
3338
self.single_arg_pass_through_rules("is-null", |expr| is_null_expr(expr), rules);
3439
self.single_arg_pass_through_rules("is-not-null", |expr| is_not_null_expr(expr), rules);
3540
rules.push(rewrite(

rust/cubesql/cubesql/src/compile/rewrite/rules/flatten/top_level.rs

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
compile::rewrite::{
3-
aggregate, cube_scan, flatten_pushdown_replacer, projection,
3+
aggregate, cube_scan, filter, flatten_pushdown_replacer, projection,
44
rewriter::{CubeEGraph, CubeRewrite},
55
rules::{flatten::FlattenRules, replacer_flat_push_down_node, replacer_push_down_node},
66
transforming_chain_rewrite_with_root, FlattenPushdownReplacerInnerAlias, ListType,
@@ -159,6 +159,90 @@ impl FlattenRules {
159159
),
160160
)]);
161161

162+
rules.extend(vec![transforming_chain_rewrite_with_root(
163+
"flatten-filter-pushdown",
164+
aggregate(
165+
"?filter_node",
166+
"?outer_group_expr",
167+
"?outer_aggregate_expr",
168+
"AggregateSplit:false",
169+
),
170+
vec![
171+
("?filter_node", filter("?filter_expr", "?inner_projection")),
172+
(
173+
"?inner_projection",
174+
projection(
175+
"?inner_projection_expr",
176+
"?cube_scan",
177+
"?inner_projection_alias",
178+
"ProjectionSplit:false",
179+
),
180+
),
181+
(
182+
"?cube_scan",
183+
cube_scan(
184+
"?alias_to_cube",
185+
"?members",
186+
"?filters",
187+
"?orders",
188+
"?limit",
189+
"?offset",
190+
"CubeScanSplit:false",
191+
"?can_pushdown_join",
192+
"CubeScanWrapped:false",
193+
"?ungrouped",
194+
"?join_hints",
195+
),
196+
),
197+
],
198+
aggregate(
199+
filter(
200+
flatten_pushdown_replacer(
201+
"?filter_expr",
202+
"?inner_projection_expr",
203+
"?inner_alias",
204+
"FlattenPushdownReplacerTopLevel:false",
205+
),
206+
cube_scan(
207+
"?alias_to_cube",
208+
"?members",
209+
"?filters",
210+
"?orders",
211+
"?limit",
212+
"?offset",
213+
"CubeScanSplit:false",
214+
"?can_pushdown_join",
215+
"CubeScanWrapped:false",
216+
"?ungrouped",
217+
"?join_hints",
218+
),
219+
),
220+
flatten_pushdown_replacer(
221+
"?outer_group_expr",
222+
"?inner_projection_expr",
223+
"?inner_alias",
224+
"FlattenPushdownReplacerTopLevel:false",
225+
),
226+
flatten_pushdown_replacer(
227+
"?outer_aggregate_expr",
228+
"?inner_projection_expr",
229+
"?inner_alias",
230+
"FlattenPushdownReplacerTopLevel:false",
231+
),
232+
"AggregateSplit:false",
233+
),
234+
self.flatten_aggregate(
235+
"?inner_projection",
236+
"?cube_scan",
237+
"?members",
238+
"?inner_projection_expr",
239+
"?outer_group_expr",
240+
"?outer_aggregate_expr",
241+
"?inner_projection_alias",
242+
"?inner_alias",
243+
),
244+
)]);
245+
162246
if self.config_obj.push_down_pull_up_split() {
163247
Self::flat_list_pushdown_rules(
164248
"flatten-projection-expr",
@@ -175,10 +259,34 @@ impl FlattenRules {
175259
ListType::AggregateGroupExpr,
176260
rules,
177261
);
262+
Self::flat_list_pushdown_rules(
263+
"flatten-scalar-fun-args",
264+
ListType::ScalarFunctionExprArgs,
265+
rules,
266+
);
267+
Self::flat_list_pushdown_rules(
268+
"flatten-udf-fun-args",
269+
ListType::ScalarUDFExprArgs,
270+
rules,
271+
);
272+
Self::flat_list_pushdown_rules(
273+
"flatten-agg-fun-args",
274+
ListType::AggregateFunctionExprArgs,
275+
rules,
276+
);
277+
Self::flat_list_pushdown_rules(
278+
"flatten-udaf-fun-args",
279+
ListType::AggregateUDFExprArgs,
280+
rules,
281+
);
178282
} else {
179283
Self::list_pushdown_rules("flatten-projection-expr", "ProjectionExpr", rules);
180284
Self::list_pushdown_rules("flatten-aggregate-expr", "AggregateAggrExpr", rules);
181285
Self::list_pushdown_rules("flatten-group-expr", "AggregateGroupExpr", rules);
286+
Self::list_pushdown_rules("flatten-scalar-fun-args", "ScalarFunctionExprArgs", rules);
287+
Self::list_pushdown_rules("flatten-udf-fun-args", "ScalarUDFExprArgs", rules);
288+
Self::list_pushdown_rules("flatten-agg-fun-args", "AggregateFunctionExprArgs", rules);
289+
Self::list_pushdown_rules("flatten-udaf-fun-args", "AggregateUDFExprArgs", rules);
182290
}
183291
}
184292

0 commit comments

Comments
 (0)