Skip to content

Commit 82b87db

Browse files
committed
[Gemm,Sm100] Fix num_sf_tmem_cols when not blockscaled
1 parent a606f62 commit 82b87db

2 files changed

Lines changed: 23 additions & 10 deletions

File tree

quack/gemm_sm100.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -480,15 +480,18 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle
480480
# to release acc_pipeline early.
481481
# The two approaches perform about the same.
482482
self.overlap_accum_sf = self.blockscaled and self.num_acc_stage == 1
483-
num_sf_tmem_cols = (
484-
(
485-
cute.ceil_div(self.cta_tile_shape_mnk[0], 128)
486-
+ cute.ceil_div(self.cta_tile_shape_mnk[1], 128)
483+
if const_expr(self.overlap_accum_sf):
484+
num_sf_tmem_cols = (
485+
(
486+
cute.ceil_div(self.cta_tile_shape_mnk[0], 128)
487+
+ cute.ceil_div(self.cta_tile_shape_mnk[1], 128)
488+
)
489+
* 4 # 4 cols per stage
490+
* (self.mma_inst_shape_mnk[2] // self.sf_vec_size)
487491
)
488-
* 4 # 4 cols per stage
489-
* (self.mma_inst_shape_mnk[2] // self.sf_vec_size)
490-
)
491-
self.iter_acc_early_release = num_sf_tmem_cols // cute.size(self.epi_tile[1])
492+
self.iter_acc_early_release = num_sf_tmem_cols // cute.size(self.epi_tile[1])
493+
else:
494+
self.iter_acc_early_release = -1
492495

493496
@cute.jit
494497
def __call__(

tools/dump_sass.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#!/usr/bin/env python3
22
"""Dump PTX and SASS of cute-dsl kernels from any script.
33
4-
Sets CUTE_DSL_KEEP_CUBIN=1 and CUTE_DSL_KEEP_PTX=1, runs the target script,
5-
then disassembles all generated .cubin files with nvdisasm.
4+
Disables the QuACK persistent kernel cache, sets CUTE_DSL_KEEP_CUBIN=1 and
5+
CUTE_DSL_KEEP_PTX=1, runs the target script, then disassembles all generated
6+
.cubin files with nvdisasm.
67
78
Usage::
89
@@ -45,6 +46,11 @@ def main():
4546
parser.add_argument("script", help="Python script to run")
4647
parser.add_argument("-o", "--output-dir", default="dump_sass_out", help="Output directory")
4748
parser.add_argument("--ptx-only", action="store_true", help="Skip SASS disassembly")
49+
parser.add_argument(
50+
"--use-cache",
51+
action="store_true",
52+
help="Allow QuACK to use its persistent .o cache instead of forcing recompilation",
53+
)
4854
args = parser.parse_args(our_argv)
4955

5056
script = Path(args.script)
@@ -59,13 +65,17 @@ def main():
5965
f.unlink()
6066

6167
env = os.environ.copy()
68+
if not args.use_cache:
69+
env["QUACK_CACHE_ENABLED"] = "0"
6270
env["CUTE_DSL_KEEP_PTX"] = "1"
6371
env["CUTE_DSL_KEEP_CUBIN"] = "1"
6472
env["CUTE_DSL_DUMP_DIR"] = str(out_dir.resolve())
6573

6674
cmd = [sys.executable, str(script)] + script_args
6775
print(f"Running: {' '.join(cmd)}")
6876
print(f"Dump dir: {out_dir.resolve()}\n")
77+
if not args.use_cache:
78+
print("QuACK cache: disabled via QUACK_CACHE_ENABLED=0\n")
6979
subprocess.run(cmd, env=env)
7080

7181
ptx_files = sorted(out_dir.glob("*.ptx"))

0 commit comments

Comments
 (0)