-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathexample_bart_usage.py
More file actions
70 lines (64 loc) · 2.24 KB
/
Copy pathexample_bart_usage.py
File metadata and controls
70 lines (64 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
"""
Example usage of the vLLM BART plugin.
This script demonstrates how to use BART models with vLLM
after installing the BART plugin.
"""
import vllm_bart_plugin
from vllm import LLM, SamplingParams
def main():
"""Run BART model examples."""
model_name = "facebook/bart-large-cnn"
llm = LLM(
model=model_name,
tensor_parallel_size=1,
enforce_eager=False,
max_model_len=1024,
max_num_seqs=4,
max_num_batched_tokens=11024,
gpu_memory_utilization=0.5,
dtype="float16",
)
params = SamplingParams(temperature=0.0, max_tokens=20)
outputs = llm.generate(
[
# Not supported
# {
# "prompt": "The president of the United States is",
# },
# Not supported without changes to vllm core
# { # Test explicit encoder/decoder prompt
# "encoder_prompt": {
# "prompt": "The president of the United States is",
# },
# "decoder_prompt": "<s>Donald",
# },
{ # NOTE Explicit encoder/decoder prompt. Use <s> to start decoder prompt
"encoder_prompt": {
"prompt": "",
# NOTE This format is needed st we don't have to add custom encoder-only prompt
# logic in preprocess.py (vllm core) to convert encoder_token_ids to mm text item
"multi_modal_data": {
"text": "The president of the United States is",
},
},
"decoder_prompt": "<s>Donald",
},
{
"encoder_prompt": {
"prompt": "",
# NOTE output is really sensible to the BOS token which should always be present in decoder prompt!
"multi_modal_data": {
"text": "<s>",
},
},
"decoder_prompt": "<s>The capital of France is",
},
],
sampling_params=params,
)
for o in outputs:
generated_text = o.outputs[0].text
print("output:", generated_text)
if __name__ == "__main__":
main()