We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b4677d5 commit e040da3Copy full SHA for e040da3
1 file changed
vllm_bart_plugin/t5gemma2.py
@@ -300,6 +300,23 @@ def __init__(
300
quant_config=quant_config,
301
prefix=f"{prefix}.k_proj",
302
)
303
+ self.v_proj = ColumnParallelLinear(
304
+ hidden_size,
305
+ self.total_num_kv_heads * self.head_dim,
306
+ bias=False,
307
+ quant_config=quant_config,
308
+ prefix=f"{prefix}.v_proj",
309
+ )
310
+ self.o_proj = RowParallelLinear(
311
+ self.total_num_heads * self.head_dim,
312
313
314
315
+ prefix=f"{prefix}.o_proj",
316
317
+ self.q_norm = GemmaRMSNorm(self.head_dim, eps=1e-6)
318
+ self.k_norm = GemmaRMSNorm(self.head_dim, eps=1e-6)
319
+
320
if rope_parameters:
321
self.rotary_emb = get_rope(
322
self.head_dim,
0 commit comments