Skip to content

Commit be8a72a

Browse files
authored
fix paddle optional get assert in sm103 (#7816)
1 parent cb2d7c0 commit be8a72a

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

custom_ops/setup_ops.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,12 @@ def get_gencode_flags(archs):
174174
"-gencode",
175175
f"arch=compute_{arch_code},code=sm_{arch_code}",
176176
]
177+
elif cc_val == 103:
178+
arch_code = "103a"
179+
flags += [
180+
"-gencode",
181+
f"arch=compute_{arch_code},code=sm_{arch_code}",
182+
]
177183
else:
178184
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
179185
return flags
@@ -478,9 +484,10 @@ def find_end_files(directory, end_str):
478484
# of them instead of only the highest one.
479485
has_sm90 = 90 in sm_versions
480486
has_sm100 = 100 in sm_versions and nvcc_version >= 12.9
481-
has_generic_fp8 = not has_sm90 and not has_sm100 # SM89 or other
487+
has_sm103 = 103 in sm_versions and nvcc_version >= 13.0
488+
has_generic_fp8 = not has_sm90 and not has_sm100 and not has_sm103 # SM89 or other
482489

483-
if has_sm90 or has_sm100:
490+
if has_sm90 or has_sm100 or has_sm103:
484491
nvcc_compile_args += [
485492
"-O3",
486493
"-DNDEBUG",
@@ -503,8 +510,8 @@ def find_end_files(directory, end_str):
503510
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
504511
]
505512

506-
if has_sm100:
507-
print("SM100 (Blackwell): Applying SM100 configurations.")
513+
if has_sm100 or has_sm103:
514+
print("SM100 / 103 (Blackwell): Applying SM100 / SM103 configurations.")
508515
# Placeholder for SM100-specific kernel auto-generation scripts
509516
# These might be needed if Blackwell has new FP8 hardware features
510517
# not covered by existing generic CUTLASS templates or SM90 scripts.

0 commit comments

Comments
 (0)