@@ -375,6 +375,7 @@ def __call__(
375375 scheduler_args : TileSchedulerOptions ,
376376 varlen_args : Optional [VarlenArguments ],
377377 stream : cuda .CUstream ,
378+ trace_ptr : Optional [cutlass .Int64 ] = None ,
378379 ):
379380 """Execute the GEMM operation in steps:
380381 - Setup static attributes
@@ -542,6 +543,7 @@ class SharedStorage:
542543 self .epi_c_smem_layout_staged ,
543544 tile_sched_params ,
544545 TileSchedulerCls ,
546+ trace_ptr ,
545547 ).launch (
546548 grid = grid ,
547549 block = [self .threads_per_cta , 1 , 1 ],
@@ -573,6 +575,7 @@ def kernel(
573575 epi_c_smem_layout : cute .ComposedLayout ,
574576 tile_sched_params ,
575577 TileSchedulerCls : cutlass .Constexpr [Callable ],
578+ trace_ptr : Optional [cutlass .Int64 ] = None ,
576579 ):
577580 """
578581 GPU device kernel performing the batched GEMM computation.
@@ -601,6 +604,11 @@ def kernel(
601604 :type epi_smem_layout: cute.ComposedLayout
602605 """
603606
607+ from quack .trace import TraceContext
608+
609+ GEMM_REGIONS = ("tma_load" , "mma" , "epilogue" )
610+ tctx = TraceContext .create (trace_ptr , region_names = GEMM_REGIONS )
611+
604612 varlen_m = const_expr (varlen_params .cu_seqlens_m is not None )
605613 varlen_k = const_expr (varlen_params .cu_seqlens_k is not None )
606614 assert not (varlen_m and varlen_k )
@@ -703,6 +711,7 @@ def kernel(
703711 pipeline .PipelineUserType .Producer , self .ab_stage
704712 )
705713 while work_tile .is_valid_tile :
714+ tctx .b ("tma_load" )
706715 tile_coord_mnkl = work_tile .tile_idx
707716 batch_idx = tile_coord_mnkl [3 ]
708717 # Local_tile partition global tensors
@@ -804,6 +813,7 @@ def kernel(
804813 k_tile_cnt ,
805814 varlen_m = varlen_m ,
806815 )
816+ tctx .e ("tma_load" )
807817 tile_scheduler .advance_to_next_work (is_scheduler_warp = is_scheduler_warp )
808818 work_tile = tile_scheduler .get_current_work ()
809819 # End of persistent scheduler loop
@@ -882,16 +892,19 @@ def kernel(
882892 batch_idx = tile_coord_mnkl [3 ]
883893 len_k = varlen_manager .len_k (batch_idx )
884894 k_tile_cnt = cute .ceil_div (len_k , self .cta_tile_shape_mnk [2 ])
895+ tctx .b ("mma" )
885896 ab_read_state = self .mma (
886897 ab_pipeline , ab_read_state , mma_fn , acc , acc_slow , k_tile_cnt , warp_group_idx
887898 )
899+ tctx .e ("mma" )
888900 if const_expr (varlen_k ):
889901 if k_tile_cnt == 0 :
890902 acc .fill (0.0 )
891903
892904 # EPILOGUE
893905 if const_expr (self .pingpong ):
894906 self .pingpong_barrier_sync (warp_group_idx , "epi" )
907+ tctx .b ("epilogue" )
895908
896909 copy_D = None
897910 if const_expr (has_D ):
@@ -966,6 +979,8 @@ def kernel(
966979 epi_store_pipeline .producer_tail ()
967980 self .pingpong_barrier_arrive (1 - warp_group_idx , stage = "epi" )
968981
982+ tctx .e ("epilogue" )
983+
969984 if const_expr (not self .pingpong ):
970985 tile_scheduler .advance_to_next_work ()
971986 work_tile = tile_scheduler .get_current_work ()
@@ -994,6 +1009,8 @@ def kernel(
9941009 if is_tma_warp :
9951010 epi_store_pipeline .producer_tail ()
9961011
1012+ tctx .flush ()
1013+
9971014 @cute .jit
9981015 def load_AB (
9991016 self ,
0 commit comments