Skip to content

Commit 1ef949e

Browse files
committed
Prevent add_special_tokens overwrite for BART
1 parent 5265385 commit 1ef949e

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

vllm_bart_plugin/bart.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,9 @@ def _call_hf_processor(
10411041
has_encoder_data = mm_data is not None and "texts" in mm_data
10421042
result = {}
10431043

1044+
# vLLM may pass add_special_tokens in tok_kwargs; we set it ourselves
1045+
tok_kwargs = {k: v for k, v in tok_kwargs.items() if k != "add_special_tokens"}
1046+
10441047
if has_encoder_data:
10451048
# Tokenize the encoder text from mm_data
10461049
encoder_texts = mm_data["texts"]
@@ -1152,8 +1155,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11521155
config.vocab_size, config.d_model, embed_scale=embed_scale
11531156
)
11541157
# Bias added to logits after lm_head, matching HuggingFace approach
1155-
self.register_buffer("final_logits_bias",
1156-
torch.zeros((1, config.vocab_size)))
1158+
self.register_buffer("final_logits_bias", torch.zeros((1, config.vocab_size)))
11571159
self.logits_processor = LogitsProcessor(
11581160
self.unpadded_vocab_size, config.vocab_size
11591161
)
@@ -1341,7 +1343,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
13411343
"Shared weight embedding already loaded with name "
13421344
"%s, skipping. This is expected on facebook/bart-large"
13431345
" like models, where the same shared embedding is "
1344-
"present multiple times.", name)
1346+
"present multiple times.",
1347+
name,
1348+
)
13451349
continue
13461350

13471351
loader = AutoWeightsLoader(

0 commit comments

Comments
 (0)