Skip to content

Commit f37ea85

Browse files
committed
[Style] Fix lint
1 parent 0dd15dc commit f37ea85

7 files changed

Lines changed: 31 additions & 30 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ torch = [
4242

4343
[tool.ruff]
4444
line-length = 100
45+
exclude = [".#*"]
4546

4647
[tool.ruff.lint]
4748
ignore = [

quack/gemm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,13 @@ def gemm(
170170
device_capacity = get_device_capacity(A.device)
171171
assert device_capacity[0] in [9, 10, 11], "Only SM90, SM100, and SM110 are supported"
172172
if rounding_mode == RoundingMode.RS:
173-
assert (
174-
device_capacity[0] >= 10
175-
), "Stochastic rounding (RoundingMode.RS) requires SM100+ (Blackwell)"
173+
assert device_capacity[0] >= 10, (
174+
"Stochastic rounding (RoundingMode.RS) requires SM100+ (Blackwell)"
175+
)
176176
if is_dynamic_persistent and device_capacity[0] == 9:
177-
assert (
178-
tile_count_semaphore is not None
179-
), "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
177+
assert tile_count_semaphore is not None, (
178+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
179+
)
180180

181181
A_p, B_p, D_p, C_p = perm3d(A, B, D, C, varlen_m=varlen_m, varlen_k=varlen_k)
182182
a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)

quack/gemm_act.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,15 @@ class GemmGatedMixin(GemmActMixin):
187187
def epi_to_underlying_arguments(
188188
self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None
189189
) -> GemmActMixin.EpilogueParams:
190-
assert (
191-
args.mPostAct.element_type.width == 16
192-
), "GemmGated only supports 16bit postact for now"
190+
assert args.mPostAct.element_type.width == 16, (
191+
"GemmGated only supports 16bit postact for now"
192+
)
193193
assert self.d_layout is None or self.d_layout.is_n_major_c()
194194
assert cutlass.utils.LayoutEnum.from_tensor(args.mPostAct).is_n_major_c()
195195
if self.arch == 90:
196-
assert (
197-
self.cta_tile_shape_mnk[1] % 32 == 0
198-
), "GemmGatedSm90 requires tileN to be divisible by 32"
196+
assert self.cta_tile_shape_mnk[1] % 32 == 0, (
197+
"GemmGatedSm90 requires tileN to be divisible by 32"
198+
)
199199
self.rounding_mode = args.rounding_mode
200200
self.postact_dtype = args.mPostAct.element_type
201201
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
@@ -415,14 +415,14 @@ def gemm_act(
415415
device_capacity = get_device_capacity(A.device)
416416
assert device_capacity[0] in [9, 10, 11], "Only SM90, SM100, and SM110 are supported"
417417
if rounding_mode == RoundingMode.RS:
418-
assert (
419-
device_capacity[0] >= 10
420-
), "Stochastic rounding (RoundingMode.RS) requires SM100+ (Blackwell)"
418+
assert device_capacity[0] >= 10, (
419+
"Stochastic rounding (RoundingMode.RS) requires SM100+ (Blackwell)"
420+
)
421421

422422
if is_dynamic_persistent and device_capacity[0] == 9:
423-
assert (
424-
tile_count_semaphore is not None
425-
), "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
423+
assert tile_count_semaphore is not None, (
424+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
425+
)
426426

427427
sr_seed_mode = (
428428
2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)

quack/gemm_dact.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,9 @@ def gemm_dact(
416416
assert device_capacity[0] in [9, 10, 11], "Only SM90, SM100, and SM110 are supported"
417417

418418
if is_dynamic_persistent and device_capacity[0] == 9:
419-
assert (
420-
tile_count_semaphore is not None
421-
), "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
419+
assert tile_count_semaphore is not None, (
420+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
421+
)
422422

423423
compiled_fn = _compile_gemm_dact(
424424
a_dtype,

quack/gemm_norm_act.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ def gemm_norm_act_fn(
312312
assert device_capacity[0] >= 10, "Stochastic rounding requires SM100+"
313313

314314
if is_dynamic_persistent and device_capacity[0] == 9:
315-
assert (
316-
tile_count_semaphore is not None
317-
), "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
315+
assert tile_count_semaphore is not None, (
316+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
317+
)
318318

319319
sr_seed_mode = (
320320
2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)

quack/gemm_sq_reduce.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ def gemm_sq_reduce(
199199
a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
200200

201201
if is_dynamic_persistent and device_capacity[0] == 9:
202-
assert (
203-
tile_count_semaphore is not None
204-
), "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
202+
assert tile_count_semaphore is not None, (
203+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
204+
)
205205

206206
compiled_fn = _compile_gemm_sq_reduce(
207207
a_dtype,

quack/gemm_symmetric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,9 @@ def gemm_symmetric(
317317
assert device_capacity[0] in [9, 10, 11], "Only SM90, SM100, and SM110 are supported"
318318

319319
if is_dynamic_persistent and device_capacity[0] == 9:
320-
assert (
321-
tile_count_semaphore is not None
322-
), "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
320+
assert tile_count_semaphore is not None, (
321+
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
322+
)
323323

324324
tile_shape_mn = (tile_M, tile_N)
325325
cluster_shape_mnk = (cluster_M, cluster_N, 1)

0 commit comments

Comments
 (0)