|
17 | 17 | import quack.copy_utils as copy_utils |
18 | 18 | import quack.layout_utils as layout_utils |
19 | 19 | from quack.compile_utils import make_fake_tensor as fake_tensor |
| 20 | +from quack.dsl import cute_op |
20 | 21 | from quack.reduce import row_reduce |
21 | 22 | from quack.reduction_base import ReductionBase |
22 | 23 | from quack.cache_utils import jit_cache |
@@ -316,7 +317,7 @@ def kernel( |
316 | 317 | copy(tXrO, tXgO) |
317 | 318 |
|
318 | 319 |
|
319 | | -@torch.library.custom_op( |
| 320 | +@cute_op( |
320 | 321 | "quack::_rmsnorm_fwd", |
321 | 322 | mutates_args=("out", "rstd", "mean", "residual_out"), |
322 | 323 | device_types="cuda", |
@@ -375,58 +376,6 @@ def _rmsnorm_fwd( |
375 | 376 | )(x, weight, bias, residual, out, residual_out, rstd, mean, eps) |
376 | 377 |
|
377 | 378 |
|
378 | | -@_rmsnorm_fwd.register_fake |
379 | | -def _rmsnorm_fwd_fake( |
380 | | - x: Tensor, |
381 | | - weight: Optional[Tensor], |
382 | | - out: Tensor, |
383 | | - bias: Optional[Tensor] = None, |
384 | | - rstd: Optional[Tensor] = None, |
385 | | - mean: Optional[Tensor] = None, |
386 | | - residual: Optional[Tensor] = None, |
387 | | - residual_out: Optional[Tensor] = None, |
388 | | - eps: float = 1e-6, |
389 | | - is_layernorm: bool = False, |
390 | | -) -> None: |
391 | | - # See softmax.py _softmax_fwd_fake for why register_fake is needed. |
392 | | - from quack.cache_utils import COMPILE_ONLY |
393 | | - |
394 | | - if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt): |
395 | | - N = x.size(-1) |
396 | | - per_head = (weight is not None and weight.dim() == 2) or ( |
397 | | - bias is not None and bias.dim() == 2 |
398 | | - ) |
399 | | - dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [ |
400 | | - torch2cute_dtype_map[t.dtype] if t is not None else None |
401 | | - for t in [x, out, weight, bias, residual, residual_out] |
402 | | - ] |
403 | | - _compile_rmsnorm_fwd( |
404 | | - dtype, |
405 | | - out_dtype, |
406 | | - res_dtype, |
407 | | - weight_dtype, |
408 | | - bias_dtype, |
409 | | - res_out_dtype, |
410 | | - N, |
411 | | - rstd is not None, |
412 | | - mean is not None, |
413 | | - is_layernorm, |
414 | | - per_head, |
415 | | - ) |
416 | | - _compile_rmsnorm_bwd( |
417 | | - N, |
418 | | - dtype, |
419 | | - dtype, |
420 | | - dtype, |
421 | | - weight_dtype, |
422 | | - bias is not None, |
423 | | - res_dtype, |
424 | | - res_out_dtype, |
425 | | - weight is not None, |
426 | | - per_head, |
427 | | - ) |
428 | | - |
429 | | - |
430 | 379 | @jit_cache |
431 | 380 | def _compile_rmsnorm_fwd( |
432 | 381 | dtype, |
@@ -921,7 +870,7 @@ def _get_sm_count(N: int, device: torch.device) -> int: |
921 | 870 | return sm_count |
922 | 871 |
|
923 | 872 |
|
924 | | -@torch.library.custom_op( |
| 873 | +@cute_op( |
925 | 874 | "quack::_rmsnorm_bwd", |
926 | 875 | mutates_args={"dx", "dw_partial", "db_partial", "dresidual"}, |
927 | 876 | device_types="cuda", |
@@ -991,45 +940,6 @@ def _rmsnorm_bwd( |
991 | 940 | )(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count) |
992 | 941 |
|
993 | 942 |
|
994 | | -@_rmsnorm_bwd.register_fake |
995 | | -def _rmsnorm_bwd_fake( |
996 | | - x: Tensor, |
997 | | - weight: Optional[Tensor], |
998 | | - dout: Tensor, |
999 | | - rstd: Tensor, |
1000 | | - dx: Tensor, |
1001 | | - dw_partial: Optional[Tensor], |
1002 | | - db_partial: Optional[Tensor] = None, |
1003 | | - dresidual_out: Optional[Tensor] = None, |
1004 | | - dresidual: Optional[Tensor] = None, |
1005 | | - sm_count: Optional[int] = None, |
1006 | | -) -> None: |
1007 | | - # See softmax.py _softmax_fwd_fake for why register_fake is needed. |
1008 | | - from quack.cache_utils import COMPILE_ONLY |
1009 | | - |
1010 | | - if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt): |
1011 | | - N = x.size(-1) |
1012 | | - per_head = x.dim() == 3 |
1013 | | - if dw_partial is None and db_partial is None and sm_count is None: |
1014 | | - return |
1015 | | - dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [ |
1016 | | - torch2cute_dtype_map[t.dtype] if t is not None else None |
1017 | | - for t in [x, dout, dx, weight, dresidual, dresidual_out] |
1018 | | - ] |
1019 | | - _compile_rmsnorm_bwd( |
1020 | | - N, |
1021 | | - dtype, |
1022 | | - dout_dtype, |
1023 | | - dx_dtype, |
1024 | | - weight_dtype, |
1025 | | - db_partial is not None, |
1026 | | - dres_dtype, |
1027 | | - dres_out_dtype, |
1028 | | - dw_partial is not None, |
1029 | | - per_head, |
1030 | | - ) |
1031 | | - |
1032 | | - |
1033 | 943 | @jit_cache |
1034 | 944 | def _compile_rmsnorm_bwd( |
1035 | 945 | N, |
|
0 commit comments