From c03309d3e12a0fd25e44e0ae8c90a420ba51d1bd Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Tue, 29 Apr 2025 15:40:25 +0300 Subject: [PATCH 1/2] Enforce tokens without adding very large numbers to avoid floating point issues Signed-off-by: aerdem4 --- .../force_last_phrase_logits_processor.ipynb | 23 ++- .../gen_length_logits_processor.ipynb | 2 +- .../multiple_choice_logits_processor.ipynb | 23 ++- .../trigger_phrase_logits_processor.ipynb | 59 ++++-- .../vllm/cite_prompt_logits_processor.ipynb | 182 +++++++++++++++--- .../force_last_phrase_logits_processor.ipynb | 69 ++++--- .../vllm/gen_length_logits_processor.ipynb | 161 ++++++++++------ .../multiple_choice_logits_processor.ipynb | 66 +++++-- .../trigger_phrase_logits_processor.ipynb | 51 +++-- example_notebooks/vllm/utils.py | 2 +- .../transformers/last_phrase.py | 5 +- .../transformers/multiple_choice.py | 5 +- .../transformers/trigger_phrase.py | 7 +- logits_processor_zoo/utils.py | 13 +- logits_processor_zoo/vllm/last_phrase.py | 5 +- logits_processor_zoo/vllm/multiple_choice.py | 5 +- logits_processor_zoo/vllm/trigger_phrase.py | 7 +- pyproject.toml | 2 +- tests/test_utils.py | 12 +- 19 files changed, 517 insertions(+), 182 deletions(-) diff --git a/example_notebooks/transformers/force_last_phrase_logits_processor.ipynb b/example_notebooks/transformers/force_last_phrase_logits_processor.ipynb index d915144..4b1a41b 100644 --- a/example_notebooks/transformers/force_last_phrase_logits_processor.ipynb +++ b/example_notebooks/transformers/force_last_phrase_logits_processor.ipynb @@ -23,7 +23,19 @@ "execution_count": 2, "id": "a85f8503", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from example_notebooks.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import ForceLastPhraseLogitsProcessor\n", @@ -63,7 +75,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", + " warnings.warn(\n" ] }, { @@ -216,7 +233,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.17" } }, "nbformat": 4, diff --git a/example_notebooks/transformers/gen_length_logits_processor.ipynb b/example_notebooks/transformers/gen_length_logits_processor.ipynb index b3d4df9..6c5ebe6 100644 --- a/example_notebooks/transformers/gen_length_logits_processor.ipynb +++ b/example_notebooks/transformers/gen_length_logits_processor.ipynb @@ -288,7 +288,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.17" } }, "nbformat": 4, diff --git a/example_notebooks/transformers/multiple_choice_logits_processor.ipynb b/example_notebooks/transformers/multiple_choice_logits_processor.ipynb index 9d06940..b4f0047 100644 --- a/example_notebooks/transformers/multiple_choice_logits_processor.ipynb +++ b/example_notebooks/transformers/multiple_choice_logits_processor.ipynb @@ -23,7 +23,19 @@ "execution_count": 2, "id": "a85f8503", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from example_notebooks.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import MultipleChoiceLogitsProcessor\n", @@ -68,7 +80,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", + " warnings.warn(\n" ] }, { @@ -237,7 +254,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.17" } }, "nbformat": 4, diff --git a/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb b/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb index 3eb21f1..0a37367 100644 --- a/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb +++ b/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb @@ -23,7 +23,19 @@ "execution_count": 2, "id": "b89279fe", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from example_notebooks.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import TriggerPhraseLogitsProcessor, GenLengthLogitsProcessor\n", @@ -58,9 +70,14 @@ "name": "stderr", "output_type": "stream", "text": [ + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n", + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", + " warnings.warn(\n", "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", - "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n", - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" ] }, { @@ -96,9 +113,9 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", "\n", - "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", + "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", "\n", @@ -162,7 +179,7 @@ "output_type": "stream", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", - "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n" + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" ] }, { @@ -198,25 +215,27 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", "\n", - "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", + "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", ",,,\n", "\n", - "Wait, but the problem says to make it recursive. So, the function as written is recursive. So, I think this should be the solution.\n", + "Wait, but the problem says to make it recursive. So, the function should call itself with smaller arguments. The approach I have is correct and recursive.\n", + "\n", + "So, the final function is as I wrote above.\n", ",,,\n", "\n", - "Wait, but in the function I wrote, for n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should work as intended.\n", + "Wait, but in the function, for n=2, it's F(1)+F(0) = 1+0=1, which is correct. For n=3, F(2)+F(1)=1+1=2. So, the function works as expected.\n", "\n", - "I think that's a good approach. So, the function is as I wrote above.\n", + "I think this should solve the problem.\n", "\n", "\n", "To solve this problem, we need to generate the nth Fibonacci number using a recursive approach. The Fibonacci sequence is a series of numbers where each number is the sum of the two preceding ones, starting from 0 and 1. \n", "\n", "### Approach\n", - "The approach to solve this problem involves using recursion, which is a method where the function calls itself with a smaller input until it reaches a base case. Here's a step-by-step breakdown of the approach:\n", + "The approach to solve this problem involves using recursion, which is a method where a function calls itself with a modified parameter to achieve the desired result. Here's a step-by-step breakdown of the approach:\n", "\n", "1. **Base Cases**: \n", " - If `n` is 0, return 0.\n", @@ -236,14 +255,14 @@ " elif n == 1:\n", " return 1\n", " else:\n", - " return fibonacci(n - 1) + fibonacci(n - 2)\n", + " return fibonacci(n-1) + fibonacci(n-2)\n", "```\n", "\n", "### Explanation\n", "- **Base Cases**: The function first checks if `n` is 0 or 1. If `n` is 0, it returns 0. If `n` is 1, it returns 1. These are the simplest cases of the Fibonacci sequence.\n", "- **Recursive Case**: For any `n` greater than 1, the function calls itself with `n-1` and `n-2`, and returns the sum of these two recursive calls. This builds up the solution by solving smaller subproblems and combining their results.\n", "\n", - "This approach is straightforward and easy to understand, but it's important to note that it is not the most efficient for large values of `n` due to the exponential time complexity. However, for the purpose of this problem, the recursive approach is sufficient.\n", + "This approach is straightforward and leverages the divide-and-conquer strategy inherent in recursion, making it easy to understand and implement. However, it's important to note that this approach has a time complexity of O(2^n) due to the exponential number of function calls, which is not efficient for large values of `n`. For larger values, an iterative approach or memoization would be more efficient.\n", "-----END-----\n", "\n" ] @@ -277,7 +296,7 @@ "output_type": "stream", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", - "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n" + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" ] }, { @@ -313,9 +332,9 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", "\n", - "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", + "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", "\n", @@ -329,7 +348,7 @@ " return fibonacci(n-1) + fibonacci(n-2)\n", "```\n", "\n", - "This function calculates the nth Fibonacci number using a recursive approach. It handles the base cases where n is 0 or 1 and for other values, it recursively calculates the sum of the two preceding Fibonacci numbers. While this implementation is straightforward, it's not the most efficient for large values of n due to repeated calculations.\n", + "This function calculates the nth Fibonacci number using a recursive approach. It handles the base cases where n is 0 or 1 and recursively computes the value for larger n by summing the two preceding Fibonacci numbers.\n", "-----END-----\n", "\n" ] @@ -361,7 +380,7 @@ "output_type": "stream", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", - "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n" + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" ] }, { @@ -429,7 +448,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.17" } }, "nbformat": 4, diff --git a/example_notebooks/vllm/cite_prompt_logits_processor.ipynb b/example_notebooks/vllm/cite_prompt_logits_processor.ipynb index d05a1a4..475ef7f 100644 --- a/example_notebooks/vllm/cite_prompt_logits_processor.ipynb +++ b/example_notebooks/vllm/cite_prompt_logits_processor.ipynb @@ -28,22 +28,50 @@ "name": "stdout", "output_type": "stream", "text": [ - "WARNING 02-12 13:41:31 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n", - "WARNING 02-12 13:41:34 config.py:1563] Casting torch.bfloat16 to torch.float16.\n", - "INFO 02-12 13:41:34 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='google/gemma-1.1-2b-it', speculative_config=None, tokenizer='google/gemma-1.1-2b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=google/gemma-1.1-2b-it, use_v2_block_manager=False, enable_prefix_caching=False)\n", - "INFO 02-12 13:41:35 model_runner.py:879] Starting to load model google/gemma-1.1-2b-it...\n", - "INFO 02-12 13:41:35 weight_utils.py:236] Using model weights format ['*.safetensors']\n" + "WARNING 04-29 15:38:10 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING 04-29 15:38:13 config.py:1563] Casting torch.bfloat16 to torch.float16.\n", + "INFO 04-29 15:38:13 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='Qwen/Qwen2.5-1.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-1.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2.5-1.5B-Instruct, use_v2_block_manager=False, enable_prefix_caching=False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 04-29 15:38:14 model_runner.py:879] Starting to load model Qwen/Qwen2.5-1.5B-Instruct...\n", + "INFO 04-29 15:38:15 weight_utils.py:236] Using model weights format ['*.safetensors']\n", + "INFO 04-29 15:38:15 weight_utils.py:280] No model.safetensors.index.json found in remote.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "afb66bd12b4148bdb3021b4229dd5a23", + "model_id": "ec8e95e203c148d4b882e5a21f8d6134", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n", - "INFO 03-18 13:37:24 config.py:911] Chunked prefill is enabled with max_num_batched_tokens=512.\n", - "INFO 03-18 13:37:24 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, use_v2_block_manager=False, enable_prefix_caching=False)\n", - "INFO 03-18 13:37:25 model_runner.py:879] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n", - "INFO 03-18 13:37:26 weight_utils.py:236] Using model weights format ['*.safetensors']\n", - "INFO 03-18 13:37:27 weight_utils.py:280] No model.safetensors.index.json found in remote.\n" + "WARNING 04-29 15:31:18 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING 04-29 15:31:21 config.py:1563] Casting torch.bfloat16 to torch.float16.\n", + "WARNING 04-29 15:31:21 arg_utils.py:839] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n", + "INFO 04-29 15:31:21 config.py:911] Chunked prefill is enabled with max_num_batched_tokens=512.\n", + "INFO 04-29 15:31:21 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, use_v2_block_manager=False, enable_prefix_caching=False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 04-29 15:31:22 model_runner.py:879] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n", + "INFO 04-29 15:31:22 weight_utils.py:236] Using model weights format ['*.safetensors']\n", + "INFO 04-29 15:31:23 weight_utils.py:280] No model.safetensors.index.json found in remote.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e3ff26a7bf23415c9e29952e2a5f2d1a", + "model_id": "fbe3d3d1cfc44271bc837babbcc9e71c", "version_major": 2, "version_minor": 0 }, @@ -56,8 +83,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO 03-18 13:37:29 model_runner.py:890] Loading model weights took 3.3460 GB\n", - "INFO 03-18 13:37:29 gpu_executor.py:121] # GPU blocks: 37898, # CPU blocks: 9362\n" + "INFO 04-29 15:31:24 model_runner.py:890] Loading model weights took 3.3460 GB\n", + "INFO 04-29 15:31:24 gpu_executor.py:121] # GPU blocks: 37898, # CPU blocks: 9362\n" ] } ], @@ -425,7 +452,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.17" } }, "nbformat": 4, diff --git a/example_notebooks/vllm/utils.py b/example_notebooks/vllm/utils.py index 34432f5..f735c26 100644 --- a/example_notebooks/vllm/utils.py +++ b/example_notebooks/vllm/utils.py @@ -2,7 +2,7 @@ class vLLMRunner: - def __init__(self, model_name='google/gemma-1.1-2b-it'): + def __init__(self, model_name="Qwen/Qwen2.5-1.5B-Instruct"): self.model = vllm.LLM( model_name, trust_remote_code=True, diff --git a/logits_processor_zoo/transformers/last_phrase.py b/logits_processor_zoo/transformers/last_phrase.py index 0e4628c..0269912 100644 --- a/logits_processor_zoo/transformers/last_phrase.py +++ b/logits_processor_zoo/transformers/last_phrase.py @@ -18,6 +18,7 @@ from transformers import PreTrainedTokenizer import torch from logits_processor_zoo.transformers.base import BaseLogitsProcessor +from logits_processor_zoo.utils import enforce_tokens class ForceLastPhraseLogitsProcessor(BaseLogitsProcessor): @@ -45,10 +46,10 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to for i in range(scores.shape[0]): it = self.iterators[i].item() if scores[i, :].argmax() == self.eos_token_id and it == 0: - scores[i, self.phrase_tokens[it]] = scores[i].max() + 1 + scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) self.iterators[i] += 1 elif len(self.phrase_tokens) > it > 0: - scores[i, self.phrase_tokens[it]] = scores[i].max() + 1 + scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) self.iterators[i] += 1 return scores diff --git a/logits_processor_zoo/transformers/multiple_choice.py b/logits_processor_zoo/transformers/multiple_choice.py index 1ea6db3..c900c51 100644 --- a/logits_processor_zoo/transformers/multiple_choice.py +++ b/logits_processor_zoo/transformers/multiple_choice.py @@ -18,7 +18,7 @@ from transformers import PreTrainedTokenizer from typing import List import torch -from logits_processor_zoo.utils import text_to_token, get_new_line_tokens +from logits_processor_zoo.utils import text_to_token, get_new_line_tokens, enforce_tokens from logits_processor_zoo.transformers.base import BaseLogitsProcessor @@ -78,6 +78,7 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to boost = self.boost_first_words * scores[row_ind, first_tokens] scores[row_ind, self.choice_tokens[:len(first_tokens)]] += boost - scores[:, self.choice_tokens] += self.very_large_number + for i in range(scores.shape[0]): + scores[i] = enforce_tokens(scores[i], self.choice_tokens) return scores diff --git a/logits_processor_zoo/transformers/trigger_phrase.py b/logits_processor_zoo/transformers/trigger_phrase.py index 455ed8e..3cdf677 100644 --- a/logits_processor_zoo/transformers/trigger_phrase.py +++ b/logits_processor_zoo/transformers/trigger_phrase.py @@ -17,7 +17,7 @@ from transformers import PreTrainedTokenizer import torch -from logits_processor_zoo.utils import text_to_token +from logits_processor_zoo.utils import text_to_token, enforce_tokens from logits_processor_zoo.transformers.base import BaseLogitsProcessor @@ -38,7 +38,6 @@ def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrained super().__init__() self.trigger_token = text_to_token(tokenizer, trigger_token_phrase, last=False) self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) - self.very_large_number = 999 self.trigger_after = trigger_after self.batch_size = batch_size self.initial_trigger_count = trigger_count @@ -56,10 +55,10 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if scores[i, :].argmax() == self.trigger_token and it == -1: self.iterators[i] = 0 if not self.trigger_after: - scores[i, self.phrase_tokens[it]] = scores[i].max() + self.very_large_number + scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) self.iterators[i] += 1 elif len(self.phrase_tokens) > it >= 0: - scores[i, self.phrase_tokens[it]] = scores[i].max() + self.very_large_number + scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) self.iterators[i] += 1 if len(self.phrase_tokens) == self.iterators[i].item(): # phrase completed, reset for next trigger diff --git a/logits_processor_zoo/utils.py b/logits_processor_zoo/utils.py index fb08dc5..3775113 100644 --- a/logits_processor_zoo/utils.py +++ b/logits_processor_zoo/utils.py @@ -16,6 +16,8 @@ # from transformers import PreTrainedTokenizer +from typing import List +import torch def text_to_token(tokenizer: PreTrainedTokenizer, text: str, last: bool): @@ -28,8 +30,17 @@ def text_to_token(tokenizer: PreTrainedTokenizer, text: str, last: bool): return tokens[-1] -def get_new_line_tokens(tokenizer): +def get_new_line_tokens(tokenizer: PreTrainedTokenizer): new_line_tokens = [token for token in tokenizer.get_vocab().values() if tokenizer.decode(token).endswith("\n")] return set(new_line_tokens) + + +def enforce_tokens(scores: torch.Tensor, tokens: List[int]): + choice_scores = scores[tokens].clone() + gap = scores.max() - choice_scores.min() + choice_scores += gap + scores.fill_(scores.min()) + scores[tokens] = choice_scores + return scores diff --git a/logits_processor_zoo/vllm/last_phrase.py b/logits_processor_zoo/vllm/last_phrase.py index 1e75aaf..f241a98 100644 --- a/logits_processor_zoo/vllm/last_phrase.py +++ b/logits_processor_zoo/vllm/last_phrase.py @@ -18,6 +18,7 @@ from transformers import PreTrainedTokenizer from typing import List import torch +from logits_processor_zoo.utils import enforce_tokens class ForceLastPhraseLogitsProcessor: @@ -50,10 +51,10 @@ def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scor self._reset() if scores.argmax() == self.eos_token_id and self.index == 0: - scores[self.phrase_tokens[self.index]] = scores.max() + 1 + scores = enforce_tokens(scores, self.phrase_tokens[self.index]) self.index += 1 elif len(self.phrase_tokens) > self.index > 0: - scores[self.phrase_tokens[self.index]] = scores.max() + 1 + scores = enforce_tokens(scores, self.phrase_tokens[self.index]) self.index += 1 return scores diff --git a/logits_processor_zoo/vllm/multiple_choice.py b/logits_processor_zoo/vllm/multiple_choice.py index dce2d2d..345b274 100644 --- a/logits_processor_zoo/vllm/multiple_choice.py +++ b/logits_processor_zoo/vllm/multiple_choice.py @@ -18,7 +18,7 @@ from transformers import PreTrainedTokenizer from typing import List import torch -from logits_processor_zoo.utils import text_to_token, get_new_line_tokens +from logits_processor_zoo.utils import text_to_token, get_new_line_tokens, enforce_tokens class MultipleChoiceLogitsProcessor: @@ -53,7 +53,6 @@ def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None, self.delimiter_token = text_to_token(tokenizer, delimiter, last=False) self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices] self.boost_first_words = boost_first_words - self.very_large_number = 999 def clone(self): return MultipleChoiceLogitsProcessor(self.tokenizer, self.choices, self.delimiter, self.boost_first_words) @@ -81,5 +80,5 @@ def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scor scores[self.choice_tokens[:len(first_tokens)]] += self.boost_first_words * scores[first_tokens] - scores[self.choice_tokens] += self.very_large_number + scores = enforce_tokens(scores, self.choice_tokens) return scores diff --git a/logits_processor_zoo/vllm/trigger_phrase.py b/logits_processor_zoo/vllm/trigger_phrase.py index c4220b2..ccb3bb2 100644 --- a/logits_processor_zoo/vllm/trigger_phrase.py +++ b/logits_processor_zoo/vllm/trigger_phrase.py @@ -18,7 +18,7 @@ from transformers import PreTrainedTokenizer from typing import List import torch -from logits_processor_zoo.utils import text_to_token +from logits_processor_zoo.utils import text_to_token, enforce_tokens class TriggerPhraseLogitsProcessor: @@ -43,7 +43,6 @@ def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrained self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) self.initial_trigger_count = trigger_count self.trigger_after = trigger_after - self.very_large_number = 999 self._reset() def clone(self): @@ -64,10 +63,10 @@ def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scor if scores.argmax() == self.trigger_token and self.index == -1: self.index = 0 if not self.trigger_after: - scores[self.phrase_tokens[self.index]] = scores.max() + self.very_large_number + scores = enforce_tokens(scores, [self.phrase_tokens[self.index]]) self.index += 1 elif len(self.phrase_tokens) > self.index >= 0: - scores[self.phrase_tokens[self.index]] = scores.max() + self.very_large_number + scores = enforce_tokens(scores, [self.phrase_tokens[self.index]]) self.index += 1 if len(self.phrase_tokens) == self.index: # phrase completed, reset for next trigger diff --git a/pyproject.toml b/pyproject.toml index 0cc2223..1a77a89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "logits-processor-zoo" -version = "0.1.5" +version = "0.1.6" description = "A collection of LogitsProcessors to customize and enhance LLM behavior for specific tasks." authors = ["Ahmet Erdem", "Ivan Sorokin", "Maximilian Jeblick", "Darragh Hanley", "David Austin"] readme = "README.md" diff --git a/tests/test_utils.py b/tests/test_utils.py index 6e001ed..f458ea3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ -from logits_processor_zoo.utils import text_to_token, get_new_line_tokens +from logits_processor_zoo.utils import text_to_token, get_new_line_tokens, enforce_tokens +import torch def test_text_to_token(llm_runner): @@ -16,3 +17,12 @@ def test_text_to_token(llm_runner): def test_get_new_line_tokens(llm_runner): assert get_new_line_tokens(llm_runner.tokenizer) == {13} + + +def test_enforce_tokens(): + scores = torch.FloatTensor([0.1, -0.4, -0.2, -0.6, 1.1]) + tokens = [1, 2] + + scores = enforce_tokens(scores, tokens) + _, top2_tokens = torch.topk(scores, k=2) + assert torch.equal(top2_tokens, torch.tensor([2, 1])) From 23a23d8643a93f4b435ce3e88c357d1fd0f5c45c Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Wed, 30 Apr 2025 15:02:07 +0300 Subject: [PATCH 2/2] Make vllm cite prompt example better Signed-off-by: aerdem4 --- .../vllm/cite_prompt_logits_processor.ipynb | 172 ++++-------------- 1 file changed, 31 insertions(+), 141 deletions(-) diff --git a/example_notebooks/vllm/cite_prompt_logits_processor.ipynb b/example_notebooks/vllm/cite_prompt_logits_processor.ipynb index 475ef7f..28558cc 100644 --- a/example_notebooks/vllm/cite_prompt_logits_processor.ipynb +++ b/example_notebooks/vllm/cite_prompt_logits_processor.ipynb @@ -28,7 +28,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "WARNING 04-29 15:38:10 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n" + "WARNING 04-30 15:00:30 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n" ] }, { @@ -43,8 +43,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "WARNING 04-29 15:38:13 config.py:1563] Casting torch.bfloat16 to torch.float16.\n", - "INFO 04-29 15:38:13 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='Qwen/Qwen2.5-1.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-1.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2.5-1.5B-Instruct, use_v2_block_manager=False, enable_prefix_caching=False)\n" + "WARNING 04-30 15:00:33 config.py:1563] Casting torch.bfloat16 to torch.float16.\n", + "INFO 04-30 15:00:33 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='Qwen/Qwen2.5-1.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-1.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2.5-1.5B-Instruct, use_v2_block_manager=False, enable_prefix_caching=False)\n" ] }, { @@ -58,15 +58,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO 04-29 15:38:14 model_runner.py:879] Starting to load model Qwen/Qwen2.5-1.5B-Instruct...\n", - "INFO 04-29 15:38:15 weight_utils.py:236] Using model weights format ['*.safetensors']\n", - "INFO 04-29 15:38:15 weight_utils.py:280] No model.safetensors.index.json found in remote.\n" + "INFO 04-30 15:00:34 model_runner.py:879] Starting to load model Qwen/Qwen2.5-1.5B-Instruct...\n", + "INFO 04-30 15:00:34 weight_utils.py:236] Using model weights format ['*.safetensors']\n", + "INFO 04-30 15:00:35 weight_utils.py:280] No model.safetensors.index.json found in remote.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ec8e95e203c148d4b882e5a21f8d6134", + "model_id": "e9c350b056a04694bf4f2eade35244ba", "version_major": 2, "version_minor": 0 }, @@ -81,8 +81,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO 04-29 15:38:16 model_runner.py:890] Loading model weights took 2.8875 GB\n", - "INFO 04-29 15:38:18 gpu_executor.py:121] # GPU blocks: 37605, # CPU blocks: 9362\n" + "INFO 04-30 15:00:36 model_runner.py:890] Loading model weights took 2.8875 GB\n", + "INFO 04-30 15:00:38 gpu_executor.py:121] # GPU blocks: 37541, # CPU blocks: 9362\n" ] } ], @@ -93,7 +93,8 @@ "\n", "example_prompts =[\n", " \"\"\"\n", - " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", + " A user review: very soft, colorful, expensive but deserves its price.\n", + " I would like to wear it in my friend's wedding.\n", " \n", " What is the user's opinion about the product's price?\n", " \"\"\",\n", @@ -103,7 +104,7 @@ " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", " \n", - " Based on the retrieved information, what is a Pokémon?\n", + " Can you shortly describe what Pokémon is?\n", " \"\"\"\n", "]\n", "\n", @@ -129,11 +130,12 @@ "output_type": "stream", "text": [ "Prompt: \n", - " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", + " A user review: very soft, colorful, expensive but deserves its price.\n", + " I would like to wear it in my friend's wedding.\n", " \n", " What is the user's opinion about the product's price?\n", " \n", - "The user's opinion about the product's price is that it is expensive, but they believe it is worth the price due to its softness, colorfulness, and stylish design.\n", + "The user's opinion about the product's price is that it is expensive, but they believe it is worth the price.\n", "-----END-----\n", "\n", "Prompt: \n", @@ -142,9 +144,9 @@ " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", " \n", - " Based on the retrieved information, what is a Pokémon?\n", + " Can you shortly describe what Pokémon is?\n", " \n", - "A Pokémon is a fictional creature in the Pokémon franchise, which is a Japanese media franchise consisting of video games, animated series, films, a trading card game, and other related media. The franchise takes place in a shared universe where humans coexist with Pokémon, a large variety of species endowed with special powers. The target audience for the franchise is children aged 5 to 12, but it is known to attract people of all ages.\n", + "Pokémon is a Japanese media franchise that includes video games, animated series, films, trading card games, and other related media. It features a shared universe where humans coexist with Pokémon, a diverse group of creatures with special powers. The franchise is aimed at children aged 5 to 12, but it has a broad appeal across all ages.\n", "-----END-----\n", "\n" ] @@ -173,11 +175,12 @@ "output_type": "stream", "text": [ "Prompt: \n", - " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", + " A user review: very soft, colorful, expensive but deserves its price.\n", + " I would like to wear it in my friend's wedding.\n", " \n", " What is the user's opinion about the product's price?\n", " \n", - "The user's opinion about the product's price is that the product is expensive, but the user believes the price is justified by the product's soft, colorful, stylish, and expensive.\n", + "The user's opinion about the product's price is that it is expensive, but the user is willing to pay the price to wear it in a friend's wedding.\n", "-----END-----\n", "\n", "Prompt: \n", @@ -186,9 +189,9 @@ " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", " \n", - " Based on the retrieved information, what is a Pokémon?\n", + " Can you shortly describe what Pokémon is?\n", " \n", - "A Pokémon is a large variety of species endowed with special powers, which co-exist with humans in a shared universe in the media franchise of Pokémon.\n", + " Pokémon is a Japanese media franchise consisting of video games, animated series, and films. The franchise takes place in a shared universe in which humans co-exist with Pokémon, a large variety of species endowed with special powers. The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", "-----END-----\n", "\n" ] @@ -196,7 +199,7 @@ ], "source": [ "runner.generate_response(example_prompts,\n", - " [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=5.0)])" + " [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=5.0, boost_eos=False)])" ] }, { @@ -218,19 +221,16 @@ "output_type": "stream", "text": [ "Prompt: \n", - " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", + " A user review: very soft, colorful, expensive but deserves its price.\n", + " I would like to wear it in my friend's wedding.\n", " \n", " What is the user's opinion about the product's price?\n", " \n", - "The user seems to have mixed feelings regarding the product's price:\n", + "The user's opinion about the product's price seems to be mixed. They appreciate that the product is \"very soft\" and \"colorful,\" indicating that these features contribute positively to their satisfaction with the item. However, they also mention that the product is \"expensive,\" which suggests that they feel the price is justified based on the quality they perceive.\n", "\n", - "1. **Positive**: The phrase \"deserves its price\" suggests that they believe the high cost of the item is justified and that it provides good value for money.\n", + "The phrase \"deserves its price\" implies that the user believes the cost of the product is appropriate for what they have received. This indicates that they find value in the product and feel that they are getting good value for their money.\n", "\n", - "2. **Neutral**: The word \"expensive\" indicates that they find it to be costly.\n", - "\n", - "3. **Positive**: The phrase \"stylish\" suggests that they appreciate how well-designed or fashionable the item is.\n", - "\n", - "Overall, while they acknowledge that it costs more than they might expect (the \"expensive\" part), they also seem to value it enough to believe that it \"deserves\" this high cost and that it meets their expectations in terms of style and quality (\"stylish\"). The use of \"very\" in describing how \"soft\" it feels further emphasizes their positive perception of both comfort and quality of the item despite being more costly than they anticipated. Therefore, their overall impression leans more towards appreciation and satisfaction with their purchase despite it being more pricey than they initially thought it would be.\n", + "In summary, while the user appreciates the product's qualities and finds them worth the price, they also acknowledge that the cost is higher than they might have expected for such features. This suggests that they view the product as a good investment for their needs and preferences.\n", "-----END-----\n", "\n", "Prompt: \n", @@ -239,119 +239,9 @@ " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", " \n", - " Based on the retrieved information, what is a Pokémon?\n", + " Can you shortly describe what Pokémon is?\n", " \n", - "A Pokémon is an imaginary creature that exists within the fictional world of the Pokémon franchise. These characters have unique abilities or characteristics that allow them to interact or battle against each other within this imaginary world.\n", - "\n", - "Pokémon can be categorized into different types based on their abilities:\n", - "\n", - "- **Fire**: Creatures that can use fire-based attacks or resist fire-based attacks.\n", - "\n", - "- **Water**: Creatures that can use water-based attacks or resist water-based attacks.\n", - "\n", - "- **Grass**: Creatures that can use grass-based attacks or resist grass-based attacks.\n", - "\n", - "- **Electric**: Creatures that can use electric-based attacks or resist electric-based attacks.\n", - "\n", - "- **Ice**: Creatures that can use ice-based attacks or resist ice-based attacks.\n", - "\n", - "- **Fighting**: Creatures that can use physical attacks or resist physical attacks.\n", - "\n", - "- **Poison**: Creatures that can use poison-based attacks or resist poison-based attacks.\n", - "\n", - "- **Ground**: Creatures that can use ground-based attacks or resist ground-based attacks.\n", - "\n", - "- **Flying**: Creatures that can use flying-based attacks or resist flying-based attacks.\n", - "\n", - "- **Psychic**: Creatures that can use psychic-based attacks or resist psychic-based attacks.\n", - "\n", - "- **Bug**: Creatures that can use bug-based attacks or resist bug-based attacks.\n", - "\n", - "- **Dark**: Creatures that can use dark-based attacks or resist dark-based attacks.\n", - "\n", - "- **Dragon**: Creatures that can use dragon-based attacks or resist dragon-based attacks.\n", - "\n", - "- **Ghost**: Creatures that can use ghost-based attacks or resist ghost-based attacks.\n", - "\n", - "- **Steel**: Creatures that can use steel-based attacks or resist steel-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "Pokémon can also be classified into different types based on their abilities:\n", - "\n", - "- **Normal**: Creatures that can use normal-based attacks or resist normal-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks or resist fairy-based attacks.\n", - "\n", - "- **Fairy**: Creatures that can use fairy-based attacks\n", + "Pokémon is a popular Japanese media franchise that features a world where humans live alongside magical creatures called Pokémon. These Pokémon have unique abilities that allow them to fight alongside humans in various adventures. The franchise includes video games, animated series, films, trading cards, and other forms of media aimed at children aged 5 to 12, though it has also gained popularity among adults.\n", "-----END-----\n", "\n" ] @@ -359,7 +249,7 @@ ], "source": [ "runner.generate_response(example_prompts,\n", - " [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=-5.0)])" + " [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=-2.0, boost_eos=False)])" ] }, {