Skip to content

Commit ecffff1

Browse files
aerdem4alonsosilvaallende
authored andcommitted
parent ec4e07f
author Ahmet Erdem <ahmeterd4@gmail.com> 1742464764 +0300 committer Alonso Silva <alonso.silva@gmail.com> 1745697927 +0200 Add the possibility of batched inference with vllm gpgsig -----BEGIN PGP SIGNATURE----- iQGzBAABCgAdFiEErgR3wPRihX6zWySbdkm60UX6O/8FAmgNPIcACgkQdkm60UX6 O//dgwv/aRhWGeeCBVtWb2iziqMVYbHB4O4XX7RuVKujS9a4w48HsdwNs24TIAmR JpvdHVHe9QUeRGWeBGW8lB77PAXLCoLvFEj7oyRwPtRCJ7bJXZcpOkXynYQIRh7V l8vmvrc2psNKjdpYsHfnC0/yzSTxl32EL2Wi8iVV4wNsGDNNZ5EP0ndM+ke7K3yO 1YW2Nb95svS13fC8zUq9qboB6xue2JeyCoAUzT3C4vShDTtONu7DbaDM0zumJgN9 Bf5kApw864TZsXT4McTxoPLs0hIgjvRVsllTIQekQVP03UYIv6INp9bfCE+dq4QJ Lsm+uAihkESTPf57DuW9v3OLJU+pdIqGqoVdQpl9QlIMa1zbcMvel/5lklCcVQIf NgNOKhYBXgLS5jj2HwK+9eBQ3LuzvtPmpx4HtSQw+QWAXbBUxgKBNMqSmOnF0Ad3 KYfaSU6oPPyyxZQ3y2VVGh269JAnw6cHL1i2rhxhh3UDcIujZL42kmUztaP1PyMe fi6bfnoJ =6cds -----END PGP SIGNATURE----- Reset processors after each batch to be able to re-use (#13) * Reset processors after each batch to be able to re-use Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Re-use processor objects in example notebooks Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Add TriggerPhraseLogitsProcessor to readme Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Fix grammar mistake Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Comment reset_if_new_batch logic Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Make new generation detection logic more robust Signed-off-by: aerdem4 <ahmeterd4@gmail.com> --------- Signed-off-by: aerdem4 <ahmeterd4@gmail.com> Add batched inference possibility Fix issues with flake8 Increment package version Use initial_trigger_phrase instead of trigger_phrase
1 parent ec4e07f commit ecffff1

19 files changed

