Skip to content

Commit 58d7007

Browse files
authored
Rewrite TPC-DS Q14 plan to workaround Polars optimizer CSE limitation (rapidsai#21885)
PR is an alternate plan for Q14. Current implementation single GPU took 500s+. This PR on an H100 takes 31s cold ~20s hot cc @Matt711 Authors: - Benjamin Zaitlen (https://github.com/quasiben) Approvers: - Matthew Murray (https://github.com/Matt711) URL: rapidsai#21885
1 parent ef308bc commit 58d7007

1 file changed

Lines changed: 70 additions & 100 deletions

File tree

  • python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries

python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q14.py

Lines changed: 70 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -278,61 +278,6 @@ def build_average_sales( # noqa: D103
278278
)
279279

280280

281-
def build_channel_result( # noqa: D103
282-
sales: pl.LazyFrame,
283-
item: pl.LazyFrame,
284-
date_dim: pl.LazyFrame,
285-
cross_items: pl.LazyFrame,
286-
*,
287-
item_key: str,
288-
date_key: str,
289-
qty_col: str,
290-
price_col: str,
291-
channel_label: str,
292-
year: int,
293-
moy: int,
294-
dom: int,
295-
average_sales: pl.LazyFrame,
296-
) -> pl.LazyFrame:
297-
# DuckDB uses d_week_seq to filter, which includes all days in the target week.
298-
# Find the d_week_seq for the specific date, then join on that week.
299-
target_week = (
300-
date_dim.filter(
301-
(pl.col("d_year") == year)
302-
& (pl.col("d_moy") == moy)
303-
& (pl.col("d_dom") == dom)
304-
)
305-
.select("d_week_seq")
306-
.unique()
307-
)
308-
week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk")
309-
return (
310-
sales.join(cross_items, left_on=item_key, right_on="ss_item_sk")
311-
.join(item, left_on=item_key, right_on="i_item_sk")
312-
.join(week_dates, left_on=date_key, right_on="d_date_sk")
313-
.group_by(["i_brand_id", "i_class_id", "i_category_id"])
314-
.agg(
315-
[
316-
(pl.col(qty_col) * pl.col(price_col)).sum().alias("sales"),
317-
pl.len().alias("number_sales"),
318-
]
319-
)
320-
.join(average_sales, how="cross")
321-
.filter(pl.col("sales") > pl.col("average_sales"))
322-
.with_columns(pl.lit(channel_label).alias("channel"))
323-
.select(
324-
[
325-
"channel",
326-
"i_brand_id",
327-
"i_class_id",
328-
"i_category_id",
329-
"sales",
330-
"number_sales",
331-
]
332-
)
333-
)
334-
335-
336281
def rollup_level(y: pl.LazyFrame, group_cols: list[str]) -> pl.LazyFrame: # noqa: D103
337282
if group_cols:
338283
lf = y.group_by(group_cols).agg(
@@ -390,60 +335,85 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
390335
item = get_data(run_config.dataset_path, "item", run_config.suffix)
391336
date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix)
392337

338+
all_sales = pl.concat(
339+
[
340+
store_sales.select(
341+
[
342+
pl.lit("store").alias("channel"),
343+
pl.col("ss_item_sk").alias("item_sk"),
344+
pl.col("ss_quantity").alias("quantity"),
345+
pl.col("ss_list_price").alias("list_price"),
346+
pl.col("ss_sold_date_sk").alias("date_sk"),
347+
]
348+
),
349+
catalog_sales.select(
350+
[
351+
pl.lit("catalog").alias("channel"),
352+
pl.col("cs_item_sk").alias("item_sk"),
353+
pl.col("cs_quantity").alias("quantity"),
354+
pl.col("cs_list_price").alias("list_price"),
355+
pl.col("cs_sold_date_sk").alias("date_sk"),
356+
]
357+
),
358+
web_sales.select(
359+
[
360+
pl.lit("web").alias("channel"),
361+
pl.col("ws_item_sk").alias("item_sk"),
362+
pl.col("ws_quantity").alias("quantity"),
363+
pl.col("ws_list_price").alias("list_price"),
364+
pl.col("ws_sold_date_sk").alias("date_sk"),
365+
]
366+
),
367+
]
368+
)
369+
393370
cross_items = build_cross_items(
394371
store_sales, catalog_sales, web_sales, item, date_dim, year=year
395372
)
396373
average_sales = build_average_sales(
397374
store_sales, catalog_sales, web_sales, date_dim, year=year
398375
)
399376

400-
y_store = build_channel_result(
401-
store_sales,
402-
item,
403-
date_dim,
404-
cross_items,
405-
item_key="ss_item_sk",
406-
date_key="ss_sold_date_sk",
407-
qty_col="ss_quantity",
408-
price_col="ss_list_price",
409-
channel_label="store",
410-
year=year + 1,
411-
moy=12,
412-
dom=day,
413-
average_sales=average_sales,
414-
)
415-
y_catalog = build_channel_result(
416-
catalog_sales,
417-
item,
418-
date_dim,
419-
cross_items,
420-
item_key="cs_item_sk",
421-
date_key="cs_sold_date_sk",
422-
qty_col="cs_quantity",
423-
price_col="cs_list_price",
424-
channel_label="catalog",
425-
year=year + 1,
426-
moy=12,
427-
dom=day,
428-
average_sales=average_sales,
429-
)
430-
y_web = build_channel_result(
431-
web_sales,
432-
item,
433-
date_dim,
434-
cross_items,
435-
item_key="ws_item_sk",
436-
date_key="ws_sold_date_sk",
437-
qty_col="ws_quantity",
438-
price_col="ws_list_price",
439-
channel_label="web",
440-
year=year + 1,
441-
moy=12,
442-
dom=day,
443-
average_sales=average_sales,
377+
# d_week_seq target is the same for all 3 channels; compute it once.
378+
target_week = (
379+
date_dim.filter(
380+
(pl.col("d_year") == year + 1)
381+
& (pl.col("d_moy") == 12)
382+
& (pl.col("d_dom") == day)
383+
)
384+
.select("d_week_seq")
385+
.unique()
444386
)
387+
week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk")
445388

446-
y = pl.concat([y_store, y_catalog, y_web])
389+
# Build y: all 3 channels in a single pipeline.
390+
# cross_items and average_sales each appear once — no CSE needed.
391+
# After group_by the frame is tiny, so the cross join with the 1-row
392+
# average_sales frame is negligible even if Polars fuses it into an IEJoin.
393+
y = (
394+
all_sales.join(cross_items, left_on="item_sk", right_on="ss_item_sk")
395+
.join(item, left_on="item_sk", right_on="i_item_sk")
396+
.join(week_dates, left_on="date_sk", right_on="d_date_sk")
397+
.group_by(["channel", "i_brand_id", "i_class_id", "i_category_id"])
398+
.agg(
399+
[
400+
(pl.col("quantity") * pl.col("list_price")).sum().alias("sales"),
401+
pl.len().alias("number_sales"),
402+
]
403+
)
404+
.join(average_sales, how="cross")
405+
.filter(pl.col("sales") > pl.col("average_sales"))
406+
.select(
407+
[
408+
"channel",
409+
"i_brand_id",
410+
"i_class_id",
411+
"i_category_id",
412+
"sales",
413+
"number_sales",
414+
]
415+
)
416+
)
447417

448418
level1 = rollup_level(y, ["channel", "i_brand_id", "i_class_id", "i_category_id"])
449419
level2 = rollup_level(y, ["channel", "i_brand_id", "i_class_id"])

0 commit comments

Comments
 (0)