Skip to content

Commit a948d7d

Browse files
committed
Add assertions for enable_gqa instead of forcing it to be set to True
1 parent 57119f0 commit a948d7d

1 file changed

Lines changed: 17 additions & 5 deletions

File tree

src/ntops/torch.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,23 +325,35 @@ def scaled_dot_product_attention(
325325
dropout_p=0,
326326
is_causal=False,
327327
scale=None,
328-
# The default value here differs from that of
329-
# `torch.nn.functional.scaled_dot_product_attention`
330-
# because GQA cannot be disabled at the moment.
331-
enable_gqa=True,
328+
enable_gqa=False,
332329
present_key=None,
333330
present_value=None,
334331
present_key_slot=None,
335332
present_value_slot=None,
336333
):
337334
# TODO: Support `dropout_p`.
338335
assert dropout_p == 0, "`dropout_p` is not supported yet."
339-
assert enable_gqa, "GQA must be enabled for now."
340336

341337
assert attn_mask is None or not is_causal, (
342338
"Cannot use `attn_mask` and `is_causal` together."
343339
)
344340

341+
num_heads_q = query.shape[-3]
342+
num_heads_kv = key.shape[-3]
343+
344+
assert num_heads_kv == value.shape[-3], (
345+
"Number of heads in `key` and `value` must be the same."
346+
)
347+
348+
if not enable_gqa:
349+
assert num_heads_q == num_heads_kv, (
350+
"Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled."
351+
)
352+
else:
353+
assert num_heads_q % num_heads_kv == 0, (
354+
"Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled."
355+
)
356+
345357
mask_shape = query.shape[:-1] + (key.shape[-2],)
346358

347359
if attn_mask is not None:

0 commit comments

Comments
 (0)