Skip to content

Commit c75fb75

Browse files
committed
[DSL] Implement cute_op to avoid having to write *_fake impl
1 parent 32534e4 commit c75fb75

16 files changed

Lines changed: 258 additions & 491 deletions

quack/cross_entropy.py

Lines changed: 3 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import quack.copy_utils as copy_utils
1818
import quack.layout_utils as layout_utils
1919
from quack.compile_utils import make_fake_tensor as fake_tensor
20+
from quack.dsl import cute_op
2021
from quack.reduce import row_reduce, online_softmax_reduce
2122
from quack.reduction_base import ReductionBase
2223
from quack.cache_utils import jit_cache
@@ -308,7 +309,7 @@ def _compile_cross_entropy_fwd(
308309
)
309310

310311

311-
@torch.library.custom_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
312+
@cute_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
312313
def cross_entropy_fwd_out(
313314
x: Tensor,
314315
target: Tensor,
@@ -363,42 +364,6 @@ def cross_entropy_fwd_out(
363364
)(x, target, target_logit, loss, lse, dx, weight, Int32(ignore_index))
364365

365366

366-
@cross_entropy_fwd_out.register_fake
367-
def _cross_entropy_fwd_out_fake(
368-
x: Tensor,
369-
target: Tensor,
370-
target_logit: Optional[Tensor],
371-
loss: Tensor,
372-
lse: Optional[Tensor],
373-
dx: Optional[Tensor],
374-
weight: Optional[Tensor],
375-
ignore_index: int = -100,
376-
) -> None:
377-
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
378-
from quack.cache_utils import COMPILE_ONLY
379-
380-
if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
381-
N = x.size(1)
382-
dtype = torch2cute_dtype_map[x.dtype]
383-
target_dtype = torch2cute_dtype_map[target.dtype]
384-
target_logit_dtype = (
385-
torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None
386-
)
387-
target_logit_ndim = target_logit.ndim if target_logit is not None else None
388-
weight_dtype = torch2cute_dtype_map[weight.dtype] if weight is not None else None
389-
_compile_cross_entropy_fwd(
390-
dtype,
391-
target_dtype,
392-
target_logit_dtype,
393-
N,
394-
lse is not None,
395-
dx is not None,
396-
weight_dtype,
397-
target_logit_ndim,
398-
)
399-
_compile_cross_entropy_backward(dtype, target_dtype, N, weight_dtype)
400-
401-
402367
def cross_entropy_fwd(
403368
x: torch.Tensor,
404369
target: torch.Tensor,
@@ -649,7 +614,7 @@ def _cross_entropy_backward(
649614
)
650615

651616

652-
@torch.library.custom_op("quack::cross_entropy_bwd_out", mutates_args={"dx"})
617+
@cute_op("quack::cross_entropy_bwd_out", mutates_args={"dx"})
653618
def cross_entropy_bwd_out(
654619
x: torch.Tensor,
655620
target: torch.Tensor,
@@ -662,27 +627,6 @@ def cross_entropy_bwd_out(
662627
_cross_entropy_backward(x, target, dloss, lse, dx, weight, ignore_index)
663628

664629

665-
@cross_entropy_bwd_out.register_fake
666-
def _cross_entropy_bwd_out_fake(
667-
x: torch.Tensor,
668-
target: torch.Tensor,
669-
dloss: torch.Tensor,
670-
lse: torch.Tensor,
671-
dx: torch.Tensor,
672-
weight: Optional[torch.Tensor] = None,
673-
ignore_index: int = -100,
674-
) -> None:
675-
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
676-
from quack.cache_utils import COMPILE_ONLY
677-
678-
if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
679-
N = x.size(1)
680-
dtype = torch2cute_dtype_map[x.dtype]
681-
target_dtype = torch2cute_dtype_map[target.dtype]
682-
weight_dtype = torch2cute_dtype_map[weight.dtype] if weight is not None else None
683-
_compile_cross_entropy_backward(dtype, target_dtype, N, weight_dtype)
684-
685-
686630
def cross_entropy_bwd(
687631
x: torch.Tensor,
688632
target: torch.Tensor,

quack/dsl/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2+
3+
from quack.dsl.torch_library_op import cute_op
4+
5+
__all__ = ["cute_op"]

quack/dsl/torch_library_op.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2+
"""``cute_op``: ``torch.library.custom_op`` for CuTe DSL kernels.
3+
4+
Same trick as ``torch.library.triton_op`` (register the impl as the fake/meta
5+
kernel too), specialized for our setup:
6+
7+
* Under ``torch.compile`` we stay a complete no-op (matches prior behavior;
8+
also avoids moving compile latency into dynamo trace time).
9+
* Under ``FakeTensorMode`` with SymInt shapes (dynamic-shape tracing), skip:
10+
``@jit_cache`` is an ``lru_cache`` and SymInts are unhashable.
11+
* Otherwise (``FakeTensorMode`` with concrete shapes, e.g. the COMPILE_ONLY
12+
worker) flip ``cache_utils.COMPILE_ONLY`` for the duration of the call so
13+
``@jit_cache`` returns ``_noop_kernel`` for every ``_compile_*(...)`` it
14+
populates. The body runs end-to-end, the .o cache is filled, and no kernel
15+
is actually launched.
16+
17+
This removes the need for hand-written ``_*_fake`` twins on each op.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
from typing import Any, Callable, Iterable, Optional, Union
23+
24+
import torch
25+
26+
from quack import cache_utils
27+
28+
29+
__all__ = ["cute_op"]
30+
31+
32+
def _has_symint_shape(args: Iterable[Any]) -> bool:
33+
for a in args:
34+
if isinstance(a, torch.Tensor) and any(isinstance(s, torch.SymInt) for s in a.shape):
35+
return True
36+
return False
37+
38+
39+
def cute_op(
40+
name: str,
41+
*,
42+
mutates_args: Union[str, Iterable[str]],
43+
schema: Optional[str] = None,
44+
device_types: Optional[Union[str, Iterable[str]]] = None,
45+
) -> Callable:
46+
"""Like ``torch.library.triton_op``, but for CuTe DSL kernels.
47+
48+
Args:
49+
name: ``"namespace::op_name"``.
50+
mutates_args: Names of mutated tensor args.
51+
schema: Optional explicit schema. Required when mutating an
52+
``Optional[Tensor]`` arg (PyTorch can't infer those).
53+
device_types: Optional device-type restriction.
54+
"""
55+
56+
def dec(fn: Callable) -> Any:
57+
kwargs: dict[str, Any] = {"mutates_args": mutates_args}
58+
if schema is not None:
59+
kwargs["schema"] = schema
60+
if device_types is not None:
61+
kwargs["device_types"] = device_types
62+
op = torch.library.custom_op(name, fn, **kwargs)
63+
64+
@op.register_fake
65+
def _fake(*args, **kw):
66+
if torch.compiler.is_compiling():
67+
return
68+
if _has_symint_shape(args) or _has_symint_shape(kw.values()):
69+
return
70+
saved = cache_utils.COMPILE_ONLY
71+
cache_utils.COMPILE_ONLY = True
72+
try:
73+
fn(*args, **kw)
74+
finally:
75+
cache_utils.COMPILE_ONLY = saved
76+
77+
return op
78+
79+
return dec

quack/hadamard.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from quack.cache_utils import jit_cache
1818
from quack.compile_utils import make_fake_tensor as fake_tensor
1919
from quack.cute_dsl_utils import torch2cute_dtype_map
20+
from quack.dsl import cute_op
2021

2122

2223
def _next_power_of_2(n: int) -> int:
@@ -280,7 +281,7 @@ def _compile_hadamard_transform_fwd(dtype, N):
280281
)
281282

282283

283-
@torch.library.custom_op(
284+
@cute_op(
284285
"quack::_hadamard_transform_fwd",
285286
mutates_args={"out"},
286287
device_types="cuda",
@@ -298,16 +299,6 @@ def _hadamard_transform_fwd(x: Tensor, out: Tensor, scale: float) -> None:
298299
_compile_hadamard_transform_fwd(dtype, N)(x, out, scale)
299300

300301

301-
@_hadamard_transform_fwd.register_fake
302-
def _hadamard_transform_fwd_fake(x: Tensor, out: Tensor, scale: float) -> None:
303-
from quack.cache_utils import COMPILE_ONLY
304-
305-
if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
306-
N = x.size(1)
307-
dtype = torch2cute_dtype_map[x.dtype]
308-
_compile_hadamard_transform_fwd(dtype, N)
309-
310-
311302
def hadamard_transform_fwd(x: Tensor, scale: float = 1.0) -> Tensor:
312303
assert x.dim() >= 1, "Input must have at least one dimension"
313304
x = _ensure_last_dim_contiguous(x)

quack/rms_final_reduce.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from quack.reduction_base import ReductionBase
2222
from quack.cache_utils import jit_cache
2323
from quack.cute_dsl_utils import torch2cute_dtype_map
24+
from quack.dsl import cute_op
2425

2526

2627
class RmsFinalReduce(ReductionBase):
@@ -134,7 +135,7 @@ def _compile_rms_final_reduce(dtype, N):
134135
)
135136

136137

137-
@torch.library.custom_op(
138+
@cute_op(
138139
"quack::rms_final_reduce_out",
139140
mutates_args=("rstd",),
140141
device_types="cuda",
@@ -152,15 +153,6 @@ def _rms_final_reduce_out(
152153
compiled_fn(x, rstd, scale, eps)
153154

154155

155-
@_rms_final_reduce_out.register_fake
156-
def _rms_final_reduce_out_fake(x, rstd, scale, eps):
157-
from quack.cache_utils import COMPILE_ONLY
158-
159-
if COMPILE_ONLY and not isinstance(x.shape[0], torch.SymInt):
160-
x_dtype = torch2cute_dtype_map[x.dtype]
161-
_compile_rms_final_reduce(x_dtype, x.shape[1])
162-
163-
164156
def rms_final_reduce(
165157
x: Tensor, # (M, N) partial squared sums
166158
scale: float, # typically 1.0 / total_columns

quack/rmsnorm.py

Lines changed: 3 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import quack.copy_utils as copy_utils
1818
import quack.layout_utils as layout_utils
1919
from quack.compile_utils import make_fake_tensor as fake_tensor
20+
from quack.dsl import cute_op
2021
from quack.reduce import row_reduce
2122
from quack.reduction_base import ReductionBase
2223
from quack.cache_utils import jit_cache
@@ -316,7 +317,7 @@ def kernel(
316317
copy(tXrO, tXgO)
317318

318319

319-
@torch.library.custom_op(
320+
@cute_op(
320321
"quack::_rmsnorm_fwd",
321322
mutates_args=("out", "rstd", "mean", "residual_out"),
322323
device_types="cuda",
@@ -375,58 +376,6 @@ def _rmsnorm_fwd(
375376
)(x, weight, bias, residual, out, residual_out, rstd, mean, eps)
376377

377378

378-
@_rmsnorm_fwd.register_fake
379-
def _rmsnorm_fwd_fake(
380-
x: Tensor,
381-
weight: Optional[Tensor],
382-
out: Tensor,
383-
bias: Optional[Tensor] = None,
384-
rstd: Optional[Tensor] = None,
385-
mean: Optional[Tensor] = None,
386-
residual: Optional[Tensor] = None,
387-
residual_out: Optional[Tensor] = None,
388-
eps: float = 1e-6,
389-
is_layernorm: bool = False,
390-
) -> None:
391-
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
392-
from quack.cache_utils import COMPILE_ONLY
393-
394-
if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt):
395-
N = x.size(-1)
396-
per_head = (weight is not None and weight.dim() == 2) or (
397-
bias is not None and bias.dim() == 2
398-
)
399-
dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [
400-
torch2cute_dtype_map[t.dtype] if t is not None else None
401-
for t in [x, out, weight, bias, residual, residual_out]
402-
]
403-
_compile_rmsnorm_fwd(
404-
dtype,
405-
out_dtype,
406-
res_dtype,
407-
weight_dtype,
408-
bias_dtype,
409-
res_out_dtype,
410-
N,
411-
rstd is not None,
412-
mean is not None,
413-
is_layernorm,
414-
per_head,
415-
)
416-
_compile_rmsnorm_bwd(
417-
N,
418-
dtype,
419-
dtype,
420-
dtype,
421-
weight_dtype,
422-
bias is not None,
423-
res_dtype,
424-
res_out_dtype,
425-
weight is not None,
426-
per_head,
427-
)
428-
429-
430379
@jit_cache
431380
def _compile_rmsnorm_fwd(
432381
dtype,
@@ -921,7 +870,7 @@ def _get_sm_count(N: int, device: torch.device) -> int:
921870
return sm_count
922871

923872

924-
@torch.library.custom_op(
873+
@cute_op(
925874
"quack::_rmsnorm_bwd",
926875
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
927876
device_types="cuda",
@@ -991,45 +940,6 @@ def _rmsnorm_bwd(
991940
)(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count)
992941

993942

994-
@_rmsnorm_bwd.register_fake
995-
def _rmsnorm_bwd_fake(
996-
x: Tensor,
997-
weight: Optional[Tensor],
998-
dout: Tensor,
999-
rstd: Tensor,
1000-
dx: Tensor,
1001-
dw_partial: Optional[Tensor],
1002-
db_partial: Optional[Tensor] = None,
1003-
dresidual_out: Optional[Tensor] = None,
1004-
dresidual: Optional[Tensor] = None,
1005-
sm_count: Optional[int] = None,
1006-
) -> None:
1007-
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
1008-
from quack.cache_utils import COMPILE_ONLY
1009-
1010-
if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt):
1011-
N = x.size(-1)
1012-
per_head = x.dim() == 3
1013-
if dw_partial is None and db_partial is None and sm_count is None:
1014-
return
1015-
dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [
1016-
torch2cute_dtype_map[t.dtype] if t is not None else None
1017-
for t in [x, dout, dx, weight, dresidual, dresidual_out]
1018-
]
1019-
_compile_rmsnorm_bwd(
1020-
N,
1021-
dtype,
1022-
dout_dtype,
1023-
dx_dtype,
1024-
weight_dtype,
1025-
db_partial is not None,
1026-
dres_dtype,
1027-
dres_out_dtype,
1028-
dw_partial is not None,
1029-
per_head,
1030-
)
1031-
1032-
1033943
@jit_cache
1034944
def _compile_rmsnorm_bwd(
1035945
N,

0 commit comments

Comments
 (0)