Skip to content

Commit 1f32628

Browse files
committed
Keep up with NineToothed 0.16.0 updates
1 parent 358a627 commit 1f32628

3 files changed

Lines changed: 13 additions & 7 deletions

File tree

attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def application(q, k, v, o):
5656
o = acc # noqa: F841
5757

5858

59-
q, k, v, o = (Tensor(4, constexpr_shape=True) for _ in range(4))
59+
q, k, v, o = (
60+
Tensor(4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}))
61+
for _ in range(4)
62+
)
6063
attention_kernel = ninetoothed.make(arrangement, application, (q, k, v, o))
6164

6265

conv2d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ def arrangement(input, filter, output):
2323
return matmul.arrangement(input_flattened, filter_permuted, output_flattened)
2424

2525

26-
conv2d_kernel = ninetoothed.make(
27-
arrangement,
28-
matmul.application,
29-
(Tensor(4), Tensor(4, constexpr_shape=True), Tensor(4)),
26+
filter_shape_options = (
27+
None,
28+
None,
29+
{"constexpr": True, "upper_bound": 16},
30+
{"constexpr": True, "upper_bound": 16},
3031
)
32+
tensors = (Tensor(4), Tensor(4, shape_options=filter_shape_options), Tensor(4))
33+
conv2d_kernel = ninetoothed.make(arrangement, matmul.application, tensors)
3134

3235

3336
def conv2d(input, filter):

max_pool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
def arrangement(input, output):
1212
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
1313

14-
WINDOW_HEIGHT = Symbol("WINDOW_HEIGHT", constexpr=True)
15-
WINDOW_WIDTH = Symbol("WINDOW_WIDTH", constexpr=True)
14+
WINDOW_HEIGHT = Symbol("WINDOW_HEIGHT", constexpr=True, upper_bound=16)
15+
WINDOW_WIDTH = Symbol("WINDOW_WIDTH", constexpr=True, upper_bound=16)
1616

1717
input_arranged = input.tile((1, 1, WINDOW_HEIGHT, WINDOW_WIDTH))
1818
input_arranged = input_arranged.ravel()

0 commit comments

Comments
 (0)