Skip to content

Commit 6516af8

Browse files
committed
Add assertions for enable_gqa instead of forcing it to be set to True
1 parent 95c5310 commit 6516af8

1 file changed

Lines changed: 17 additions & 5 deletions

File tree

src/ntops/torch.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,23 +362,35 @@ def scaled_dot_product_attention(
362362
dropout_p=0,
363363
is_causal=False,
364364
scale=None,
365-
# The default value here differs from that of
366-
# `torch.nn.functional.scaled_dot_product_attention`
367-
# because GQA cannot be disabled at the moment.
368-
enable_gqa=True,
365+
enable_gqa=False,
369366
present_key=None,
370367
present_value=None,
371368
present_key_slot=None,
372369
present_value_slot=None,
373370
):
374371
# TODO: Support `dropout_p`.
375372
assert dropout_p == 0, "`dropout_p` is not supported yet."
376-
assert enable_gqa, "GQA must be enabled for now."
377373

378374
assert attn_mask is None or not is_causal, (
379375
"Cannot use `attn_mask` and `is_causal` together."
380376
)
381377

378+
num_heads_q = query.shape[-3]
379+
num_heads_kv = key.shape[-3]
380+
381+
assert num_heads_kv == value.shape[-3], (
382+
"Number of heads in `key` and `value` must be the same."
383+
)
384+
385+
if not enable_gqa:
386+
assert num_heads_q == num_heads_kv, (
387+
"Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled."
388+
)
389+
else:
390+
assert num_heads_q % num_heads_kv == 0, (
391+
"Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled."
392+
)
393+
382394
mask_shape = query.shape[:-1] + (key.shape[-2],)
383395

384396
if attn_mask is not None:

0 commit comments

Comments
 (0)