@@ -278,61 +278,6 @@ def build_average_sales( # noqa: D103
278278 )
279279
280280
281- def build_channel_result ( # noqa: D103
282- sales : pl .LazyFrame ,
283- item : pl .LazyFrame ,
284- date_dim : pl .LazyFrame ,
285- cross_items : pl .LazyFrame ,
286- * ,
287- item_key : str ,
288- date_key : str ,
289- qty_col : str ,
290- price_col : str ,
291- channel_label : str ,
292- year : int ,
293- moy : int ,
294- dom : int ,
295- average_sales : pl .LazyFrame ,
296- ) -> pl .LazyFrame :
297- # DuckDB uses d_week_seq to filter, which includes all days in the target week.
298- # Find the d_week_seq for the specific date, then join on that week.
299- target_week = (
300- date_dim .filter (
301- (pl .col ("d_year" ) == year )
302- & (pl .col ("d_moy" ) == moy )
303- & (pl .col ("d_dom" ) == dom )
304- )
305- .select ("d_week_seq" )
306- .unique ()
307- )
308- week_dates = date_dim .join (target_week , on = "d_week_seq" ).select ("d_date_sk" )
309- return (
310- sales .join (cross_items , left_on = item_key , right_on = "ss_item_sk" )
311- .join (item , left_on = item_key , right_on = "i_item_sk" )
312- .join (week_dates , left_on = date_key , right_on = "d_date_sk" )
313- .group_by (["i_brand_id" , "i_class_id" , "i_category_id" ])
314- .agg (
315- [
316- (pl .col (qty_col ) * pl .col (price_col )).sum ().alias ("sales" ),
317- pl .len ().alias ("number_sales" ),
318- ]
319- )
320- .join (average_sales , how = "cross" )
321- .filter (pl .col ("sales" ) > pl .col ("average_sales" ))
322- .with_columns (pl .lit (channel_label ).alias ("channel" ))
323- .select (
324- [
325- "channel" ,
326- "i_brand_id" ,
327- "i_class_id" ,
328- "i_category_id" ,
329- "sales" ,
330- "number_sales" ,
331- ]
332- )
333- )
334-
335-
336281def rollup_level (y : pl .LazyFrame , group_cols : list [str ]) -> pl .LazyFrame : # noqa: D103
337282 if group_cols :
338283 lf = y .group_by (group_cols ).agg (
@@ -390,60 +335,85 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
390335 item = get_data (run_config .dataset_path , "item" , run_config .suffix )
391336 date_dim = get_data (run_config .dataset_path , "date_dim" , run_config .suffix )
392337
338+ all_sales = pl .concat (
339+ [
340+ store_sales .select (
341+ [
342+ pl .lit ("store" ).alias ("channel" ),
343+ pl .col ("ss_item_sk" ).alias ("item_sk" ),
344+ pl .col ("ss_quantity" ).alias ("quantity" ),
345+ pl .col ("ss_list_price" ).alias ("list_price" ),
346+ pl .col ("ss_sold_date_sk" ).alias ("date_sk" ),
347+ ]
348+ ),
349+ catalog_sales .select (
350+ [
351+ pl .lit ("catalog" ).alias ("channel" ),
352+ pl .col ("cs_item_sk" ).alias ("item_sk" ),
353+ pl .col ("cs_quantity" ).alias ("quantity" ),
354+ pl .col ("cs_list_price" ).alias ("list_price" ),
355+ pl .col ("cs_sold_date_sk" ).alias ("date_sk" ),
356+ ]
357+ ),
358+ web_sales .select (
359+ [
360+ pl .lit ("web" ).alias ("channel" ),
361+ pl .col ("ws_item_sk" ).alias ("item_sk" ),
362+ pl .col ("ws_quantity" ).alias ("quantity" ),
363+ pl .col ("ws_list_price" ).alias ("list_price" ),
364+ pl .col ("ws_sold_date_sk" ).alias ("date_sk" ),
365+ ]
366+ ),
367+ ]
368+ )
369+
393370 cross_items = build_cross_items (
394371 store_sales , catalog_sales , web_sales , item , date_dim , year = year
395372 )
396373 average_sales = build_average_sales (
397374 store_sales , catalog_sales , web_sales , date_dim , year = year
398375 )
399376
400- y_store = build_channel_result (
401- store_sales ,
402- item ,
403- date_dim ,
404- cross_items ,
405- item_key = "ss_item_sk" ,
406- date_key = "ss_sold_date_sk" ,
407- qty_col = "ss_quantity" ,
408- price_col = "ss_list_price" ,
409- channel_label = "store" ,
410- year = year + 1 ,
411- moy = 12 ,
412- dom = day ,
413- average_sales = average_sales ,
414- )
415- y_catalog = build_channel_result (
416- catalog_sales ,
417- item ,
418- date_dim ,
419- cross_items ,
420- item_key = "cs_item_sk" ,
421- date_key = "cs_sold_date_sk" ,
422- qty_col = "cs_quantity" ,
423- price_col = "cs_list_price" ,
424- channel_label = "catalog" ,
425- year = year + 1 ,
426- moy = 12 ,
427- dom = day ,
428- average_sales = average_sales ,
429- )
430- y_web = build_channel_result (
431- web_sales ,
432- item ,
433- date_dim ,
434- cross_items ,
435- item_key = "ws_item_sk" ,
436- date_key = "ws_sold_date_sk" ,
437- qty_col = "ws_quantity" ,
438- price_col = "ws_list_price" ,
439- channel_label = "web" ,
440- year = year + 1 ,
441- moy = 12 ,
442- dom = day ,
443- average_sales = average_sales ,
377+ # d_week_seq target is the same for all 3 channels; compute it once.
378+ target_week = (
379+ date_dim .filter (
380+ (pl .col ("d_year" ) == year + 1 )
381+ & (pl .col ("d_moy" ) == 12 )
382+ & (pl .col ("d_dom" ) == day )
383+ )
384+ .select ("d_week_seq" )
385+ .unique ()
444386 )
387+ week_dates = date_dim .join (target_week , on = "d_week_seq" ).select ("d_date_sk" )
445388
446- y = pl .concat ([y_store , y_catalog , y_web ])
389+ # Build y: all 3 channels in a single pipeline.
390+ # cross_items and average_sales each appear once — no CSE needed.
391+ # After group_by the frame is tiny, so the cross join with the 1-row
392+ # average_sales frame is negligible even if Polars fuses it into an IEJoin.
393+ y = (
394+ all_sales .join (cross_items , left_on = "item_sk" , right_on = "ss_item_sk" )
395+ .join (item , left_on = "item_sk" , right_on = "i_item_sk" )
396+ .join (week_dates , left_on = "date_sk" , right_on = "d_date_sk" )
397+ .group_by (["channel" , "i_brand_id" , "i_class_id" , "i_category_id" ])
398+ .agg (
399+ [
400+ (pl .col ("quantity" ) * pl .col ("list_price" )).sum ().alias ("sales" ),
401+ pl .len ().alias ("number_sales" ),
402+ ]
403+ )
404+ .join (average_sales , how = "cross" )
405+ .filter (pl .col ("sales" ) > pl .col ("average_sales" ))
406+ .select (
407+ [
408+ "channel" ,
409+ "i_brand_id" ,
410+ "i_class_id" ,
411+ "i_category_id" ,
412+ "sales" ,
413+ "number_sales" ,
414+ ]
415+ )
416+ )
447417
448418 level1 = rollup_level (y , ["channel" , "i_brand_id" , "i_class_id" , "i_category_id" ])
449419 level2 = rollup_level (y , ["channel" , "i_brand_id" , "i_class_id" ])
0 commit comments