Skip to content

Commit 06deeef

Browse files
committed
Merge origin/master into ak/t5gemma2 and resolve conflicts
TAG=agy CONV=7b651079-7501-4d53-a9a3-6058ec11ea33 Signed-off-by: Akhilesh Kumar <akhilbussiness@gmail.com>
2 parents e040da3 + d16cb22 commit 06deeef

6 files changed

Lines changed: 277 additions & 40 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vllm-bart-plugin"
7-
version = "0.2.0"
7+
version = "0.3.4"
88
description = "BART model plugin for vLLM"
99
readme = "README.md"
1010
requires-python = ">=3.10"
@@ -26,7 +26,7 @@ classifiers = [
2626
]
2727

2828
dependencies = [
29-
"vllm>=0.14.0",
29+
"vllm>=0.13.0,<=0.18",
3030
"torch>=2.9.0",
3131
"transformers>=4.56.0,<5",
3232
]

scripts/bump_version.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python3
2+
"""Simple script to bump the package version in setup.py and pyproject.toml."""
3+
4+
import argparse
5+
import re
6+
import sys
7+
from pathlib import Path
8+
9+
10+
def get_current_version(content: str) -> str | None:
11+
"""Extract version from file content."""
12+
match = re.search(r'version\s*=\s*["\']([^"\']+)["\']', content)
13+
return match.group(1) if match else None
14+
15+
16+
def bump_version(version: str, part: str) -> str:
17+
"""Bump the specified part of the version."""
18+
parts = version.split(".")
19+
if len(parts) != 3:
20+
raise ValueError(f"Expected semantic version (x.y.z), got: {version}")
21+
22+
major, minor, patch = map(int, parts)
23+
24+
if part == "major":
25+
major += 1
26+
minor = 0
27+
patch = 0
28+
elif part == "minor":
29+
minor += 1
30+
patch = 0
31+
elif part == "patch":
32+
patch += 1
33+
else:
34+
raise ValueError(f"Unknown version part: {part}")
35+
36+
return f"{major}.{minor}.{patch}"
37+
38+
39+
def update_file(filepath: Path, old_version: str, new_version: str) -> bool:
40+
"""Update version in a file. Returns True if file was modified."""
41+
if not filepath.exists():
42+
return False
43+
44+
content = filepath.read_text()
45+
new_content = re.sub(
46+
rf'(version\s*=\s*["\']){re.escape(old_version)}(["\'])',
47+
rf"\g<1>{new_version}\g<2>",
48+
content,
49+
)
50+
51+
if content != new_content:
52+
filepath.write_text(new_content)
53+
return True
54+
return False
55+
56+
57+
def main():
58+
parser = argparse.ArgumentParser(description="Bump package version")
59+
parser.add_argument(
60+
"part",
61+
choices=["major", "minor", "patch"],
62+
nargs="?",
63+
default="patch",
64+
help="Version part to bump (default: patch)",
65+
)
66+
parser.add_argument(
67+
"--set",
68+
dest="set_version",
69+
metavar="VERSION",
70+
help="Set a specific version instead of bumping",
71+
)
72+
parser.add_argument(
73+
"--dry-run",
74+
action="store_true",
75+
help="Show what would be changed without modifying files",
76+
)
77+
args = parser.parse_args()
78+
79+
root = Path(__file__).parent.parent
80+
files = [root / "setup.py", root / "pyproject.toml"]
81+
82+
# Get current version from pyproject.toml
83+
pyproject = root / "pyproject.toml"
84+
if not pyproject.exists():
85+
print("Error: pyproject.toml not found", file=sys.stderr)
86+
sys.exit(1)
87+
88+
current_version = get_current_version(pyproject.read_text())
89+
if not current_version:
90+
print("Error: Could not find version in pyproject.toml", file=sys.stderr)
91+
sys.exit(1)
92+
93+
# Determine new version
94+
if args.set_version:
95+
new_version = args.set_version
96+
else:
97+
new_version = bump_version(current_version, args.part)
98+
99+
print(f"Version: {current_version} -> {new_version}")
100+
101+
if args.dry_run:
102+
print("Dry run - no files modified")
103+
return
104+
105+
# Update files
106+
for filepath in files:
107+
if update_file(filepath, current_version, new_version):
108+
print(f"Updated: {filepath.name}")
109+
else:
110+
print(f"Skipped: {filepath.name} (not found or no changes)")
111+
112+
print(f"\nDone! Don't forget to commit and tag: git tag v{new_version}")
113+
114+
115+
if __name__ == "__main__":
116+
main()

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
setup(
66
name="vllm-bart-plugin",
7-
version="0.2.0",
7+
version="0.3.4",
88
description="BART model plugin for vLLM",
99
author="Nicolò Lucchesi",
1010
author_email="nick.lucche@redhat.com",
1111
packages=find_packages(),
1212
python_requires=">=3.10",
1313
install_requires=[
14-
"vllm>=0.13.0",
14+
"vllm>=0.13.0,<=0.18",
1515
"torch>=2.9.0",
1616
"transformers >= 4.56.0, < 5",
1717
],

tests/test_vllm_018_compat.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Regression tests for vLLM 0.18 compatibility in the BART processor."""
2+
3+
import torch
4+
5+
6+
def test_text_data_parser_handles_v018_empty_inputs():
7+
from vllm_bart_plugin.bart import TextDataParser
8+
9+
parser = TextDataParser()
10+
11+
assert parser._parse_text_data("") is None
12+
assert parser._parse_text_data([]) is None
13+
14+
15+
def test_create_encoder_prompt_uses_placeholder_token():
16+
from vllm_bart_plugin.bart import BartMultiModalProcessor
17+
18+
processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
19+
20+
assert processor.create_encoder_prompt("<s>decoder text", {"texts": ["encoder text"]}) == [0]
21+
22+
23+
def test_call_hf_processor_accepts_pretokenized_decoder_prompt():
24+
from vllm_bart_plugin.bart import BartMultiModalProcessor
25+
26+
class FakeTokenizer:
27+
def __call__(self, text, return_tensors="pt", **kwargs):
28+
if text == "encoder text":
29+
return {"input_ids": torch.tensor([[11, 12, 13]])}
30+
return {"input_ids": torch.tensor([[21, 22]])}
31+
32+
class FakeInfo:
33+
def get_tokenizer(self):
34+
return FakeTokenizer()
35+
36+
processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
37+
processor.info = FakeInfo()
38+
39+
out = processor._call_hf_processor(
40+
[7, 8, 9],
41+
{"texts": ["encoder text"]},
42+
{},
43+
{},
44+
)
45+
46+
assert torch.equal(out["encoder_input_ids"], torch.tensor([[11, 12, 13]]))
47+
assert torch.equal(out["input_ids"], torch.tensor([[7, 8, 9]]))

0 commit comments

Comments
 (0)