Skip to content

Commit 0927ac4

Browse files
committed
Add M2M100/NLLB support (nllb-200-distilled-600M, 1.3B, 3.3B)
Adds M2M100ForConditionalGeneration support for the three NLLB distilled models: facebook/nllb-200-distilled-600M, 1.3B, and 3.3B. Architecture differences from BART implemented in nllb.py: - Sinusoidal (fixed) positional embeddings instead of learned - PRE-LayerNorm (norm before sublayer) instead of POST-LayerNorm - Additional layer_norm after all encoder/decoder layers - ReLU activation instead of GELU - No final_logits_bias Language routing: - Decoder starts with target language token via create_decoder_prompt - Source language token is prepended to encoder input via src_lang in mm_processor_kwargs (make_nllb_prompt helper provided) Also fixes BartMultiModalProcessor for vLLM >=0.18 compatibility: - create_encoder_prompt: always return [0] placeholder (inputs.prompt is the decoder prompt in 0.18, not the encoder text) - _call_hf_processor: handle already-tokenized token ID lists - _is_empty removed in 0.18; replaced with inline emptiness check Tests: 12 unit tests (no GPU) + 13 integration tests covering English/non-English sources, 4 target scripts, batch, determinism, and max_tokens.
1 parent 331e24a commit 0927ac4

8 files changed

Lines changed: 1629 additions & 29 deletions

File tree

example_nllb_usage.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Example: NLLB translation with vLLM via the bart-plugin.
2+
3+
Supported models (all use model_type=m2m_100):
4+
facebook/nllb-200-distilled-600M (~1.2 GB)
5+
facebook/nllb-200-distilled-1.3B (~2.6 GB)
6+
facebook/nllb-200-3.3B (~6.6 GB)
7+
8+
Language codes follow the FLORES-200 format: <language>_<script>
9+
English → eng_Latn
10+
French → fra_Latn
11+
German → deu_Latn
12+
Arabic → arb_Arab
13+
Chinese → zho_Hans
14+
Amharic → amh_Ethi
15+
Hindi → hin_Deva
16+
(200+ languages supported)
17+
18+
Run:
19+
python example_nllb_usage.py
20+
21+
Required:
22+
pip install vllm-bart-plugin
23+
"""
24+
25+
import os
26+
27+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
28+
29+
from vllm import LLM, SamplingParams
30+
from vllm_bart_plugin.nllb import make_nllb_prompt
31+
32+
MODEL_NAME = "facebook/nllb-200-distilled-600M"
33+
34+
# ---------------------------------------------------------------------------
35+
# Demo 1: English → multiple target languages
36+
# ---------------------------------------------------------------------------
37+
38+
ENGLISH_TEXTS = [
39+
"The United Nations was founded in 1945.",
40+
"Machine translation has improved significantly in recent years.",
41+
"Hello, how are you doing today?",
42+
]
43+
44+
TARGET_LANGS = [
45+
("French", "fra_Latn"),
46+
("German", "deu_Latn"),
47+
("Spanish", "spa_Latn"),
48+
("Arabic", "arb_Arab"),
49+
("Chinese", "zho_Hans"),
50+
]
51+
52+
# ---------------------------------------------------------------------------
53+
# Demo 2: Non-English source → English
54+
# ---------------------------------------------------------------------------
55+
56+
NON_ENGLISH_TEXTS = [
57+
# Amharic (Ge'ez script)
58+
("amh_Ethi", "eng_Latn", "ሰላም፣ ዓለም! የተባበሩት መንግሥታት ድርጅት በ1945 ዓ.ም ተቋቋመ።"),
59+
# French → German
60+
("fra_Latn", "deu_Latn", "La traduction automatique s'est beaucoup améliorée."),
61+
# Hindi → English
62+
("hin_Deva", "eng_Latn", "संयुक्त राष्ट्र की स्थापना 1945 में हुई थी।"),
63+
]
64+
65+
66+
def main():
67+
llm = LLM(
68+
model=MODEL_NAME,
69+
enforce_eager=True,
70+
max_model_len=512,
71+
gpu_memory_utilization=0.15,
72+
dtype="float16",
73+
)
74+
params = SamplingParams(temperature=0.0, max_tokens=60)
75+
76+
# --- Demo 1: English source -------------------------------------------
77+
print("=" * 60)
78+
print("Demo 1: English → multiple languages")
79+
print("=" * 60)
80+
81+
for tgt_name, tgt_lang in TARGET_LANGS:
82+
prompts = [
83+
make_nllb_prompt(text, src_lang="eng_Latn", tgt_lang=tgt_lang)
84+
for text in ENGLISH_TEXTS
85+
]
86+
outputs = llm.generate(prompts, sampling_params=params)
87+
print(f"\n{tgt_name} ({tgt_lang})")
88+
for text, out in zip(ENGLISH_TEXTS, outputs):
89+
print(f" [EN] {text}")
90+
print(f" [{tgt_lang[:3].upper()}] {out.outputs[0].text}")
91+
92+
# --- Demo 2: Non-English sources --------------------------------------
93+
print("\n" + "=" * 60)
94+
print("Demo 2: Non-English sources")
95+
print("=" * 60)
96+
97+
for src_lang, tgt_lang, text in NON_ENGLISH_TEXTS:
98+
prompt = make_nllb_prompt(text, src_lang=src_lang, tgt_lang=tgt_lang)
99+
out = llm.generate([prompt], sampling_params=params)[0]
100+
print(f"\n[{src_lang}] {text}")
101+
print(f"[{tgt_lang}] {out.outputs[0].text}")
102+
103+
104+
if __name__ == "__main__":
105+
main()

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vllm-bart-plugin"
7-
version = "0.3.3"
8-
description = "BART model plugin for vLLM"
7+
version = "0.3.4"
8+
description = "BART, Florence-2, and NLLB/M2M-100 (translation) model plugin for vLLM"
99
readme = "README.md"
1010
requires-python = ">=3.10"
1111
license = {text = "Apache-2.0"}
1212
authors = [
1313
{name = "Nicolò Lucchesi", email = "nick.lucche@redhat.com"}
1414
]
15-
keywords = ["vllm", "bart", "language-model", "inference", "plugin"]
15+
keywords = ["vllm", "bart", "nllb", "m2m100", "translation", "language-model", "inference", "plugin"]
1616
classifiers = [
1717
"Development Status :: 3 - Alpha",
1818
"Intended Audience :: Developers",

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@ def cuda_available():
1010
return torch.cuda.is_available()
1111

1212

13+
@pytest.fixture
14+
def vllm_config_ctx():
15+
"""Context manager that sets a minimal vLLM config.
16+
17+
Required for tests that instantiate vLLM attention layers directly
18+
(Attention, MMEncoderAttention, CrossAttention all call
19+
get_current_vllm_config() during __init__).
20+
"""
21+
from vllm.config import VllmConfig, set_current_vllm_config
22+
vllm_config = VllmConfig()
23+
with set_current_vllm_config(vllm_config):
24+
yield vllm_config
25+
26+
1327
@pytest.fixture(scope="session")
1428
def device():
1529
"""Get the device to use for tests."""

0 commit comments

Comments
 (0)