Skip to content

Commit 54f9308

Browse files
committed
Relax the constexpr constraints on the shapes of present_key, present_value, present_key_slot, and present_value_slot
1 parent 351c342 commit 54f9308

1 file changed

Lines changed: 2 additions & 11 deletions

File tree

src/ntops/kernels/scaled_dot_product_attention.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def arrange_key_or_value(input):
4646
return arranged
4747

4848
def arrange_present_key_or_present_value(input):
49-
arranged = input.tile((1, 1, -1, -1))
49+
arranged = input.tile((1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N))
5050
arranged.dtype = arranged.dtype.squeeze((0, 1))
5151

5252
return arranged
@@ -165,16 +165,7 @@ def make(with_kv_cache):
165165
for _ in range(5)
166166
)
167167
present_key, present_value, present_key_slot, present_value_slot = (
168-
Tensor(
169-
4,
170-
shape_options=(
171-
None,
172-
None,
173-
{"constexpr": True, "upper_bound": 1},
174-
{"constexpr": True, "upper_bound": 128},
175-
),
176-
)
177-
for _ in range(4)
168+
Tensor(4) for _ in range(4)
178169
)
179170
scale = Tensor(0)
180171
is_causal, with_attn_mask = (Tensor(0, constexpr=True) for _ in range(2))

0 commit comments

Comments
 (0)