Skip to content

Commit 327c65e

Browse files
committed
fix(ninetoothed): bound rms norm autotune shapes
1 parent 87e86ab commit 327c65e

1 file changed

Lines changed: 48 additions & 1 deletion

File tree

src/ninetoothed/ops/rms_norm/build.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import itertools
23

34
import ninetoothed
@@ -9,6 +10,9 @@
910

1011
_DEFAULT_NDIMS = (2, 3, 4)
1112

13+
_BATCH_DIM_AUTO_TUNE_SIZE = 1
14+
_NORMALIZED_DIM_AUTO_TUNE_SIZE = 256
15+
1216
_CONFIGS = tuple(
1317
(
1418
(),
@@ -30,11 +34,54 @@
3034
)
3135

3236

37+
def _shape_options(ndim, num_normalized_dims):
38+
batch_dims = ndim - num_normalized_dims
39+
40+
return (
41+
*({"upper_bound": _BATCH_DIM_AUTO_TUNE_SIZE} for _ in range(batch_dims)),
42+
*(
43+
{"upper_bound": _NORMALIZED_DIM_AUTO_TUNE_SIZE}
44+
for _ in range(num_normalized_dims)
45+
),
46+
)
47+
48+
49+
def _premake(
50+
ndim,
51+
num_normalized_dims,
52+
input_dtype=None,
53+
weight_dtype=None,
54+
output_dtype=None,
55+
block_size=None,
56+
):
57+
dims = tuple(-(dim + 1) for dim in range(num_normalized_dims))
58+
arrangement = functools.partial(
59+
ntops.kernels.reduction.arrangement,
60+
dim=dims,
61+
block_size=block_size,
62+
)
63+
shape_options = _shape_options(ndim, num_normalized_dims)
64+
tensors = (
65+
ninetoothed.Tensor(
66+
ndim,
67+
other=0,
68+
dtype=input_dtype,
69+
shape_options=shape_options,
70+
),
71+
ninetoothed.Tensor(ndim, dtype=weight_dtype, shape_options=shape_options),
72+
ninetoothed.Tensor(0, dtype=ninetoothed.float32),
73+
ninetoothed.Tensor(ndim, dtype=output_dtype, shape_options=shape_options),
74+
ninetoothed.Tensor(0, dtype=ninetoothed.uint64),
75+
)
76+
77+
return arrangement, ntops.kernels.rms_norm.application, tensors
78+
79+
3380
def build(output_dir):
3481
variant_dir = output_dir / "rms_norm"
3582
variant_dir.mkdir(parents=True, exist_ok=True)
3683
ninetoothed.build(
37-
ntops.kernels.rms_norm.premake,
84+
_premake,
3885
_CONFIGS,
3986
meta_parameters=("block_size",),
4087
caller="cuda",

0 commit comments

Comments
 (0)