Lines changed: 259 additions & 85 deletions

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,12 @@ I am getting a lot of calls during the day. What is more important for me to con
7878
2. Operating System
7979
3. Battery
8080
```
81-
The goal is to make LLM generate "3" as an answer.
81+
The goal is to make LLM generate "3" as an answer.
82+
83+
### TriggerPhraseLogitsProcessor
84+
A logits processor which triggers phrases when it encounters a given token.
85+
One common use case is to force writing python code just after thinking:
86+
```python
87+
trigger_python = TriggerPhraseLogitsProcessor(phrase="\n```python", trigger_token_phrase="</think>",
88+
tokenizer=tokenizer, trigger_count=1, trigger_after=True)
89+
```

example_notebooks/transformers/cite_prompt_logits_processor.ipynb

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
" \n",
137137
"\n",
138138
"LLM response:\n",
139-
"The user seems to have mixed feelings about the price of the product. They find it expensive, but they also appreciate its softness, colorfulness, and style, which suggests that the product is well-made and worth the cost.\n",
139+
"The user seems to have mixed feelings about the price of the product. They find it expensive, but they also appreciate its softness, colorfulness, and style.\n",
140140
"-----END-----\n",
141141
"\n",
142142
"Prompt: \n",
@@ -158,7 +158,7 @@
158158
"source": [
159159
"runner.generate_response(\n",
160160
" example_prompts,\n",
161-
" [CiteFromPromptLogitsProcessor(runner.tokenizer, example_prompts, boost_factor=2.0)]\n",
161+
" [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=2.0, boost_eos=False)]\n",
162162
")"
163163
]
164164
},
@@ -187,9 +187,17 @@
187187
" \n",
188188
"\n",
189189
"LLM response:\n",
190-
"The reviewer seems to have mixed feelings towards the pricing of the product. They describe it as \"expensive\" and \"deserves its price,\" which suggests that they find it worth paying for its quality or unique features. The use of words like \"stylish\" further emphasizes their positive impression.\n",
190+
"The reviewer seems to have mixed feelings towards the pricing of the product:\n",
191191
"\n",
192-
"So in summary, while they appreciate the design and style of the product, they also acknowledge that it might be quite pricey. Therefore, their overall sentiment can be described as **mixed**, with appreciation for both aspects (quality and style).\n",
192+
"- They describe it as \"very soft\" and \"colorful\", suggesting that they appreciate these qualities.\n",
193+
"\n",
194+
"- They also mention that it is \"expensive,\" which might be seen as negative if you're looking for an affordable option or if this was their first time buying something like this.\n",
195+
"\n",
196+
"- However, they state that it \"deserves its price,\" indicating that they believe the high cost reflects on quality or value.\n",
197+
"\n",
198+
"Overall, while they seem satisfied with the overall experience and don't mind paying more for what they perceive as good-quality materials and design, they may feel that the price point could be higher than expected for everyday use or budget-conscious shoppers.\n",
199+
"\n",
200+
"So in summary, they find the item to be well-made and aesthetically pleasing despite feeling that it might not be suitable for everyone due to being too pricey for some people's budgets. The reviewer seems generally positive toward the purchase decision itself rather than just the specific item.\n",
193201
"-----END-----\n",
194202
"\n",
195203
"Prompt: \n",
@@ -215,9 +223,17 @@
215223
"source": [
216224
"runner.generate_response(\n",
217225
" example_prompts,\n",
218-
" [CiteFromPromptLogitsProcessor(runner.tokenizer, example_prompts, boost_factor=-2.0)]\n",
226+
" [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=-2.0, boost_eos=False)]\n",
219227
")"
220228
]
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": null,
233+
"id": "c29fedb3",
234+
"metadata": {},
235+
"outputs": [],
236+
"source": []
221237
}
222238
],
223239
"metadata": {

example_notebooks/vllm/force_last_phrase_logits_processor.ipynb

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@
2828
"name": "stdout",
2929
"output_type": "stream",
3030
"text": [
31-
"WARNING 02-12 13:42:36 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32-
"WARNING 02-12 13:42:39 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33-
"INFO 02-12 13:42:39 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='google/gemma-1.1-2b-it', speculative_config=None, tokenizer='google/gemma-1.1-2b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=google/gemma-1.1-2b-it, use_v2_block_manager=False, enable_prefix_caching=False)\n",
34-
"INFO 02-12 13:42:40 model_runner.py:879] Starting to load model google/gemma-1.1-2b-it...\n",
35-
"INFO 02-12 13:42:40 weight_utils.py:236] Using model weights format ['*.safetensors']\n"
31+
"WARNING 03-18 13:40:54 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32+
"WARNING 03-18 13:40:58 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33+
"INFO 03-18 13:40:58 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='google/gemma-1.1-2b-it', speculative_config=None, tokenizer='google/gemma-1.1-2b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=google/gemma-1.1-2b-it, use_v2_block_manager=False, enable_prefix_caching=False)\n",
34+
"INFO 03-18 13:40:59 model_runner.py:879] Starting to load model google/gemma-1.1-2b-it...\n",
35+
"INFO 03-18 13:41:00 weight_utils.py:236] Using model weights format ['*.safetensors']\n"
3636
]
3737
},
3838
{
3939
"data": {
4040
"application/vnd.jupyter.widget-view+json": {
41-
"model_id": "",
41+
"model_id": "ebad9294acfd4e15aa9272a1aac448df",
4242
"version_major": 2,
4343
"version_minor": 0
4444
},
@@ -53,8 +53,8 @@
5353
"name": "stdout",
5454
"output_type": "stream",
5555
"text": [
56-
"INFO 02-12 13:42:42 model_runner.py:890] Loading model weights took 4.6720 GB\n",
57-
"INFO 02-12 13:42:44 gpu_executor.py:121] # GPU blocks: 49686, # CPU blocks: 14563\n"
56+
"INFO 03-18 13:41:02 model_runner.py:890] Loading model weights took 4.6720 GB\n",
57+
"INFO 03-18 13:41:05 gpu_executor.py:121] # GPU blocks: 49742, # CPU blocks: 14563\n"
5858
]
5959
}
6060
],
@@ -156,10 +156,10 @@
156156
}
157157
],
158158
"source": [
159-
"phrase = \"\\n\\nReferences:\"\n",
159+
"reference = ForceLastPhraseLogitsProcessor(\"\\n\\nReferences:\", runner.tokenizer)\n",
160160
"\n",
161161
"runner.generate_response(example_prompts,\n",
162-
" [ForceLastPhraseLogitsProcessor(phrase, runner.tokenizer)])"
162+
" [reference])"
163163
]
164164
},
165165
{
@@ -199,19 +199,58 @@
199199
}
200200
],
201201
"source": [
202-
"phrase = \"\\n\\nThanks for trying our RAG application! If you have more questions about\"\n",
202+
"thank = ForceLastPhraseLogitsProcessor(\"\\n\\nThanks for trying our RAG application! If you have more questions about\",\n",
203+
" runner.tokenizer)\n",
203204
"\n",
204205
"runner.generate_response(example_prompts,\n",
205-
" [ForceLastPhraseLogitsProcessor(phrase, runner.tokenizer)])"
206+
" [thank])"
207+
]
208+
},
209+
{
210+
"cell_type": "markdown",
211+
"id": "34735f41",
212+
"metadata": {},
213+
"source": [
214+
"## Both"
206215
]
207216
},
208217
{
209218
"cell_type": "code",
210-
"execution_count": null,
219+
"execution_count": 6,
211220
"id": "88032bf3",
212221
"metadata": {},
213-
"outputs": [],
214-
"source": []
222+
"outputs": [
223+
{
224+
"name": "stdout",
225+
"output_type": "stream",
226+
"text": [
227+
"Prompt: \n",
228+
" Retrieved information from: https://en.wikipedia.org/wiki/Bulbasaur\n",
229+
" Bulbasaur is a fictional Pokémon species in Nintendo and Game Freak's Pokémon franchise. \n",
230+
" Designed by Atsuko Nishida, Bulbasaur is a Grass and Poison-type, first appearing in Pocket Monsters: Red and Green (Pokémon Red and Blue outside Japan) as a starter Pokémon. \n",
231+
" Since then, it has reappeared in sequels, spin-off games, related merchandise, and animated and printed adaptations of the franchise. \n",
232+
" It is a central character in the Pokémon anime, being one of Ash Ketchum's main Pokémon for the first season, with a different one later being obtained by supporting character May. \n",
233+
" It is featured in various manga and is owned by protagonist Red in Pokémon Adventures.\n",
234+
" \n",
235+
" What is Bulbasaur?\n",
236+
" \n",
237+
"Bulbasaur is a fictional Pokémon species that is a Grass and Poison-type.\n",
238+
"\n",
239+
"Thanks for trying our RAG application! If you have more questions about Bulbasaur or anything else, please let me know.\n",
240+
"\n",
241+
"References:\n",
242+
"\n",
243+
"* Wikipedia: Bulbasaur\n",
244+
"* Pokémon franchise website\n",
245+
"-----END-----\n",
246+
"\n"
247+
]
248+
}
249+
],
250+
"source": [
251+
"runner.generate_response(example_prompts,\n",
252+
" [thank, reference])"
253+
]
215254
}
216255
],
217256
"metadata": {

example_notebooks/vllm/trigger_phrase_logits_processor.ipynb

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@
2828
"name": "stdout",
2929
"output_type": "stream",
3030
"text": [
31-
"WARNING 02-13 10:32:45 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32-
"WARNING 02-13 10:32:49 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33-
"WARNING 02-13 10:32:49 arg_utils.py:839] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n",
34-
"INFO 02-13 10:32:49 config.py:911] Chunked prefill is enabled with max_num_batched_tokens=512.\n",
35-
"INFO 02-13 10:32:49 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, use_v2_block_manager=False, enable_prefix_caching=False)\n",
36-
"INFO 02-13 10:32:50 model_runner.py:879] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n",
37-
"INFO 02-13 10:32:51 weight_utils.py:236] Using model weights format ['*.safetensors']\n",
38-
"INFO 02-13 10:32:52 weight_utils.py:280] No model.safetensors.index.json found in remote.\n"
31+
"WARNING 03-18 13:37:20 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32+
"WARNING 03-18 13:37:24 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33+
"WARNING 03-18 13:37:24 arg_utils.py:839] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n",
34+
"INFO 03-18 13:37:24 config.py:911] Chunked prefill is enabled with max_num_batched_tokens=512.\n",
35+
"INFO 03-18 13:37:24 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, use_v2_block_manager=False, enable_prefix_caching=False)\n",
36+
"INFO 03-18 13:37:25 model_runner.py:879] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n",
37+
"INFO 03-18 13:37:26 weight_utils.py:236] Using model weights format ['*.safetensors']\n",
38+
"INFO 03-18 13:37:27 weight_utils.py:280] No model.safetensors.index.json found in remote.\n"
3939
]
4040
},
4141
{
4242
"data": {
4343
"application/vnd.jupyter.widget-view+json": {
44-
"model_id": "35561044f6c848eb9c56be591d6c50c8",
44+
"model_id": "e3ff26a7bf23415c9e29952e2a5f2d1a",
4545
"version_major": 2,
4646
"version_minor": 0
4747
},
@@ -56,8 +56,8 @@
5656
"name": "stdout",
5757
"output_type": "stream",
5858
"text": [
59-
"INFO 02-13 10:32:53 model_runner.py:890] Loading model weights took 3.3460 GB\n",
60-
"INFO 02-13 10:32:53 gpu_executor.py:121] # GPU blocks: 37897, # CPU blocks: 9362\n"
59+
"INFO 03-18 13:37:29 model_runner.py:890] Loading model weights took 3.3460 GB\n",
60+
"INFO 03-18 13:37:29 gpu_executor.py:121] # GPU blocks: 37898, # CPU blocks: 9362\n"
6161
]
6262
}
6363
],
@@ -338,9 +338,11 @@
338338
}
339339
],
340340
"source": [
341+
"trigger_python = TriggerPhraseLogitsProcessor(\"\\n```python\", \"</think>\", runner.tokenizer, \n",
342+
" trigger_count=1, trigger_after=True)\n",
343+
"\n",
341344
"runner.generate_response(example_prompts,\n",
342-
" [TriggerPhraseLogitsProcessor(\"\\n```python\", \"</think>\", runner.tokenizer, \n",
343-
" trigger_count=1, trigger_after=True)],\n",
345+
" [trigger_python],\n",
344346
" max_tokens=4096)"
345347
]
346348
},
@@ -387,12 +389,14 @@
387389
}
388390
],
389391
"source": [
392+
"keep_thinking_short = GenLengthLogitsProcessor(runner.tokenizer, boost_factor=0.1, complete_sentences=True,\n",
393+
" boost_token_str=\"</think>\")\n",
394+
"\n",
390395
"runner.generate_response(example_prompts,\n",
391396
" [\n",
392-
" GenLengthLogitsProcessor(runner.tokenizer, boost_factor=0.1, complete_sentences=True,\n",
393-
" boost_token_str=\"</think>\"),\n",
394-
" TriggerPhraseLogitsProcessor(\"\\n```python\", \"</think>\", runner.tokenizer, \n",
395-
" trigger_count=1, trigger_after=True)],\n",
397+
" keep_thinking_short,\n",
398+
" trigger_python\n",
399+
" ],\n",
396400
" max_tokens=4096)"
397401
]
398402
},
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import torch
19+
20+
21+
class BaseLogitsProcessor:
22+
def __init__(self):
23+
self.prompt_token_ids = None
24+
self.prev_token_ids = None
25+
26+
def _reset(self):
27+
pass
28+
29+
def _check_new_generation(self, input_ids: torch.LongTensor):
30+
first_time = self.prompt_token_ids is None
31+
if first_time:
32+
self._reset()
33+
self.prompt_token_ids = input_ids
34+
else:
35+
same_gen = False
36+
if input_ids.shape[1] > 1:
37+
same_gen = torch.equal(input_ids[:, :-1], self.prev_token_ids)
38+
39+
if not same_gen:
40+
self._reset()
41+
self.prompt_token_ids = input_ids
42+
43+
self.prev_token_ids = input_ids
44+
45+
def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
46+
return scores
47+
48+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
49+
self._check_new_generation(input_ids)
50+
scores = self._process(input_ids, scores)
51+
return scores

0 commit comments

Comments
 (0)