Skip to content

Commit c68fffe

Browse files
authored
Merge pull request #48 from InfiniTensor/update-rms-norm-premake-to-better-support-aot-compilation
Update `rms_norm.premake` to better support AOT compilation
2 parents 4c677af + c3d617f commit c68fffe

2 files changed

Lines changed: 18 additions & 9 deletions

File tree

src/ntops/kernels/rms_norm.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
2-
import math
32

3+
import ninetoothed
44
import ninetoothed.language as ntl
55
from ninetoothed import Tensor
66

@@ -20,17 +20,24 @@ def application(input, weight, eps, output, num_normalized_elements):
2020
output[i] = input[i] / rms * weight[i]
2121

2222

23-
def premake(ndim, normalized_shape, dtype=None, block_size=None):
24-
dims = tuple(-(dim + 1) for dim in range(len(normalized_shape)))
23+
def premake(
24+
ndim,
25+
num_normalized_dims,
26+
input_dtype=None,
27+
weight_dtype=None,
28+
output_dtype=None,
29+
block_size=None,
30+
):
31+
dims = tuple(-(dim + 1) for dim in range(num_normalized_dims))
2532

2633
arrangement_ = functools.partial(arrangement, dim=dims, block_size=block_size)
2734

2835
tensors = (
29-
Tensor(ndim, other=0, dtype=dtype),
30-
Tensor(ndim, dtype=dtype),
31-
Tensor(0, dtype=dtype),
32-
Tensor(ndim, dtype=dtype),
33-
Tensor(0, dtype=dtype, constexpr=True, value=math.prod(normalized_shape)),
36+
Tensor(ndim, other=0, dtype=input_dtype),
37+
Tensor(ndim, dtype=weight_dtype),
38+
Tensor(0, dtype=ninetoothed.float32),
39+
Tensor(ndim, dtype=output_dtype),
40+
Tensor(0, dtype=ninetoothed.uint64),
3441
)
3542

3643
return arrangement_, application, tensors

src/ntops/torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,9 @@ def rms_norm(input, normalized_shape, weight=None, eps=None):
394394

395395
output = torch.empty_like(input)
396396

397-
kernel = _cached_make(ntops.kernels.rms_norm.premake, input.ndim, normalized_shape)
397+
kernel = _cached_make(
398+
ntops.kernels.rms_norm.premake, input.ndim, len(normalized_shape)
399+
)
398400

399401
kernel(input, weight, eps, output, math.prod(normalized_shape))
400402

0 commit comments

Comments
 (0)