Skip to content

Commit daa9ad6

Browse files
committed
Fix the dtype values of zero-dimensional tensors in the premake functions
1 parent 88ad686 commit daa9ad6

File tree

9 files changed

+22
-17
lines changed

9 files changed

+22
-17
lines changed

src/ntops/kernels/add.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22

3+
import ninetoothed
34
from ninetoothed import Tensor
45

56
from ntops.kernels.element_wise import arrangement
@@ -15,7 +16,7 @@ def premake(ndim, dtype=None, block_size=None):
1516
tensors = (
1617
Tensor(ndim, dtype=dtype),
1718
Tensor(ndim, dtype=dtype),
18-
Tensor(0, dtype=dtype),
19+
Tensor(0, dtype=ninetoothed.float64),
1920
Tensor(ndim, dtype=dtype),
2021
)
2122

src/ntops/kernels/addmm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22

3+
import ninetoothed
34
import ninetoothed.language as ntl
45
from ninetoothed import Tensor
56

@@ -84,10 +85,10 @@ def premake(
8485
Tensor(2, dtype=dtype),
8586
Tensor(2, dtype=dtype),
8687
Tensor(2, dtype=dtype),
87-
Tensor(0, dtype=dtype),
88-
Tensor(0, dtype=dtype),
88+
Tensor(0, dtype=ninetoothed.float64),
89+
Tensor(0, dtype=ninetoothed.float64),
8990
Tensor(2, dtype=dtype),
90-
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
91+
Tensor(0, constexpr=True, value=input_precision),
9192
)
9293

9394
return arrangement_, application, tensors

src/ntops/kernels/bmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def premake(
6161
Tensor(3, dtype=dtype),
6262
Tensor(3, dtype=dtype),
6363
Tensor(3, dtype=dtype),
64-
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
64+
Tensor(0, constexpr=True, value=input_precision),
6565
)
6666

6767
return arrangement_, application, tensors

src/ntops/kernels/dropout.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22

3+
import ninetoothed
34
import ninetoothed.language as ntl
45
from ninetoothed import Tensor
56

@@ -15,8 +16,8 @@ def premake(ndim, dtype=None, block_size=None):
1516

1617
tensors = (
1718
Tensor(ndim, dtype=dtype),
18-
Tensor(0, dtype=dtype),
19-
Tensor(0, dtype=dtype),
19+
Tensor(0, dtype=ninetoothed.float64),
20+
Tensor(0, dtype=ninetoothed.int64),
2021
Tensor(ndim, dtype=dtype),
2122
)
2223

src/ntops/kernels/layer_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import math
33

4+
import ninetoothed
45
import ninetoothed.language as ntl
56
from ninetoothed import Tensor
67

@@ -39,9 +40,9 @@ def premake(ndim, normalized_shape, dtype=None, block_size=None):
3940
Tensor(ndim, other=0, dtype=dtype),
4041
Tensor(ndim, dtype=dtype),
4142
Tensor(ndim, dtype=dtype),
42-
Tensor(0, dtype=dtype),
43+
Tensor(0, dtype=ninetoothed.float64),
4344
Tensor(ndim, dtype=dtype),
44-
Tensor(0, dtype=dtype, constexpr=True, value=math.prod(normalized_shape)),
45+
Tensor(0, constexpr=True, value=math.prod(normalized_shape)),
4546
)
4647

4748
return arrangement_, application, tensors

src/ntops/kernels/mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def premake(
8383
Tensor(2, dtype=dtype),
8484
Tensor(2, dtype=dtype),
8585
Tensor(2, dtype=dtype),
86-
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
86+
Tensor(0, constexpr=True, value=input_precision),
8787
)
8888

8989
return arrangement_, application, tensors

src/ntops/kernels/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def premake(
3535
tensors = (
3636
Tensor(ndim, other=0, dtype=input_dtype),
3737
Tensor(ndim, dtype=weight_dtype),
38-
Tensor(0, dtype=ninetoothed.float32),
38+
Tensor(0, dtype=ninetoothed.float64),
3939
Tensor(ndim, dtype=output_dtype),
40-
Tensor(0, dtype=ninetoothed.uint64),
40+
Tensor(0, dtype=ninetoothed.int64),
4141
)
4242

4343
return arrangement_, application, tensors

src/ntops/kernels/scaled_dot_product_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ def premake(
229229
present_key, present_value, present_key_slot, present_value_slot = (
230230
Tensor(4, dtype=dtype) for _ in range(4)
231231
)
232-
scale = Tensor(0, dtype=dtype)
233-
is_causal = Tensor(0, dtype=dtype, constexpr=True, value=is_causal)
234-
with_attn_mask = Tensor(0, dtype=dtype, constexpr=True, value=with_attn_mask)
235-
causal_variant = Tensor(0, dtype=dtype, constexpr=True, value=causal_variant)
232+
scale = Tensor(0, dtype=ninetoothed.float64)
233+
is_causal = Tensor(0, constexpr=True, value=is_causal)
234+
with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask)
235+
causal_variant = Tensor(0, constexpr=True, value=causal_variant)
236236

237237
if emb_dim is not None:
238238
for tensor in (query, key, value, attn_mask, output):

src/ntops/kernels/sub.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22

3+
import ninetoothed
34
from ninetoothed import Tensor
45

56
from ntops.kernels.element_wise import arrangement
@@ -15,7 +16,7 @@ def premake(ndim, dtype=None, block_size=None):
1516
tensors = (
1617
Tensor(ndim, dtype=dtype),
1718
Tensor(ndim, dtype=dtype),
18-
Tensor(0, dtype=dtype),
19+
Tensor(0, dtype=ninetoothed.float64),
1920
Tensor(ndim, dtype=dtype),
2021
)
2122

0 commit comments

Comments
 (0)