Skip to content

Commit e040da3

Browse files
committed
Fix T5Gemma2Attention in t5gemma2.py by adding missing initializations
TAG=agy CONV=7b651079-7501-4d53-a9a3-6058ec11ea33 Signed-off-by: Akhilesh Kumar <akhilbussiness@gmail.com>
1 parent b4677d5 commit e040da3

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

vllm_bart_plugin/t5gemma2.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,23 @@ def __init__(
300300
quant_config=quant_config,
301301
prefix=f"{prefix}.k_proj",
302302
)
303+
self.v_proj = ColumnParallelLinear(
304+
hidden_size,
305+
self.total_num_kv_heads * self.head_dim,
306+
bias=False,
307+
quant_config=quant_config,
308+
prefix=f"{prefix}.v_proj",
309+
)
310+
self.o_proj = RowParallelLinear(
311+
self.total_num_heads * self.head_dim,
312+
hidden_size,
313+
bias=False,
314+
quant_config=quant_config,
315+
prefix=f"{prefix}.o_proj",
316+
)
317+
self.q_norm = GemmaRMSNorm(self.head_dim, eps=1e-6)
318+
self.k_norm = GemmaRMSNorm(self.head_dim, eps=1e-6)
319+
303320
if rope_parameters:
304321
self.rotary_emb = get_rope(
305322
self.head_dim,

0 commit comments

Comments
 (0)