11import functools
2- import math
32
3+ import ninetoothed
44import ninetoothed .language as ntl
55from ninetoothed import Tensor
66
@@ -20,17 +20,24 @@ def application(input, weight, eps, output, num_normalized_elements):
2020 output [i ] = input [i ] / rms * weight [i ]
2121
2222
23- def premake (ndim , normalized_shape , dtype = None , block_size = None ):
24- dims = tuple (- (dim + 1 ) for dim in range (len (normalized_shape )))
23+ def premake (
24+ ndim ,
25+ num_normalized_dims ,
26+ input_dtype = None ,
27+ weight_dtype = None ,
28+ output_dtype = None ,
29+ block_size = None ,
30+ ):
31+ dims = tuple (- (dim + 1 ) for dim in range (num_normalized_dims ))
2532
2633 arrangement_ = functools .partial (arrangement , dim = dims , block_size = block_size )
2734
2835 tensors = (
29- Tensor (ndim , other = 0 , dtype = dtype ),
30- Tensor (ndim , dtype = dtype ),
31- Tensor (0 , dtype = dtype ),
32- Tensor (ndim , dtype = dtype ),
33- Tensor (0 , dtype = dtype , constexpr = True , value = math . prod ( normalized_shape ) ),
36+ Tensor (ndim , other = 0 , dtype = input_dtype ),
37+ Tensor (ndim , dtype = weight_dtype ),
38+ Tensor (0 , dtype = ninetoothed . float32 ),
39+ Tensor (ndim , dtype = output_dtype ),
40+ Tensor (0 , dtype = ninetoothed . uint64 ),
3441 )
3542
3643 return arrangement_ , application , tensors
0 commit comments