File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -362,23 +362,35 @@ def scaled_dot_product_attention(
362362 dropout_p = 0 ,
363363 is_causal = False ,
364364 scale = None ,
365- # The default value here differs from that of
366- # `torch.nn.functional.scaled_dot_product_attention`
367- # because GQA cannot be disabled at the moment.
368- enable_gqa = True ,
365+ enable_gqa = False ,
369366 present_key = None ,
370367 present_value = None ,
371368 present_key_slot = None ,
372369 present_value_slot = None ,
373370):
374371 # TODO: Support `dropout_p`.
375372 assert dropout_p == 0 , "`dropout_p` is not supported yet."
376- assert enable_gqa , "GQA must be enabled for now."
377373
378374 assert attn_mask is None or not is_causal , (
379375 "Cannot use `attn_mask` and `is_causal` together."
380376 )
381377
378+ num_heads_q = query .shape [- 3 ]
379+ num_heads_kv = key .shape [- 3 ]
380+
381+ assert num_heads_kv == value .shape [- 3 ], (
382+ "Number of heads in `key` and `value` must be the same."
383+ )
384+
385+ if not enable_gqa :
386+ assert num_heads_q == num_heads_kv , (
387+ "Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled."
388+ )
389+ else :
390+ assert num_heads_q % num_heads_kv == 0 , (
391+ "Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled."
392+ )
393+
382394 mask_shape = query .shape [:- 1 ] + (key .shape [- 2 ],)
383395
384396 if attn_mask is not None :
You can’t perform that action at this time.
0 commit comments