File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -325,23 +325,35 @@ def scaled_dot_product_attention(
325325 dropout_p = 0 ,
326326 is_causal = False ,
327327 scale = None ,
328- # The default value here differs from that of
329- # `torch.nn.functional.scaled_dot_product_attention`
330- # because GQA cannot be disabled at the moment.
331- enable_gqa = True ,
328+ enable_gqa = False ,
332329 present_key = None ,
333330 present_value = None ,
334331 present_key_slot = None ,
335332 present_value_slot = None ,
336333):
337334 # TODO: Support `dropout_p`.
338335 assert dropout_p == 0 , "`dropout_p` is not supported yet."
339- assert enable_gqa , "GQA must be enabled for now."
340336
341337 assert attn_mask is None or not is_causal , (
342338 "Cannot use `attn_mask` and `is_causal` together."
343339 )
344340
341+ num_heads_q = query .shape [- 3 ]
342+ num_heads_kv = key .shape [- 3 ]
343+
344+ assert num_heads_kv == value .shape [- 3 ], (
345+ "Number of heads in `key` and `value` must be the same."
346+ )
347+
348+ if not enable_gqa :
349+ assert num_heads_q == num_heads_kv , (
350+ "Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled."
351+ )
352+ else :
353+ assert num_heads_q % num_heads_kv == 0 , (
354+ "Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled."
355+ )
356+
345357 mask_shape = query .shape [:- 1 ] + (key .shape [- 2 ],)
346358
347359 if attn_mask is not None :
You can’t perform that action at this time.
0 commit comments