From e979b9cdf56be7aed48d9878a9c7a867ebbdf10c Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Tue, 24 Jun 2025 15:32:33 +0300 Subject: [PATCH 01/11] Use tensorrt llm api Signed-off-by: aerdem4 --- README.md | 10 ++ example_notebooks/trtllm/README.md | 4 +- .../trtllm/gen_length_logits_processor.py | 5 +- example_notebooks/trtllm/utils.py | 118 ++++-------------- .../trtllm/generation_length.py | 22 ++-- pyproject.toml | 4 +- 6 files changed, 53 insertions(+), 110 deletions(-) diff --git a/README.md b/README.md index cbcbe7e..b249621 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,16 @@ Struggling to get LLMs to follow your instructions? LogitsProcessorZoo offers a pip install logits-processor-zoo ``` +With vllm installation: +```bash +pip install logits-processor-zoo[vllm] +``` + +With tensorrt-llm installation: +```bash +pip install logits-processor-zoo[tensorrt-llm] +``` + ## Supported Frameworks * transformers * vLLM diff --git a/example_notebooks/trtllm/README.md b/example_notebooks/trtllm/README.md index 1934526..1a3b9fa 100644 --- a/example_notebooks/trtllm/README.md +++ b/example_notebooks/trtllm/README.md @@ -8,6 +8,6 @@ https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html ## Examples ``` -python example_notebooks/trtllm/gen_length_logits_processor.py --engine_path ../TensorRT-LLM/examples/llama/llama-engine/ --tokenizer_path ~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/x/ -python example_notebooks/trtllm/multiple_choice_logits_processor.py --engine_path ../TensorRT-LLM/examples/llama/llama-engine/ --tokenizer_path ~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/x/ --prompt "Which one is heavier?\n1. 1 kg\n2. 100 kg\n3. 10 kg\nAnswer:" +python example_notebooks/trtllm/gen_length_logits_processor.py --backend pytorch --model_name Qwen/Qwen2.5-1.5B-Instruct +python example_notebooks/trtllm/multiple_choice_logits_processor.py --model_name Qwen/Qwen2.5-1.5B-Instruct --prompt "Which one is heavier?\n1. 1 kg\n2. 100 kg\n3. 10 kg\nAnswer:" ``` \ No newline at end of file diff --git a/example_notebooks/trtllm/gen_length_logits_processor.py b/example_notebooks/trtllm/gen_length_logits_processor.py index f885985..0a0bb31 100644 --- a/example_notebooks/trtllm/gen_length_logits_processor.py +++ b/example_notebooks/trtllm/gen_length_logits_processor.py @@ -5,10 +5,9 @@ if __name__ == "__main__": args = get_parser() - beam_width = 1 - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) lp = GenLengthLogitsProcessor(tokenizer, boost_factor=1.0, complete_sentences=True) - TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width) + TRTLLMTester(args.model_name, args.backend, lp).run(args.prompts) diff --git a/example_notebooks/trtllm/utils.py b/example_notebooks/trtllm/utils.py index 7739d9d..08dcd1f 100644 --- a/example_notebooks/trtllm/utils.py +++ b/example_notebooks/trtllm/utils.py @@ -1,114 +1,42 @@ import argparse -import datetime from typing import List +from tensorrt_llm.sampling_params import SamplingParams, LogitsProcessor -import tensorrt_llm.bindings.executor as trtllm - - -# TensorRT-LLM utility functions are taken from: -# https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/bindings/executor/example_logits_processor.py -# Prepare and enqueue the requests class TRTLLMTester: - def __init__(self, logits_processor, tokenizer, args): - self.logits_processor = logits_processor - self.tokenizer = tokenizer - self.args = args - - def enqueue_requests(self, prompt: List[int], executor: trtllm.Executor, - beam_width: int, max_new_tokens: int, batch_size: int = 1): - sampling_config = trtllm.SamplingConfig(beam_width) - - request_ids = [] - for iter_id in range(batch_size): - # Create the request. - request = trtllm.Request(input_token_ids=prompt, - max_new_tokens=max_new_tokens, - end_id=self.tokenizer.eos_token_id, - sampling_config=sampling_config, - client_id=iter_id % 2) - request.logits_post_processor_name = "my_logits_pp" - - # Enqueue the request. - req_id = executor.enqueue_request(request) - request_ids.append(req_id) - - return request_ids - - # Wait for responses and store output tokens - def wait_for_responses(self, request_ids: List[int], - executor: trtllm.Executor, beam_width: int): - output_tokens = { - req_id: {beam: [] - for beam in range(beam_width)} - for req_id in request_ids - } - num_finished = 0 - iter = 0 - while num_finished < len(request_ids) and iter < self.args.timeout_ms: - responses = executor.await_responses( - datetime.timedelta(milliseconds=self.args.timeout_ms)) - for response in responses: - req_id = response.request_id - if not response.has_error(): - result = response.result - num_finished += 1 if result.is_final else 0 - for beam, outTokens in enumerate(result.output_token_ids): - output_tokens[req_id][beam].extend(outTokens) - else: - raise RuntimeError(f"{req_id} encountered error: {response.error_msg}") - - return output_tokens - - def run(self, prompt: str, beam_width: int = 1, max_new_tokens: int = 2000): - # Create the executor. - executor_config = trtllm.ExecutorConfig(beam_width) - executor_config.logits_post_processor_map = { - "my_logits_pp": self.logits_processor - } - executor = trtllm.Executor(self.args.engine_path, trtllm.ModelType.DECODER_ONLY, - executor_config) - - prompt_encoded = self.tokenizer.encode(prompt) - print(f"Input text: {prompt}\n") + def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", backend: str = "tensorrt-llm", + logits_processor: LogitsProcessor = None): + if backend == "pytorch": + from tensorrt_llm._torch import LLM + else: + from tensorrt_llm import LLM - if executor.can_enqueue_requests(): - request_ids = self.enqueue_requests(prompt_encoded, executor, beam_width, max_new_tokens) - output_tokens = self.wait_for_responses(request_ids, executor, beam_width) + self.llm = LLM(model=model_name) + self.lp = logits_processor - # Print output - for req_id in request_ids: - for beam_id in range(beam_width): - result = self.tokenizer.decode( - output_tokens[req_id][beam_id][len(prompt_encoded):]) - generated_tokens = len( - output_tokens[req_id][beam_id]) - len(prompt_encoded) - print( - f"Request {req_id} Beam {beam_id} ({generated_tokens} tokens): {result}" - ) + def run(self, prompts: List[str], max_tokens: int = 256): + sparams = {"top_k": 1, "max_tokens": max_tokens, "temperature": 0.001} + if self.lp: + sparams["logits_processor"] = self.lp + output = self.llm.generate(prompts, SamplingParams(**sparams)) + print(output) def get_parser(): parser = argparse.ArgumentParser(description="Logits Processor Example") - parser.add_argument("--tokenizer_path", - "-t", + parser.add_argument("--model_name", + "-m", type=str, - required=True, - help="Directory containing model tokenizer") - parser.add_argument("--engine_path", - "-e", + default="Qwen/Qwen2.5-1.5B-Instruct", + help="Directory or HF link containing model") + parser.add_argument("--backend", + "-b", type=str, - required=True, - help="Directory containing model engine") + default="tensorrt-llm", + help="TensorRT-LLM backend") parser.add_argument("--prompt", "-p", type=str, default="Please give me information about macaques:", help="Prompt to test") - parser.add_argument( - "--timeout_ms", - type=int, - required=False, - default=10000, - help="The maximum time to wait for all responses, in milliseconds") return parser.parse_args() diff --git a/logits_processor_zoo/trtllm/generation_length.py b/logits_processor_zoo/trtllm/generation_length.py index 75496f5..2684049 100644 --- a/logits_processor_zoo/trtllm/generation_length.py +++ b/logits_processor_zoo/trtllm/generation_length.py @@ -18,10 +18,11 @@ from typing import List, Optional from transformers import PreTrainedTokenizer import torch +from tensorrt_llm.sampling_params import LogitsProcessor from logits_processor_zoo.utils import text_to_token -class GenLengthLogitsProcessor: +class GenLengthLogitsProcessor(LogitsProcessor): """ A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token based on the length of the generated sequence, encouraging or discouraging shorter answers. @@ -48,19 +49,22 @@ def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float, self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True) self.complete_sentences = complete_sentences - def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], - ids_batch: List[List[List[int]]], stream_ptr, - client_ids_batch: List[Optional[int]]): + def __call__(self, req_id: int, logits: torch.Tensor, + token_ids: List[List[int]], stream_ptr: Optional[int], + client_id: Optional[int]) -> None: boost_val = self.boost_factor * (self.token_count ** self.p) / (10 ** self.p) - with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): - ids_batch = torch.LongTensor(ids_batch).to(logits_batch.device, non_blocking=True) + stream = None if stream_ptr is None else torch.cuda.ExternalStream( + stream_ptr) + + with torch.cuda.stream(stream): + ids = torch.LongTensor(token_ids).to(logits.device, non_blocking=True) if self.complete_sentences: - enabled = (ids_batch[:, -1] == self.full_stop_token) | (ids_batch[:, -1] == self.new_line_token) - logits_batch[:, :, self.eos_token] += enabled * boost_val + enabled = (ids[:, -1] == self.full_stop_token) | (ids[:, -1] == self.new_line_token) + logits[:, :, self.eos_token] += enabled * boost_val else: - logits_batch[:, :, self.eos_token] += boost_val + logits[:, :, self.eos_token] += boost_val self.token_count += 1 diff --git a/pyproject.toml b/pyproject.toml index 1de2def..100f97a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "logits-processor-zoo" -version = "0.1.11" +version = "0.2.0" description = "A collection of LogitsProcessors to customize and enhance LLM behavior for specific tasks." authors = ["Ahmet Erdem", "Ivan Sorokin", "Maximilian Jeblick", "Darragh Hanley", "David Austin"] readme = "README.md" @@ -11,9 +11,11 @@ torch = "*" transformers = ">=4.41.2" accelerate = ">=0.26.1" vllm = { version = ">=0.5.0.post1", optional = true } +tensorrt-llm = { version = ">=0.20.0", optional = true} [tool.poetry.extras] vllm = ["vllm"] +tensorrt-llm = ["tensorrt-llm"] [build-system] From 910c7a0020791b97764eb03346bd449d10cdd353 Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Fri, 27 Jun 2025 10:29:16 +0300 Subject: [PATCH 02/11] Update readme Signed-off-by: aerdem4 --- README.md | 14 +++----------- pyproject.toml | 2 -- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index b249621..ff86ffb 100644 --- a/README.md +++ b/README.md @@ -15,20 +15,10 @@ Struggling to get LLMs to follow your instructions? LogitsProcessorZoo offers a pip install logits-processor-zoo ``` -With vllm installation: -```bash -pip install logits-processor-zoo[vllm] -``` - -With tensorrt-llm installation: -```bash -pip install logits-processor-zoo[tensorrt-llm] -``` - ## Supported Frameworks * transformers * vLLM -* TensorRT-LLM +* TensorRT-LLM (>=0.20.0) ## Usage @@ -97,3 +87,5 @@ One common use case is to force writing python code just after thinking: trigger_python = TriggerPhraseLogitsProcessor(phrase="\n```python", trigger_token_phrase="", tokenizer=tokenizer, trigger_count=1, trigger_after=True) ``` +### PreventHallucinationLogitsProcessor +A logits processor that mitigates hallucinated model outputs by enforcing a predefined fallback phrase when token confidence falls below a specified threshold. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 100f97a..c1b63ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,9 @@ torch = "*" transformers = ">=4.41.2" accelerate = ">=0.26.1" vllm = { version = ">=0.5.0.post1", optional = true } -tensorrt-llm = { version = ">=0.20.0", optional = true} [tool.poetry.extras] vllm = ["vllm"] -tensorrt-llm = ["tensorrt-llm"] [build-system] From 51903ea9765e564012ff266f7888e1114b327a17 Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Fri, 27 Jun 2025 11:27:01 +0300 Subject: [PATCH 03/11] Improve GenLength LP example Signed-off-by: aerdem4 --- .../trtllm/gen_length_logits_processor.py | 4 +-- example_notebooks/trtllm/utils.py | 28 +++++++++++++------ .../trtllm/generation_length.py | 23 ++++++++------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/example_notebooks/trtllm/gen_length_logits_processor.py b/example_notebooks/trtllm/gen_length_logits_processor.py index 0a0bb31..f973359 100644 --- a/example_notebooks/trtllm/gen_length_logits_processor.py +++ b/example_notebooks/trtllm/gen_length_logits_processor.py @@ -7,7 +7,7 @@ args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.model_name) + llm_tester = TRTLLMTester(args.model_name, args.backend) lp = GenLengthLogitsProcessor(tokenizer, boost_factor=1.0, complete_sentences=True) - - TRTLLMTester(args.model_name, args.backend, lp).run(args.prompts) + llm_tester.run([args.prompt], logits_processor=lp) diff --git a/example_notebooks/trtllm/utils.py b/example_notebooks/trtllm/utils.py index 08dcd1f..d8fbe6a 100644 --- a/example_notebooks/trtllm/utils.py +++ b/example_notebooks/trtllm/utils.py @@ -3,22 +3,34 @@ from tensorrt_llm.sampling_params import SamplingParams, LogitsProcessor class TRTLLMTester: - def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", backend: str = "tensorrt-llm", - logits_processor: LogitsProcessor = None): + def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", backend: str = "tensorrt-llm"): if backend == "pytorch": from tensorrt_llm._torch import LLM else: from tensorrt_llm import LLM self.llm = LLM(model=model_name) - self.lp = logits_processor - def run(self, prompts: List[str], max_tokens: int = 256): + def run(self, prompts: List[str], max_tokens: int = 256, logits_processor: LogitsProcessor = None): sparams = {"top_k": 1, "max_tokens": max_tokens, "temperature": 0.001} - if self.lp: - sparams["logits_processor"] = self.lp - output = self.llm.generate(prompts, SamplingParams(**sparams)) - print(output) + if logits_processor: + sparams["logits_processor"] = logits_processor + + prompts_with_template = [] + for prompt in prompts: + messages = [ + { + "role": "user", + "content": prompt + } + ] + text = self.llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + prompts_with_template.append(text) + + gens = self.llm.generate(prompts_with_template, SamplingParams(**sparams)) + for prompt, gen in zip(prompts, gens): + print(prompt) + print(gen.outputs[0].text) def get_parser(): diff --git a/logits_processor_zoo/trtllm/generation_length.py b/logits_processor_zoo/trtllm/generation_length.py index 2684049..32bd799 100644 --- a/logits_processor_zoo/trtllm/generation_length.py +++ b/logits_processor_zoo/trtllm/generation_length.py @@ -26,7 +26,6 @@ class GenLengthLogitsProcessor(LogitsProcessor): """ A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token based on the length of the generated sequence, encouraging or discouraging shorter answers. - WARNING: Create a new object before every model.generate call since token_count is accumulated. Parameters ---------- @@ -36,18 +35,22 @@ class GenLengthLogitsProcessor(LogitsProcessor): p (int, optional): The power to which the token count is raised when computing the boost value. Default is 2. complete_sentences (bool, optional): If True, boosts EOS token likelihood only when the last token is a full stop or a new line. Default is False. - + boost_token_str (str, optional): A string to be tokenized and used instead of EOS. Especially useful for . """ - def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float, - p: int = 2, complete_sentences: bool = False): - self.eos_token = tokenizer.eos_token_id + p: int = 2, complete_sentences: bool = False, boost_token_str: str = None): + + self.tokenizer = tokenizer + self.boost_token = self.tokenizer.eos_token_id + self.boost_token_str = boost_token_str + if boost_token_str is not None: + self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False) self.boost_factor = boost_factor self.p = p - self.token_count = 0 - self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True) - self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True) + self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True) + self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True) self.complete_sentences = complete_sentences + self.token_count = 0 def __call__(self, req_id: int, logits: torch.Tensor, token_ids: List[List[int]], stream_ptr: Optional[int], @@ -63,8 +66,8 @@ def __call__(self, req_id: int, logits: torch.Tensor, if self.complete_sentences: enabled = (ids[:, -1] == self.full_stop_token) | (ids[:, -1] == self.new_line_token) - logits[:, :, self.eos_token] += enabled * boost_val + logits[:, :, self.boost_token] += enabled * boost_val else: - logits[:, :, self.eos_token] += boost_val + logits[:, :, self.boost_token] += boost_val self.token_count += 1 From 15532e7b37fbed9e22cc163a46b685243fb8083e Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Fri, 27 Jun 2025 13:26:12 +0300 Subject: [PATCH 04/11] Update trtllm CiteFromPrompt LP Signed-off-by: aerdem4 --- example_notebooks/trtllm/README.md | 9 +++- .../trtllm/cite_prompt_logits_processor.py | 10 ++-- .../trtllm/gen_length_logits_processor.py | 3 ++ logits_processor_zoo/trtllm/cite_prompt.py | 52 +++++++++++++------ .../trtllm/generation_length.py | 4 +- 5 files changed, 53 insertions(+), 25 deletions(-) diff --git a/example_notebooks/trtllm/README.md b/example_notebooks/trtllm/README.md index 1a3b9fa..01b1417 100644 --- a/example_notebooks/trtllm/README.md +++ b/example_notebooks/trtllm/README.md @@ -8,6 +8,11 @@ https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html ## Examples ``` -python example_notebooks/trtllm/gen_length_logits_processor.py --backend pytorch --model_name Qwen/Qwen2.5-1.5B-Instruct -python example_notebooks/trtllm/multiple_choice_logits_processor.py --model_name Qwen/Qwen2.5-1.5B-Instruct --prompt "Which one is heavier?\n1. 1 kg\n2. 100 kg\n3. 10 kg\nAnswer:" +python example_notebooks/trtllm/gen_length_logits_processor.py +python example_notebooks/trtllm/cite_prompt_logits_processor.py -p " Retrieved information: + Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. + The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. + The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages. + + Can you shortly describe what Pokémon is?" ``` \ No newline at end of file diff --git a/example_notebooks/trtllm/cite_prompt_logits_processor.py b/example_notebooks/trtllm/cite_prompt_logits_processor.py index 9826882..9848c54 100644 --- a/example_notebooks/trtllm/cite_prompt_logits_processor.py +++ b/example_notebooks/trtllm/cite_prompt_logits_processor.py @@ -5,10 +5,12 @@ if __name__ == "__main__": args = get_parser() - beam_width = 1 - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + llm_tester = TRTLLMTester(args.model_name, args.backend) - lp = CiteFromPromptLogitsProcessor(tokenizer, [args.prompt], boost_factor=1.0) + lp = CiteFromPromptLogitsProcessor(tokenizer, boost_factor=1.0, boost_eos=False, conditional_boost_factor=3.0) + llm_tester.run([args.prompt], logits_processor=lp) - TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width) + lp = CiteFromPromptLogitsProcessor(tokenizer, boost_factor=-1.0, boost_eos=False, conditional_boost_factor=-1.0) + llm_tester.run([args.prompt], logits_processor=lp) diff --git a/example_notebooks/trtllm/gen_length_logits_processor.py b/example_notebooks/trtllm/gen_length_logits_processor.py index f973359..deb56d5 100644 --- a/example_notebooks/trtllm/gen_length_logits_processor.py +++ b/example_notebooks/trtllm/gen_length_logits_processor.py @@ -11,3 +11,6 @@ lp = GenLengthLogitsProcessor(tokenizer, boost_factor=1.0, complete_sentences=True) llm_tester.run([args.prompt], logits_processor=lp) + + lp = GenLengthLogitsProcessor(tokenizer, boost_factor=-1.0, p=0, complete_sentences=True) + llm_tester.run([args.prompt], logits_processor=lp) diff --git a/logits_processor_zoo/trtllm/cite_prompt.py b/logits_processor_zoo/trtllm/cite_prompt.py index 4276251..9f52aa0 100644 --- a/logits_processor_zoo/trtllm/cite_prompt.py +++ b/logits_processor_zoo/trtllm/cite_prompt.py @@ -18,39 +18,57 @@ from typing import List, Optional import torch from transformers import PreTrainedTokenizer +from tensorrt_llm.sampling_params import LogitsProcessor -class CiteFromPromptLogitsProcessor: +class CiteFromPromptLogitsProcessor(LogitsProcessor): """ A logits processor which boosts or diminishes the likelihood of tokens present in the prompt (and optionally EOS token) to encourage the model to generate tokens similar to those seen in the prompt or vice versa. - WARNING: Create a new object before every model.generate call since every batch has different prompts. Parameters ---------- tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. - prompts (List[str]): Prompts in the batch. boost_factor (float): A factor to boost the likelihood of the tokens from the prompt. Negative values are used for the opposite effect. boost_eos (bool, optional): If True, boosts EOS token too. + conditional_boost_factor (float, optional): A factor to boost the likelihood of the tokens based on previous token. """ - def __init__(self, tokenizer: PreTrainedTokenizer, prompts: List[str], boost_factor: float = 1.0, - boost_eos: bool = True): + def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float = 1.0, boost_eos: bool = True, + conditional_boost_factor: float = 0.0): + self.tokenizer = tokenizer self.boost_factor = boost_factor + self.eos_token_id = self.tokenizer.eos_token_id + self.boost_eos = boost_eos + self.conditional_boost_factor = conditional_boost_factor + self.first_token = True + self.prompt_token_ids = list() - self.boost_ids = [] - for prompt in prompts: - prompt_tokens = set(tokenizer.encode(prompt)) + def __call__(self, req_id: int, logits: torch.Tensor, + token_ids: List[List[int]], stream_ptr: Optional[int], + client_id: Optional[int]) -> None: + if self.first_token: + self.prompt_token_ids = list(token_ids[0]) # take first beam since all beams have the same prompt + self.first_token = False - if boost_eos: - prompt_tokens.add(tokenizer.eos_token_id) + tokens = set(self.prompt_token_ids) + if self.boost_eos: + tokens.add(self.eos_token_id) - self.boost_ids.append(list(prompt_tokens)) + tokens = [t for t in tokens if t < logits.shape[-1]] - def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], - ids_batch: List[List[List[int]]], stream_ptr, - client_ids_batch: List[Optional[int]]): + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) - with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): - for i in range(logits_batch.shape[1]): - logits_batch[:, i, self.boost_ids[i]] += self.boost_factor + with torch.cuda.stream(stream): + logits[:, :, tokens] += self.boost_factor + + if self.conditional_boost_factor != 0: + + for i in range(len(token_ids)): # iterate over beams + tokens = set() + for prompt_token_idx in range(len(self.prompt_token_ids) - 1): + in_vocab = self.prompt_token_ids[prompt_token_idx + 1] < logits.shape[-1] + last_token = self.prompt_token_ids[prompt_token_idx] == token_ids[i][-1] + if last_token and in_vocab: + tokens.add(self.prompt_token_ids[prompt_token_idx + 1]) + logits[:, i, list(tokens)] += self.conditional_boost_factor diff --git a/logits_processor_zoo/trtllm/generation_length.py b/logits_processor_zoo/trtllm/generation_length.py index 32bd799..596ef92 100644 --- a/logits_processor_zoo/trtllm/generation_length.py +++ b/logits_processor_zoo/trtllm/generation_length.py @@ -58,8 +58,8 @@ def __call__(self, req_id: int, logits: torch.Tensor, boost_val = self.boost_factor * (self.token_count ** self.p) / (10 ** self.p) - stream = None if stream_ptr is None else torch.cuda.ExternalStream( - stream_ptr) + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) + with torch.cuda.stream(stream): ids = torch.LongTensor(token_ids).to(logits.device, non_blocking=True) From fff779ce1505b411792c8ca0eca4d630e4514f24 Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Mon, 30 Jun 2025 09:52:43 +0300 Subject: [PATCH 05/11] Update trtllm LastPhrase LP Signed-off-by: aerdem4 --- example_notebooks/trtllm/README.md | 1 + .../trtllm/last_phrase_logits_processor.py | 9 ++--- logits_processor_zoo/trtllm/last_phrase.py | 37 +++++++++++-------- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/example_notebooks/trtllm/README.md b/example_notebooks/trtllm/README.md index 01b1417..0f4b9c9 100644 --- a/example_notebooks/trtllm/README.md +++ b/example_notebooks/trtllm/README.md @@ -15,4 +15,5 @@ python example_notebooks/trtllm/cite_prompt_logits_processor.py -p " Retrieve The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages. Can you shortly describe what Pokémon is?" +python example_notebooks/trtllm/last_phrase_logits_processor.py ``` \ No newline at end of file diff --git a/example_notebooks/trtllm/last_phrase_logits_processor.py b/example_notebooks/trtllm/last_phrase_logits_processor.py index 8d774f3..bcbcbbb 100644 --- a/example_notebooks/trtllm/last_phrase_logits_processor.py +++ b/example_notebooks/trtllm/last_phrase_logits_processor.py @@ -5,12 +5,11 @@ if __name__ == "__main__": args = get_parser() - beam_width = 1 - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + llm_tester = TRTLLMTester(args.model_name, args.backend) phrase = "\n\nThanks for trying our application! If you have more questions about" + lp = ForceLastPhraseLogitsProcessor(phrase, tokenizer) - lp = ForceLastPhraseLogitsProcessor(phrase, tokenizer, batch_size=1) - - TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width) + llm_tester.run([args.prompt], logits_processor=lp) diff --git a/logits_processor_zoo/trtllm/last_phrase.py b/logits_processor_zoo/trtllm/last_phrase.py index 9983267..10908f7 100644 --- a/logits_processor_zoo/trtllm/last_phrase.py +++ b/logits_processor_zoo/trtllm/last_phrase.py @@ -18,35 +18,42 @@ from typing import List, Optional from transformers import PreTrainedTokenizer import torch +from logits_processor_zoo.utils import enforce_tokens +from tensorrt_llm.sampling_params import LogitsProcessor -class ForceLastPhraseLogitsProcessor: +class ForceLastPhraseLogitsProcessor(LogitsProcessor): """ A logits processor which forces LLMs to use the given phrase before they finalize their answers. Most common use cases can be providing references, thanking user with context etc. - WARNING: Create a new object before every model.generate call to reset iterators. Parameters ---------- phrase (str): The phrase to be generated by LLM before the end of its speech. tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. - batch_size (int): Number of prompts in the batch. """ - def __init__(self, phrase: str, tokenizer: PreTrainedTokenizer, batch_size: int): + def __init__(self, phrase: str, tokenizer: PreTrainedTokenizer): self.eos_token_id = tokenizer.eos_token_id self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) - self.iterators = torch.zeros(batch_size, dtype=torch.int32) + self.first_token = True + self.iterators = None - def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], - ids_batch: List[List[List[int]]], stream_ptr, - client_ids_batch: List[Optional[int]]): + def __call__(self, req_id: int, logits: torch.Tensor, + token_ids: List[List[int]], stream_ptr: Optional[int], + client_id: Optional[int]) -> None: + beam_width = len(token_ids) + if self.first_token: + self.iterators = torch.zeros(beam_width, dtype=torch.int32) + self.first_token = False - with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): - for i in range(logits_batch.shape[1]): - it = self.iterators[i].item() - if logits_batch[:, i, :].argmax() == self.eos_token_id and it == 0: - logits_batch[:, i, self.phrase_tokens[it]] = logits_batch[:, i].max() + 1 + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) + + with torch.cuda.stream(stream): + for i in range(beam_width): # iterate over beams + current_index = self.iterators[i].item() + if logits[0, i].argmax() == self.eos_token_id and current_index == 0: + enforce_tokens(logits[0, i], [self.phrase_tokens[current_index]]) self.iterators[i] += 1 - elif len(self.phrase_tokens) > it > 0: - logits_batch[:, i, self.phrase_tokens[it]] = logits_batch[:, i].max() + 1 + elif len(self.phrase_tokens) > current_index > 0: + enforce_tokens(logits[0, i], [self.phrase_tokens[current_index]]) self.iterators[i] += 1 From e525d7e773ddee6baa1eeb9677bbc748589715df Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Mon, 30 Jun 2025 09:55:02 +0300 Subject: [PATCH 06/11] Update trtllm LastPhrase LP Signed-off-by: aerdem4 --- example_notebooks/trtllm/utils.py | 1 + logits_processor_zoo/trtllm/cite_prompt.py | 4 ++-- logits_processor_zoo/trtllm/generation_length.py | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/example_notebooks/trtllm/utils.py b/example_notebooks/trtllm/utils.py index d8fbe6a..349c931 100644 --- a/example_notebooks/trtllm/utils.py +++ b/example_notebooks/trtllm/utils.py @@ -2,6 +2,7 @@ from typing import List from tensorrt_llm.sampling_params import SamplingParams, LogitsProcessor + class TRTLLMTester: def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", backend: str = "tensorrt-llm"): if backend == "pytorch": diff --git a/logits_processor_zoo/trtllm/cite_prompt.py b/logits_processor_zoo/trtllm/cite_prompt.py index 9f52aa0..74acff3 100644 --- a/logits_processor_zoo/trtllm/cite_prompt.py +++ b/logits_processor_zoo/trtllm/cite_prompt.py @@ -48,7 +48,7 @@ def __call__(self, req_id: int, logits: torch.Tensor, token_ids: List[List[int]], stream_ptr: Optional[int], client_id: Optional[int]) -> None: if self.first_token: - self.prompt_token_ids = list(token_ids[0]) # take first beam since all beams have the same prompt + self.prompt_token_ids = list(token_ids[0]) # take first beam since all beams have the same prompt self.first_token = False tokens = set(self.prompt_token_ids) @@ -64,7 +64,7 @@ def __call__(self, req_id: int, logits: torch.Tensor, if self.conditional_boost_factor != 0: - for i in range(len(token_ids)): # iterate over beams + for i in range(len(token_ids)): # iterate over beams tokens = set() for prompt_token_idx in range(len(self.prompt_token_ids) - 1): in_vocab = self.prompt_token_ids[prompt_token_idx + 1] < logits.shape[-1] diff --git a/logits_processor_zoo/trtllm/generation_length.py b/logits_processor_zoo/trtllm/generation_length.py index 596ef92..606651b 100644 --- a/logits_processor_zoo/trtllm/generation_length.py +++ b/logits_processor_zoo/trtllm/generation_length.py @@ -60,7 +60,6 @@ def __call__(self, req_id: int, logits: torch.Tensor, stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) - with torch.cuda.stream(stream): ids = torch.LongTensor(token_ids).to(logits.device, non_blocking=True) From e1490fc4f8edc77ed9117e95ee26c506cc263eb3 Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Mon, 30 Jun 2025 11:32:52 +0300 Subject: [PATCH 07/11] Update trtllm MultipleChoice LP Signed-off-by: aerdem4 --- example_notebooks/trtllm/README.md | 5 ++ .../multiple_choice_logits_processor.py | 10 +-- .../trtllm/multiple_choice.py | 63 ++++++++++--------- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/example_notebooks/trtllm/README.md b/example_notebooks/trtllm/README.md index 0f4b9c9..f8fb979 100644 --- a/example_notebooks/trtllm/README.md +++ b/example_notebooks/trtllm/README.md @@ -16,4 +16,9 @@ python example_notebooks/trtllm/cite_prompt_logits_processor.py -p " Retrieve Can you shortly describe what Pokémon is?" python example_notebooks/trtllm/last_phrase_logits_processor.py +python example_notebooks/trtllm/multiple_choice_logits_processor.py -p "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? +0. Camera +1. Screen resolution +2. Operating System +3. Battery" ``` \ No newline at end of file diff --git a/example_notebooks/trtllm/multiple_choice_logits_processor.py b/example_notebooks/trtllm/multiple_choice_logits_processor.py index 7d02f15..31e782e 100644 --- a/example_notebooks/trtllm/multiple_choice_logits_processor.py +++ b/example_notebooks/trtllm/multiple_choice_logits_processor.py @@ -5,10 +5,12 @@ if __name__ == "__main__": args = get_parser() - beam_width = 1 - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + llm_tester = TRTLLMTester(args.model_name, args.backend) - lp = MultipleChoiceLogitsProcessor(tokenizer, choices=["1", "2"], delimiter=".", boost_first_words=0.5) + lp = mclp = MultipleChoiceLogitsProcessor(tokenizer, choices=["0", "1", "2", "3"]) + llm_tester.run([args.prompt], logits_processor=lp, max_tokens=1) - TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width, max_new_tokens=1) + lp = MultipleChoiceLogitsProcessor(tokenizer, choices=["0", "1", "2", "3"], delimiter=".", boost_first_words=2.0) + llm_tester.run([args.prompt], logits_processor=lp, max_tokens=1) diff --git a/logits_processor_zoo/trtllm/multiple_choice.py b/logits_processor_zoo/trtllm/multiple_choice.py index b471905..ef0cde1 100644 --- a/logits_processor_zoo/trtllm/multiple_choice.py +++ b/logits_processor_zoo/trtllm/multiple_choice.py @@ -18,10 +18,11 @@ from transformers import PreTrainedTokenizer from typing import List, Optional import torch -from logits_processor_zoo.utils import text_to_token, get_new_line_tokens +from logits_processor_zoo.utils import text_to_token, get_new_line_tokens, enforce_tokens +from tensorrt_llm.sampling_params import LogitsProcessor -class MultipleChoiceLogitsProcessor: +class MultipleChoiceLogitsProcessor(LogitsProcessor): """ A logits processor to answer multiple choice questions with one of the choices. A multiple choice question is like: @@ -50,38 +51,42 @@ def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None, self.delimiter_token = text_to_token(tokenizer, delimiter, last=False) self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices] self.boost_first_words = boost_first_words - self.very_large_number = 999 + self.first_tokens = list() - def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], - ids_batch: List[List[List[int]]], stream_ptr, - client_ids_batch: List[Optional[int]]): + def _init_choice_first_words(self, prompt_token_ids): + choice = 0 - if self.boost_first_words: - with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): - ids_batch = torch.LongTensor(ids_batch).to(logits_batch.device, non_blocking=True) + first_tokens = [] + for i in range(len(prompt_token_ids) - 3): + # A choice is like "\nA) hair dryer", where first token is "hair" + choice_starts = ( + (prompt_token_ids[i] in self.new_line_tokens) and + (prompt_token_ids[i + 1] == self.choice_tokens[choice]) and + (prompt_token_ids[i + 2] == self.delimiter_token) + ) - for row_ind in range(ids_batch.shape[0]): - if self.boost_first_words: - choice = 0 + if choice_starts: + first_tokens.append(prompt_token_ids[i + 3]) + choice += 1 - first_tokens = [] - for i in range(len(ids_batch[row_ind]) - 3): - # A choice is like "\nA) hair dryer", where first token is "hair" - choice_starts = ( - (ids_batch[row_ind, i].item() in self.new_line_tokens) and - (ids_batch[row_ind, i + 1] == self.choice_tokens[choice]) and - (ids_batch[row_ind, i + 2] == self.delimiter_token) - ) + if choice >= len(self.choice_tokens): + break + return first_tokens - if choice_starts: - first_tokens.append(ids_batch[row_ind, i + 3]) - choice += 1 + def __call__(self, req_id: int, logits: torch.Tensor, + token_ids: List[List[int]], stream_ptr: Optional[int], + client_id: Optional[int]) -> None: - if choice >= len(self.choice_tokens): - break + if len(self.first_tokens) == 0 and self.boost_first_words: + prompt_token_ids = list(token_ids[0]) # take first beam since all beams have the same prompt + self.first_tokens = self._init_choice_first_words(prompt_token_ids) - boost = self.boost_first_words * logits_batch[:, row_ind, first_tokens] - logits_batch[:, row_ind, self.choice_tokens[:len(first_tokens)]] += boost + beam_width = len(token_ids) + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) - with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): - logits_batch[:, :, self.choice_tokens] += self.very_large_number + with torch.cuda.stream(stream): + if len(self.first_tokens) > 0: + boost = self.boost_first_words * logits[0, :, self.first_tokens] + logits[0, :, self.choice_tokens[:len(self.first_tokens)]] += boost + for i in range(beam_width): # iterate over beams + enforce_tokens(logits[0, i], self.choice_tokens) From cc08f364bdec9955c10e3b43ff156f203a8873f3 Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Mon, 30 Jun 2025 14:28:52 +0300 Subject: [PATCH 08/11] Add trtllm PreventHallucination LP Signed-off-by: aerdem4 --- example_notebooks/trtllm/README.md | 7 +- .../multiple_choice_logits_processor.py | 2 +- .../prevent_hallucination_logits_processor.py | 13 +++ logits_processor_zoo/trtllm/__init__.py | 3 +- logits_processor_zoo/trtllm/cite_prompt.py | 11 +-- logits_processor_zoo/trtllm/last_phrase.py | 9 +- .../trtllm/prevent_hallucination.py | 86 +++++++++++++++++++ 7 files changed, 119 insertions(+), 12 deletions(-) create mode 100644 example_notebooks/trtllm/prevent_hallucination_logits_processor.py create mode 100644 logits_processor_zoo/trtllm/prevent_hallucination.py diff --git a/example_notebooks/trtllm/README.md b/example_notebooks/trtllm/README.md index f8fb979..1b13983 100644 --- a/example_notebooks/trtllm/README.md +++ b/example_notebooks/trtllm/README.md @@ -9,16 +9,21 @@ https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html ``` python example_notebooks/trtllm/gen_length_logits_processor.py -python example_notebooks/trtllm/cite_prompt_logits_processor.py -p " Retrieved information: + +python example_notebooks/trtllm/cite_prompt_logits_processor.py -p "Retrieved information: Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages. Can you shortly describe what Pokémon is?" + python example_notebooks/trtllm/last_phrase_logits_processor.py + python example_notebooks/trtllm/multiple_choice_logits_processor.py -p "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 0. Camera 1. Screen resolution 2. Operating System 3. Battery" + +python example_notebooks/trtllm/prevent_hallucination_logits_processor.py -p "Tell me the Nobel Prizes in 1977" ``` \ No newline at end of file diff --git a/example_notebooks/trtllm/multiple_choice_logits_processor.py b/example_notebooks/trtllm/multiple_choice_logits_processor.py index 31e782e..018ccd7 100644 --- a/example_notebooks/trtllm/multiple_choice_logits_processor.py +++ b/example_notebooks/trtllm/multiple_choice_logits_processor.py @@ -9,7 +9,7 @@ tokenizer = AutoTokenizer.from_pretrained(args.model_name) llm_tester = TRTLLMTester(args.model_name, args.backend) - lp = mclp = MultipleChoiceLogitsProcessor(tokenizer, choices=["0", "1", "2", "3"]) + lp = MultipleChoiceLogitsProcessor(tokenizer, choices=["0", "1", "2", "3"]) llm_tester.run([args.prompt], logits_processor=lp, max_tokens=1) lp = MultipleChoiceLogitsProcessor(tokenizer, choices=["0", "1", "2", "3"], delimiter=".", boost_first_words=2.0) diff --git a/example_notebooks/trtllm/prevent_hallucination_logits_processor.py b/example_notebooks/trtllm/prevent_hallucination_logits_processor.py new file mode 100644 index 0000000..05d74c9 --- /dev/null +++ b/example_notebooks/trtllm/prevent_hallucination_logits_processor.py @@ -0,0 +1,13 @@ +from transformers import AutoTokenizer +from logits_processor_zoo.trtllm import PreventHallucinationLogitsProcessor +from utils import TRTLLMTester, get_parser + + +if __name__ == "__main__": + args = get_parser() + + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + llm_tester = TRTLLMTester(args.model_name, args.backend) + + lp = PreventHallucinationLogitsProcessor(tokenizer, minp=0.25, tolerate=1) + llm_tester.run([args.prompt], logits_processor=lp) diff --git a/logits_processor_zoo/trtllm/__init__.py b/logits_processor_zoo/trtllm/__init__.py index bba822b..761225f 100644 --- a/logits_processor_zoo/trtllm/__init__.py +++ b/logits_processor_zoo/trtllm/__init__.py @@ -19,6 +19,7 @@ from .last_phrase import ForceLastPhraseLogitsProcessor from .cite_prompt import CiteFromPromptLogitsProcessor from .multiple_choice import MultipleChoiceLogitsProcessor +from .prevent_hallucination import PreventHallucinationLogitsProcessor __all__ = ['GenLengthLogitsProcessor', 'ForceLastPhraseLogitsProcessor', 'CiteFromPromptLogitsProcessor', - 'MultipleChoiceLogitsProcessor'] + 'MultipleChoiceLogitsProcessor', 'PreventHallucinationLogitsProcessor'] diff --git a/logits_processor_zoo/trtllm/cite_prompt.py b/logits_processor_zoo/trtllm/cite_prompt.py index 74acff3..70d994d 100644 --- a/logits_processor_zoo/trtllm/cite_prompt.py +++ b/logits_processor_zoo/trtllm/cite_prompt.py @@ -41,15 +41,16 @@ def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float = 1.0, bo self.eos_token_id = self.tokenizer.eos_token_id self.boost_eos = boost_eos self.conditional_boost_factor = conditional_boost_factor - self.first_token = True - self.prompt_token_ids = list() + self.prompt_token_ids = None + + def _init_before_gen(self, token_ids): + self.prompt_token_ids = list(token_ids[0]) # take first beam since all beams have the same prompt def __call__(self, req_id: int, logits: torch.Tensor, token_ids: List[List[int]], stream_ptr: Optional[int], client_id: Optional[int]) -> None: - if self.first_token: - self.prompt_token_ids = list(token_ids[0]) # take first beam since all beams have the same prompt - self.first_token = False + if self.prompt_token_ids is None: + self._init_before_gen(token_ids) tokens = set(self.prompt_token_ids) if self.boost_eos: diff --git a/logits_processor_zoo/trtllm/last_phrase.py b/logits_processor_zoo/trtllm/last_phrase.py index 10908f7..40e4411 100644 --- a/logits_processor_zoo/trtllm/last_phrase.py +++ b/logits_processor_zoo/trtllm/last_phrase.py @@ -35,16 +35,17 @@ class ForceLastPhraseLogitsProcessor(LogitsProcessor): def __init__(self, phrase: str, tokenizer: PreTrainedTokenizer): self.eos_token_id = tokenizer.eos_token_id self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) - self.first_token = True self.iterators = None + def _init_before_gen(self, beam_width): + self.iterators = torch.zeros(beam_width, dtype=torch.int32) + def __call__(self, req_id: int, logits: torch.Tensor, token_ids: List[List[int]], stream_ptr: Optional[int], client_id: Optional[int]) -> None: beam_width = len(token_ids) - if self.first_token: - self.iterators = torch.zeros(beam_width, dtype=torch.int32) - self.first_token = False + if self.iterators is None: + self._init_before_gen(beam_width) stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) diff --git a/logits_processor_zoo/trtllm/prevent_hallucination.py b/logits_processor_zoo/trtllm/prevent_hallucination.py new file mode 100644 index 0000000..f6bc189 --- /dev/null +++ b/logits_processor_zoo/trtllm/prevent_hallucination.py @@ -0,0 +1,86 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional +import torch +from transformers import PreTrainedTokenizer +from logits_processor_zoo.utils import enforce_tokens +from tensorrt_llm.sampling_params import LogitsProcessor + + +class PreventHallucinationLogitsProcessor(LogitsProcessor): + """ + A logits processor that mitigates hallucinated model outputs by enforcing a predefined fallback phrase + when token confidence falls below a specified threshold. + + This processor monitors token probabilities during generation. If the model produces a number of + low-confidence tokens (below `minp`) exceeding `tolerate`, it begins injecting a fallback phrase + token-by-token to gracefully indicate uncertainty. + + Parameters + ---------- + tokenizer : PreTrainedTokenizer + The tokenizer used by the language model. It is used to tokenize the fallback phrase. + minp : float, optional (default=0.4) + The minimum probability threshold. Tokens with max probability below this are considered low-confidence. + tolerate : int, optional (default=1) + The number of consecutive low-confidence tokens tolerated before triggering the fallback phrase. + phrase : str, optional (default="...I don't know actually.\\n") + The phrase that will be inserted when hallucination is detected. It will be tokenized and injected + sequentially into the generation. + """ + def __init__(self, tokenizer: PreTrainedTokenizer, minp: float = 0.4, tolerate: int = 1, + phrase: str = "...I don't know actually.\n"): + self.phrase = phrase + self.eos_token_id = tokenizer.eos_token_id + self.phrase_tokens = tokenizer.encode(self.phrase, add_special_tokens=False) + self.tokenizer = tokenizer + self.minp = minp + self.tolerate = tolerate + self.iterators = None + self.minp_counts = None + + def _init_before_gen(self, beam_width): + self.iterators = torch.zeros(beam_width, dtype=torch.int32) + self.minp_counts = torch.zeros(beam_width, dtype=torch.int32) + + def __call__(self, req_id: int, logits: torch.Tensor, + token_ids: List[List[int]], stream_ptr: Optional[int], + client_id: Optional[int]) -> None: + beam_width = len(token_ids) + if self.iterators is None: + self._init_before_gen(beam_width) + + beam_width = len(token_ids) + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) + + with torch.cuda.stream(stream): + for i in range(beam_width): # iterate over beams + current_index = self.iterators[i].item() + + if logits[0, i, :].softmax(dim=-1).amax() < self.minp: + self.minp_counts[i] += 1 + + if self.minp_counts[i] > self.tolerate and current_index == 0: + enforce_tokens(logits[0, i], [self.phrase_tokens[current_index]]) + self.iterators[i] += 1 + elif len(self.phrase_tokens) > current_index > 0: + enforce_tokens(logits[0, i], [self.phrase_tokens[current_index]]) + self.iterators[i] += 1 + elif current_index == len(self.phrase_tokens): + self.iterators[i] = 0 + self.minp_counts[i] = 0 From f4b28755b6b2e8fca441e4d0f997a47fbcd036eb Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Mon, 30 Jun 2025 16:08:13 +0300 Subject: [PATCH 09/11] Add trtllm TriggerPhrase LP Signed-off-by: aerdem4 --- .../trigger_phrase_logits_processor.ipynb | 98 ++++++++++++------- example_notebooks/trtllm/README.md | 2 + .../trtllm/trigger_phrase_logits_processor.py | 17 ++++ .../transformers/trigger_phrase.py | 2 +- logits_processor_zoo/trtllm/__init__.py | 3 +- logits_processor_zoo/trtllm/trigger_phrase.py | 79 +++++++++++++++ 6 files changed, 161 insertions(+), 40 deletions(-) create mode 100644 example_notebooks/trtllm/trigger_phrase_logits_processor.py create mode 100644 logits_processor_zoo/trtllm/trigger_phrase.py diff --git a/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb b/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb index 0a37367..fd8c774 100644 --- a/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb +++ b/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb @@ -28,11 +28,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n", - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", - "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" + "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n" ] } ], @@ -70,14 +66,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", - " warnings.warn(\n", - "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", - " warnings.warn(\n", - "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", - " warnings.warn(\n", "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", - "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n", + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" ] }, { @@ -113,9 +104,9 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", "\n", - "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", + "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", "\n", @@ -215,39 +206,64 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", "\n", - "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", + "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", - ",,,\n", + "Wait, but in the problem statement, it says to make it recursive. So, the function as written is recursive, but it's not optimized. So, I think this should be the solution.\n", + "Wait, but the problem says to make it recursive, so perhaps the function should handle larger n efficiently. But without memoization, it's not efficient. So, maybe the function should be written with memoization or an iterative approach.\n", + "\n", + "Alternatively, I can implement it iteratively, which is more efficient. Let's think about that.\n", + "\n", + "An iterative approach would start from the bottom and build up the Fibonacci numbers. For example:\n", + "\n", + "def fibonacci(n):\n", + " if n == 0:\n", + " return 0\n", + " a, b = 0, 1\n", + " for _ in range(2, n+1):\n", + " a, b = b, a + b\n", + " return b\n", + "\n", + "Wait, let's test this. For n=5, the loop runs from 2 to 5, which is 4 iterations. Let's see:\n", + "\n", + "Initial a=0, b=1.\n", + "\n", + "After first iteration (i=2): a becomes 1, b becomes 0+1=1.\n", + "\n", + "After second iteration (i=3): a=1, b=1+1=2.\n", + "\n", + "After third iteration (i=4): a=2, b=1+2=3.\n", + "\n", + "After fourth iteration (i=5): a=3, b=2+3=5.\n", + "\n", + "So, return b=5, which is correct.\n", + "\n", + "This approach is more efficient as it runs in O(n) time and uses O(1) space, which is better for large n.\n", "\n", - "Wait, but the problem says to make it recursive. So, the function should call itself with smaller arguments. The approach I have is correct and recursive.\n", + "But the problem says to make it recursive. So, perhaps the recursive approach is acceptable, but the iterative approach is better for larger n.\n", "\n", - "So, the final function is as I wrote above.\n", - ",,,\n", + "But since the problem doesn't specify, I'll proceed with the recursive approach as it's straightforward.\n", "\n", - "Wait, but in the function, for n=2, it's F(1)+F(0) = 1+0=1, which is correct. For n=3, F(2)+F(1)=1+1=2. So, the function works as expected.\n", + "Wait, but the problem says to make it recursive. So, perhaps the function should be written recursively, but without memoization. So, the function as I wrote earlier is correct.\n", "\n", - "I think this should solve the problem.\n", + "Alternatively, I can write it with memoization using a dictionary to store previously computed values. But that's more advanced and perhaps beyond the scope of this problem.\n", + "\n", + "So, the function I wrote earlier is correct and meets the requirements.\n", "\n", "\n", "To solve this problem, we need to generate the nth Fibonacci number using a recursive approach. The Fibonacci sequence is a series of numbers where each number is the sum of the two preceding ones, starting from 0 and 1. \n", "\n", "### Approach\n", - "The approach to solve this problem involves using recursion, which is a method where a function calls itself with a modified parameter to achieve the desired result. Here's a step-by-step breakdown of the approach:\n", + "The Fibonacci sequence is defined as follows:\n", + "- F(0) = 0\n", + "- F(1) = 1\n", + "- F(n) = F(n-1) + F(n-2) for n >= 2\n", "\n", - "1. **Base Cases**: \n", - " - If `n` is 0, return 0.\n", - " - If `n` is 1, return 1.\n", - " \n", - "2. **Recursive Case**:\n", - " - For any `n` greater than 1, the nth Fibonacci number is the sum of the (n-1)th and (n-2)th Fibonacci numbers. This is achieved by recursively calling the function with `n-1` and `n-2` and adding their results.\n", - "\n", - "This approach ensures that each Fibonacci number is computed by breaking down the problem into smaller subproblems, which are then solved recursively.\n", + "Given the requirement to use a recursive approach, we can define a function that calls itself with smaller values of n until it reaches the base cases. The function will handle the base cases directly and use recursion for the general case.\n", "\n", "### Solution Code\n", - "\n", "```python\n", "def fibonacci(n):\n", " if n == 0:\n", @@ -259,10 +275,16 @@ "```\n", "\n", "### Explanation\n", - "- **Base Cases**: The function first checks if `n` is 0 or 1. If `n` is 0, it returns 0. If `n` is 1, it returns 1. These are the simplest cases of the Fibonacci sequence.\n", - "- **Recursive Case**: For any `n` greater than 1, the function calls itself with `n-1` and `n-2`, and returns the sum of these two recursive calls. This builds up the solution by solving smaller subproblems and combining their results.\n", + "The function `fibonacci` takes an integer `n` as input and returns the nth Fibonacci number. \n", + "\n", + "1. **Base Cases**:\n", + " - If `n` is 0, the function returns 0.\n", + " - If `n` is 1, the function returns 1.\n", + "\n", + "2. **Recursive Case**:\n", + " - For `n >= 2`, the function calls itself with `n-1` and `n-2` and returns the sum of these two recursive calls. This builds up the Fibonacci sequence from the bottom up, ensuring that each value is computed only once.\n", "\n", - "This approach is straightforward and leverages the divide-and-conquer strategy inherent in recursion, making it easy to understand and implement. However, it's important to note that this approach has a time complexity of O(2^n) due to the exponential number of function calls, which is not efficient for large values of `n`. For larger values, an iterative approach or memoization would be more efficient.\n", + "This approach is straightforward and leverages the recursive nature of the Fibonacci sequence, making it easy to understand and implement. However, it's important to note that for very large values of `n`, this approach can be inefficient due to repeated calculations. For larger values, an iterative approach or memoization would be more efficient.\n", "-----END-----\n", "\n" ] @@ -332,9 +354,9 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", "\n", - "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", + "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", "\n", @@ -348,7 +370,7 @@ " return fibonacci(n-1) + fibonacci(n-2)\n", "```\n", "\n", - "This function calculates the nth Fibonacci number using a recursive approach. It handles the base cases where n is 0 or 1 and recursively computes the value for larger n by summing the two preceding Fibonacci numbers.\n", + "This function calculates the nth Fibonacci number using a recursive approach. It handles the base cases where n is 0 or 1 and for other values, it recursively calculates the sum of the two preceding Fibonacci numbers. While this implementation is straightforward, it's not the most efficient for large values of n due to repeated calculations.\n", "-----END-----\n", "\n" ] diff --git a/example_notebooks/trtllm/README.md b/example_notebooks/trtllm/README.md index 1b13983..e268675 100644 --- a/example_notebooks/trtllm/README.md +++ b/example_notebooks/trtllm/README.md @@ -26,4 +26,6 @@ python example_notebooks/trtllm/multiple_choice_logits_processor.py -p "I am get 3. Battery" python example_notebooks/trtllm/prevent_hallucination_logits_processor.py -p "Tell me the Nobel Prizes in 1977" + +python example_notebooks/trtllm/trigger_phrase_logits_processor.py -p "Generate a python function to calculate nth fibonacci number. Make it recursive. Keep thinking short." ``` \ No newline at end of file diff --git a/example_notebooks/trtllm/trigger_phrase_logits_processor.py b/example_notebooks/trtllm/trigger_phrase_logits_processor.py new file mode 100644 index 0000000..faf0b76 --- /dev/null +++ b/example_notebooks/trtllm/trigger_phrase_logits_processor.py @@ -0,0 +1,17 @@ +from transformers import AutoTokenizer +from logits_processor_zoo.trtllm import TriggerPhraseLogitsProcessor +from utils import TRTLLMTester, get_parser + + +if __name__ == "__main__": + args = get_parser() + + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + llm_tester = TRTLLMTester(args.model_name, args.backend) + + lp = TriggerPhraseLogitsProcessor("...Wait, let me think more.", " function", tokenizer, + trigger_count=2, trigger_after=False) + llm_tester.run([args.prompt], logits_processor=lp) + + lp = TriggerPhraseLogitsProcessor("\n```python", " function", tokenizer, trigger_count=1, trigger_after=True) + llm_tester.run([args.prompt], logits_processor=lp) diff --git a/logits_processor_zoo/transformers/trigger_phrase.py b/logits_processor_zoo/transformers/trigger_phrase.py index 3cdf677..949de51 100644 --- a/logits_processor_zoo/transformers/trigger_phrase.py +++ b/logits_processor_zoo/transformers/trigger_phrase.py @@ -55,7 +55,7 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if scores[i, :].argmax() == self.trigger_token and it == -1: self.iterators[i] = 0 if not self.trigger_after: - scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) + scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[0]]) self.iterators[i] += 1 elif len(self.phrase_tokens) > it >= 0: scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) diff --git a/logits_processor_zoo/trtllm/__init__.py b/logits_processor_zoo/trtllm/__init__.py index 761225f..97e5cf2 100644 --- a/logits_processor_zoo/trtllm/__init__.py +++ b/logits_processor_zoo/trtllm/__init__.py @@ -20,6 +20,7 @@ from .cite_prompt import CiteFromPromptLogitsProcessor from .multiple_choice import MultipleChoiceLogitsProcessor from .prevent_hallucination import PreventHallucinationLogitsProcessor +from .trigger_phrase import TriggerPhraseLogitsProcessor __all__ = ['GenLengthLogitsProcessor', 'ForceLastPhraseLogitsProcessor', 'CiteFromPromptLogitsProcessor', - 'MultipleChoiceLogitsProcessor', 'PreventHallucinationLogitsProcessor'] + 'MultipleChoiceLogitsProcessor', 'PreventHallucinationLogitsProcessor', 'TriggerPhraseLogitsProcessor'] diff --git a/logits_processor_zoo/trtllm/trigger_phrase.py b/logits_processor_zoo/trtllm/trigger_phrase.py new file mode 100644 index 0000000..c0aba35 --- /dev/null +++ b/logits_processor_zoo/trtllm/trigger_phrase.py @@ -0,0 +1,79 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional +from transformers import PreTrainedTokenizer +import torch +from logits_processor_zoo.utils import enforce_tokens, text_to_token +from tensorrt_llm.sampling_params import LogitsProcessor + + +class TriggerPhraseLogitsProcessor(LogitsProcessor): + """ + A logits processor which triggers phrases when it encounters a given token. + + Parameters + ---------- + phrase (str): The phrase to be generated by LLM when it encounters the trigger token. + trigger_token_phrase (str): One token phrase in string to trigger phrases. + tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. + trigger_count (int): How many times the phrase will be triggered. + trigger_after (bool): Whether the phrase is written after the trigger token or instead of the trigger token. + """ + def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrainedTokenizer, + trigger_count: int = 1, trigger_after: bool = False): + self.tokenizer = tokenizer + self.trigger_token = text_to_token(self.tokenizer, trigger_token_phrase, last=False) + self.phrase_tokens = self.tokenizer.encode(phrase, add_special_tokens=False) + self.initial_trigger_count = trigger_count + self.trigger_after = trigger_after + self.iterators = None + self.trigger_counts = None + + def _init_before_gen(self, beam_width): + self.iterators = -torch.ones(beam_width, dtype=torch.int32) + self.trigger_counts = self.initial_trigger_count*torch.ones(beam_width, dtype=torch.int32) + + def __call__(self, req_id: int, logits: torch.Tensor, + token_ids: List[List[int]], stream_ptr: Optional[int], + client_id: Optional[int]) -> None: + beam_width = len(token_ids) + if self.iterators is None: + self._init_before_gen(beam_width) + + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) + + with torch.cuda.stream(stream): + for i in range(beam_width): # iterate over beams + if self.trigger_counts[i] <= 0: + continue + + current_index = self.iterators[i].item() + + if logits[0, i].argmax() == self.trigger_token and current_index == -1: + self.iterators[i] = 0 + print("triggering...") + if not self.trigger_after: + enforce_tokens(logits[0, i], [self.phrase_tokens[0]]) + self.iterators[i] += 1 + elif len(self.phrase_tokens) > current_index >= 0: + enforce_tokens(logits[0, i], [self.phrase_tokens[current_index]]) + self.iterators[i] += 1 + + if len(self.phrase_tokens) == self.iterators[i].item(): # phrase completed, reset for next trigger + self.iterators[i] = -1 + self.trigger_counts[i] -= 1 From 98bd4f8dd70b0ecc21d4c01752187fcd34ed2113 Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Wed, 2 Jul 2025 09:36:10 +0300 Subject: [PATCH 10/11] Use pytorch backend Signed-off-by: aerdem4 --- example_notebooks/trtllm/README.md | 32 ++++++++++++++++--- .../trtllm/cite_prompt_logits_processor.py | 2 +- .../trtllm/gen_length_logits_processor.py | 2 +- .../trtllm/last_phrase_logits_processor.py | 2 +- .../multiple_choice_logits_processor.py | 2 +- .../prevent_hallucination_logits_processor.py | 2 +- .../trtllm/trigger_phrase_logits_processor.py | 2 +- example_notebooks/trtllm/utils.py | 12 +++---- logits_processor_zoo/trtllm/trigger_phrase.py | 1 - 9 files changed, 37 insertions(+), 20 deletions(-) diff --git a/example_notebooks/trtllm/README.md b/example_notebooks/trtllm/README.md index e268675..718eb68 100644 --- a/example_notebooks/trtllm/README.md +++ b/example_notebooks/trtllm/README.md @@ -2,30 +2,52 @@ ## Quick Start -Follow this guide to create an engine: -https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html +It's recommended to use [TensorRT-LLM release containers](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags) (>= 0.20.0) that has TensorRT-LLM pre-installed. +Alternatively, please follow [this documentation](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) to install it in [NGC PyTorch containers](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) (>=25.04). ## Examples +### GenLengthLogitsProcessor +A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token based on the length of the generated sequence, encouraging or discouraging shorter answers. ``` python example_notebooks/trtllm/gen_length_logits_processor.py +``` +### CiteFromPromptLogitsProcessor +A logits processor which boosts or diminishes the likelihood of tokens present in the prompt (and optionally EOS token) to encourage the model to generate tokens similar to those seen in the prompt or vice versa. +``` python example_notebooks/trtllm/cite_prompt_logits_processor.py -p "Retrieved information: Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages. Can you shortly describe what Pokémon is?" +``` +### ForceLastPhraseLogitsProcessor +A logits processor which forces LLMs to use the given phrase before they finalize their answers. Most common use cases can be providing references, thanking user with context etc. +``` python example_notebooks/trtllm/last_phrase_logits_processor.py +``` +### MultipleChoiceLogitsProcessor +A logits processor to answer multiple choice questions with one of the choices. +``` python example_notebooks/trtllm/multiple_choice_logits_processor.py -p "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 0. Camera 1. Screen resolution 2. Operating System 3. Battery" +``` -python example_notebooks/trtllm/prevent_hallucination_logits_processor.py -p "Tell me the Nobel Prizes in 1977" - +### TriggerPhraseLogitsProcessor +A logits processor which triggers phrases when it encounters a given token. +``` python example_notebooks/trtllm/trigger_phrase_logits_processor.py -p "Generate a python function to calculate nth fibonacci number. Make it recursive. Keep thinking short." -``` \ No newline at end of file +``` + +### PreventHallucinationLogitsProcessor +A logits processor that mitigates hallucinated model outputs by enforcing a predefined fallback phrase when token confidence falls below a specified threshold. +``` +python example_notebooks/trtllm/prevent_hallucination_logits_processor.py -p "What are Nobel Prizes? Name the winners in 1977" +``` diff --git a/example_notebooks/trtllm/cite_prompt_logits_processor.py b/example_notebooks/trtllm/cite_prompt_logits_processor.py index 9848c54..1f19641 100644 --- a/example_notebooks/trtllm/cite_prompt_logits_processor.py +++ b/example_notebooks/trtllm/cite_prompt_logits_processor.py @@ -7,7 +7,7 @@ args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.model_name) - llm_tester = TRTLLMTester(args.model_name, args.backend) + llm_tester = TRTLLMTester(args.model_name) lp = CiteFromPromptLogitsProcessor(tokenizer, boost_factor=1.0, boost_eos=False, conditional_boost_factor=3.0) llm_tester.run([args.prompt], logits_processor=lp) diff --git a/example_notebooks/trtllm/gen_length_logits_processor.py b/example_notebooks/trtllm/gen_length_logits_processor.py index deb56d5..ba84d6c 100644 --- a/example_notebooks/trtllm/gen_length_logits_processor.py +++ b/example_notebooks/trtllm/gen_length_logits_processor.py @@ -7,7 +7,7 @@ args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.model_name) - llm_tester = TRTLLMTester(args.model_name, args.backend) + llm_tester = TRTLLMTester(args.model_name) lp = GenLengthLogitsProcessor(tokenizer, boost_factor=1.0, complete_sentences=True) llm_tester.run([args.prompt], logits_processor=lp) diff --git a/example_notebooks/trtllm/last_phrase_logits_processor.py b/example_notebooks/trtllm/last_phrase_logits_processor.py index bcbcbbb..4500259 100644 --- a/example_notebooks/trtllm/last_phrase_logits_processor.py +++ b/example_notebooks/trtllm/last_phrase_logits_processor.py @@ -7,7 +7,7 @@ args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.model_name) - llm_tester = TRTLLMTester(args.model_name, args.backend) + llm_tester = TRTLLMTester(args.model_name) phrase = "\n\nThanks for trying our application! If you have more questions about" lp = ForceLastPhraseLogitsProcessor(phrase, tokenizer) diff --git a/example_notebooks/trtllm/multiple_choice_logits_processor.py b/example_notebooks/trtllm/multiple_choice_logits_processor.py index 018ccd7..c7fab14 100644 --- a/example_notebooks/trtllm/multiple_choice_logits_processor.py +++ b/example_notebooks/trtllm/multiple_choice_logits_processor.py @@ -7,7 +7,7 @@ args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.model_name) - llm_tester = TRTLLMTester(args.model_name, args.backend) + llm_tester = TRTLLMTester(args.model_name) lp = MultipleChoiceLogitsProcessor(tokenizer, choices=["0", "1", "2", "3"]) llm_tester.run([args.prompt], logits_processor=lp, max_tokens=1) diff --git a/example_notebooks/trtllm/prevent_hallucination_logits_processor.py b/example_notebooks/trtllm/prevent_hallucination_logits_processor.py index 05d74c9..b6f835e 100644 --- a/example_notebooks/trtllm/prevent_hallucination_logits_processor.py +++ b/example_notebooks/trtllm/prevent_hallucination_logits_processor.py @@ -7,7 +7,7 @@ args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.model_name) - llm_tester = TRTLLMTester(args.model_name, args.backend) + llm_tester = TRTLLMTester(args.model_name) lp = PreventHallucinationLogitsProcessor(tokenizer, minp=0.25, tolerate=1) llm_tester.run([args.prompt], logits_processor=lp) diff --git a/example_notebooks/trtllm/trigger_phrase_logits_processor.py b/example_notebooks/trtllm/trigger_phrase_logits_processor.py index faf0b76..2056bd4 100644 --- a/example_notebooks/trtllm/trigger_phrase_logits_processor.py +++ b/example_notebooks/trtllm/trigger_phrase_logits_processor.py @@ -7,7 +7,7 @@ args = get_parser() tokenizer = AutoTokenizer.from_pretrained(args.model_name) - llm_tester = TRTLLMTester(args.model_name, args.backend) + llm_tester = TRTLLMTester(args.model_name) lp = TriggerPhraseLogitsProcessor("...Wait, let me think more.", " function", tokenizer, trigger_count=2, trigger_after=False) diff --git a/example_notebooks/trtllm/utils.py b/example_notebooks/trtllm/utils.py index 349c931..525191e 100644 --- a/example_notebooks/trtllm/utils.py +++ b/example_notebooks/trtllm/utils.py @@ -4,10 +4,11 @@ class TRTLLMTester: - def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", backend: str = "tensorrt-llm"): - if backend == "pytorch": + def __init__(self, model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"): + # Temporarily attempt to import the torch backend until it becomes default + try: from tensorrt_llm._torch import LLM - else: + except ImportError: from tensorrt_llm import LLM self.llm = LLM(model=model_name) @@ -41,11 +42,6 @@ def get_parser(): type=str, default="Qwen/Qwen2.5-1.5B-Instruct", help="Directory or HF link containing model") - parser.add_argument("--backend", - "-b", - type=str, - default="tensorrt-llm", - help="TensorRT-LLM backend") parser.add_argument("--prompt", "-p", type=str, diff --git a/logits_processor_zoo/trtllm/trigger_phrase.py b/logits_processor_zoo/trtllm/trigger_phrase.py index c0aba35..8afa7b8 100644 --- a/logits_processor_zoo/trtllm/trigger_phrase.py +++ b/logits_processor_zoo/trtllm/trigger_phrase.py @@ -66,7 +66,6 @@ def __call__(self, req_id: int, logits: torch.Tensor, if logits[0, i].argmax() == self.trigger_token and current_index == -1: self.iterators[i] = 0 - print("triggering...") if not self.trigger_after: enforce_tokens(logits[0, i], [self.phrase_tokens[0]]) self.iterators[i] += 1 From deb2c986c1ddad0f7b45fa7140a737481e58a8ae Mon Sep 17 00:00:00 2001 From: aerdem4 Date: Thu, 3 Jul 2025 09:11:32 +0300 Subject: [PATCH 11/11] Rename example_notebooks to examples Signed-off-by: aerdem4 --- .../transformers/cite_prompt_logits_processor.ipynb | 2 +- .../force_last_phrase_logits_processor.ipynb | 2 +- .../transformers/gen_length_logits_processor.ipynb | 2 +- .../multiple_choice_logits_processor.ipynb | 2 +- .../prevent_hallucination_logits_processor.ipynb | 2 +- .../trigger_phrase_logits_processor.ipynb | 2 +- .../transformers/utils.py | 0 {example_notebooks => examples}/trtllm/README.md | 12 ++++++------ .../trtllm/cite_prompt_logits_processor.py | 0 .../trtllm/gen_length_logits_processor.py | 0 .../trtllm/last_phrase_logits_processor.py | 0 .../trtllm/multiple_choice_logits_processor.py | 0 .../trtllm/prevent_hallucination_logits_processor.py | 0 .../trtllm/trigger_phrase_logits_processor.py | 0 {example_notebooks => examples}/trtllm/utils.py | 0 .../vllm/cite_prompt_logits_processor.ipynb | 2 +- .../vllm/force_last_phrase_logits_processor.ipynb | 2 +- .../vllm/gen_length_logits_processor.ipynb | 2 +- .../vllm/multiple_choice_logits_processor.ipynb | 2 +- .../vllm/performance_profiling.ipynb | 2 +- .../prevent_hallucination_logits_processor.ipynb | 2 +- .../vllm/trigger_phrase_logits_processor.ipynb | 2 +- {example_notebooks => examples}/vllm/utils.py | 0 .../vllm/vllm_serve.ipynb | 0 24 files changed, 19 insertions(+), 19 deletions(-) rename {example_notebooks => examples}/transformers/cite_prompt_logits_processor.ipynb (99%) rename {example_notebooks => examples}/transformers/force_last_phrase_logits_processor.ipynb (99%) rename {example_notebooks => examples}/transformers/gen_length_logits_processor.ipynb (99%) rename {example_notebooks => examples}/transformers/multiple_choice_logits_processor.ipynb (99%) rename {example_notebooks => examples}/transformers/prevent_hallucination_logits_processor.ipynb (99%) rename {example_notebooks => examples}/transformers/trigger_phrase_logits_processor.ipynb (99%) rename {example_notebooks => examples}/transformers/utils.py (100%) rename {example_notebooks => examples}/trtllm/README.md (75%) rename {example_notebooks => examples}/trtllm/cite_prompt_logits_processor.py (100%) rename {example_notebooks => examples}/trtllm/gen_length_logits_processor.py (100%) rename {example_notebooks => examples}/trtllm/last_phrase_logits_processor.py (100%) rename {example_notebooks => examples}/trtllm/multiple_choice_logits_processor.py (100%) rename {example_notebooks => examples}/trtllm/prevent_hallucination_logits_processor.py (100%) rename {example_notebooks => examples}/trtllm/trigger_phrase_logits_processor.py (100%) rename {example_notebooks => examples}/trtllm/utils.py (100%) rename {example_notebooks => examples}/vllm/cite_prompt_logits_processor.ipynb (99%) rename {example_notebooks => examples}/vllm/force_last_phrase_logits_processor.ipynb (99%) rename {example_notebooks => examples}/vllm/gen_length_logits_processor.ipynb (99%) rename {example_notebooks => examples}/vllm/multiple_choice_logits_processor.ipynb (99%) rename {example_notebooks => examples}/vllm/performance_profiling.ipynb (99%) rename {example_notebooks => examples}/vllm/prevent_hallucination_logits_processor.ipynb (99%) rename {example_notebooks => examples}/vllm/trigger_phrase_logits_processor.ipynb (99%) rename {example_notebooks => examples}/vllm/utils.py (100%) rename {example_notebooks => examples}/vllm/vllm_serve.ipynb (100%) diff --git a/example_notebooks/transformers/cite_prompt_logits_processor.ipynb b/examples/transformers/cite_prompt_logits_processor.ipynb similarity index 99% rename from example_notebooks/transformers/cite_prompt_logits_processor.ipynb rename to examples/transformers/cite_prompt_logits_processor.ipynb index 68b5145..13e5e30 100644 --- a/example_notebooks/transformers/cite_prompt_logits_processor.ipynb +++ b/examples/transformers/cite_prompt_logits_processor.ipynb @@ -33,7 +33,7 @@ } ], "source": [ - "from example_notebooks.transformers.utils import LLMRunner\n", + "from examples.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import CiteFromPromptLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/transformers/force_last_phrase_logits_processor.ipynb b/examples/transformers/force_last_phrase_logits_processor.ipynb similarity index 99% rename from example_notebooks/transformers/force_last_phrase_logits_processor.ipynb rename to examples/transformers/force_last_phrase_logits_processor.ipynb index 4b1a41b..d1e67aa 100644 --- a/example_notebooks/transformers/force_last_phrase_logits_processor.ipynb +++ b/examples/transformers/force_last_phrase_logits_processor.ipynb @@ -37,7 +37,7 @@ } ], "source": [ - "from example_notebooks.transformers.utils import LLMRunner\n", + "from examples.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import ForceLastPhraseLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/transformers/gen_length_logits_processor.ipynb b/examples/transformers/gen_length_logits_processor.ipynb similarity index 99% rename from example_notebooks/transformers/gen_length_logits_processor.ipynb rename to examples/transformers/gen_length_logits_processor.ipynb index 6c5ebe6..fac3efe 100644 --- a/example_notebooks/transformers/gen_length_logits_processor.ipynb +++ b/examples/transformers/gen_length_logits_processor.ipynb @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "from example_notebooks.transformers.utils import LLMRunner\n", + "from examples.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import GenLengthLogitsProcessor\n", "\n", "example_prompts =[\n", diff --git a/example_notebooks/transformers/multiple_choice_logits_processor.ipynb b/examples/transformers/multiple_choice_logits_processor.ipynb similarity index 99% rename from example_notebooks/transformers/multiple_choice_logits_processor.ipynb rename to examples/transformers/multiple_choice_logits_processor.ipynb index b4f0047..4bb42cf 100644 --- a/example_notebooks/transformers/multiple_choice_logits_processor.ipynb +++ b/examples/transformers/multiple_choice_logits_processor.ipynb @@ -37,7 +37,7 @@ } ], "source": [ - "from example_notebooks.transformers.utils import LLMRunner\n", + "from examples.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import MultipleChoiceLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/transformers/prevent_hallucination_logits_processor.ipynb b/examples/transformers/prevent_hallucination_logits_processor.ipynb similarity index 99% rename from example_notebooks/transformers/prevent_hallucination_logits_processor.ipynb rename to examples/transformers/prevent_hallucination_logits_processor.ipynb index 9dade9c..6502762 100644 --- a/example_notebooks/transformers/prevent_hallucination_logits_processor.ipynb +++ b/examples/transformers/prevent_hallucination_logits_processor.ipynb @@ -33,7 +33,7 @@ } ], "source": [ - "from example_notebooks.transformers.utils import LLMRunner\n", + "from examples.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import PreventHallucinationLogitsProcessor\n", "\n", "runner = LLMRunner()" diff --git a/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb b/examples/transformers/trigger_phrase_logits_processor.ipynb similarity index 99% rename from example_notebooks/transformers/trigger_phrase_logits_processor.ipynb rename to examples/transformers/trigger_phrase_logits_processor.ipynb index fd8c774..72daa80 100644 --- a/example_notebooks/transformers/trigger_phrase_logits_processor.ipynb +++ b/examples/transformers/trigger_phrase_logits_processor.ipynb @@ -33,7 +33,7 @@ } ], "source": [ - "from example_notebooks.transformers.utils import LLMRunner\n", + "from examples.transformers.utils import LLMRunner\n", "from logits_processor_zoo.transformers import TriggerPhraseLogitsProcessor, GenLengthLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/transformers/utils.py b/examples/transformers/utils.py similarity index 100% rename from example_notebooks/transformers/utils.py rename to examples/transformers/utils.py diff --git a/example_notebooks/trtllm/README.md b/examples/trtllm/README.md similarity index 75% rename from example_notebooks/trtllm/README.md rename to examples/trtllm/README.md index 718eb68..530329f 100644 --- a/example_notebooks/trtllm/README.md +++ b/examples/trtllm/README.md @@ -10,13 +10,13 @@ Alternatively, please follow [this documentation](https://nvidia.github.io/Tenso ### GenLengthLogitsProcessor A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token based on the length of the generated sequence, encouraging or discouraging shorter answers. ``` -python example_notebooks/trtllm/gen_length_logits_processor.py +python examples/trtllm/gen_length_logits_processor.py ``` ### CiteFromPromptLogitsProcessor A logits processor which boosts or diminishes the likelihood of tokens present in the prompt (and optionally EOS token) to encourage the model to generate tokens similar to those seen in the prompt or vice versa. ``` -python example_notebooks/trtllm/cite_prompt_logits_processor.py -p "Retrieved information: +python examples/trtllm/cite_prompt_logits_processor.py -p "Retrieved information: Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages. @@ -27,13 +27,13 @@ python example_notebooks/trtllm/cite_prompt_logits_processor.py -p "Retrieved in ### ForceLastPhraseLogitsProcessor A logits processor which forces LLMs to use the given phrase before they finalize their answers. Most common use cases can be providing references, thanking user with context etc. ``` -python example_notebooks/trtllm/last_phrase_logits_processor.py +python examples/trtllm/last_phrase_logits_processor.py ``` ### MultipleChoiceLogitsProcessor A logits processor to answer multiple choice questions with one of the choices. ``` -python example_notebooks/trtllm/multiple_choice_logits_processor.py -p "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? +python examples/trtllm/multiple_choice_logits_processor.py -p "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 0. Camera 1. Screen resolution 2. Operating System @@ -43,11 +43,11 @@ python example_notebooks/trtllm/multiple_choice_logits_processor.py -p "I am get ### TriggerPhraseLogitsProcessor A logits processor which triggers phrases when it encounters a given token. ``` -python example_notebooks/trtllm/trigger_phrase_logits_processor.py -p "Generate a python function to calculate nth fibonacci number. Make it recursive. Keep thinking short." +python examples/trtllm/trigger_phrase_logits_processor.py -p "Generate a python function to calculate nth fibonacci number. Make it recursive. Keep thinking short." ``` ### PreventHallucinationLogitsProcessor A logits processor that mitigates hallucinated model outputs by enforcing a predefined fallback phrase when token confidence falls below a specified threshold. ``` -python example_notebooks/trtllm/prevent_hallucination_logits_processor.py -p "What are Nobel Prizes? Name the winners in 1977" +python examples/trtllm/prevent_hallucination_logits_processor.py -p "What are Nobel Prizes? Name the winners in 1977" ``` diff --git a/example_notebooks/trtllm/cite_prompt_logits_processor.py b/examples/trtllm/cite_prompt_logits_processor.py similarity index 100% rename from example_notebooks/trtllm/cite_prompt_logits_processor.py rename to examples/trtllm/cite_prompt_logits_processor.py diff --git a/example_notebooks/trtllm/gen_length_logits_processor.py b/examples/trtllm/gen_length_logits_processor.py similarity index 100% rename from example_notebooks/trtllm/gen_length_logits_processor.py rename to examples/trtllm/gen_length_logits_processor.py diff --git a/example_notebooks/trtllm/last_phrase_logits_processor.py b/examples/trtllm/last_phrase_logits_processor.py similarity index 100% rename from example_notebooks/trtllm/last_phrase_logits_processor.py rename to examples/trtllm/last_phrase_logits_processor.py diff --git a/example_notebooks/trtllm/multiple_choice_logits_processor.py b/examples/trtllm/multiple_choice_logits_processor.py similarity index 100% rename from example_notebooks/trtllm/multiple_choice_logits_processor.py rename to examples/trtllm/multiple_choice_logits_processor.py diff --git a/example_notebooks/trtllm/prevent_hallucination_logits_processor.py b/examples/trtllm/prevent_hallucination_logits_processor.py similarity index 100% rename from example_notebooks/trtllm/prevent_hallucination_logits_processor.py rename to examples/trtllm/prevent_hallucination_logits_processor.py diff --git a/example_notebooks/trtllm/trigger_phrase_logits_processor.py b/examples/trtllm/trigger_phrase_logits_processor.py similarity index 100% rename from example_notebooks/trtllm/trigger_phrase_logits_processor.py rename to examples/trtllm/trigger_phrase_logits_processor.py diff --git a/example_notebooks/trtllm/utils.py b/examples/trtllm/utils.py similarity index 100% rename from example_notebooks/trtllm/utils.py rename to examples/trtllm/utils.py diff --git a/example_notebooks/vllm/cite_prompt_logits_processor.ipynb b/examples/vllm/cite_prompt_logits_processor.ipynb similarity index 99% rename from example_notebooks/vllm/cite_prompt_logits_processor.ipynb rename to examples/vllm/cite_prompt_logits_processor.ipynb index 8e3877e..5c4ab63 100644 --- a/example_notebooks/vllm/cite_prompt_logits_processor.ipynb +++ b/examples/vllm/cite_prompt_logits_processor.ipynb @@ -70,7 +70,7 @@ } ], "source": [ - "from example_notebooks.vllm.utils import vLLMRunner\n", + "from examples.vllm.utils import vLLMRunner\n", "from logits_processor_zoo.vllm import CiteFromPromptLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/vllm/force_last_phrase_logits_processor.ipynb b/examples/vllm/force_last_phrase_logits_processor.ipynb similarity index 99% rename from example_notebooks/vllm/force_last_phrase_logits_processor.ipynb rename to examples/vllm/force_last_phrase_logits_processor.ipynb index 6c7a69f..2d063a6 100644 --- a/example_notebooks/vllm/force_last_phrase_logits_processor.ipynb +++ b/examples/vllm/force_last_phrase_logits_processor.ipynb @@ -70,7 +70,7 @@ } ], "source": [ - "from example_notebooks.vllm.utils import vLLMRunner\n", + "from examples.vllm.utils import vLLMRunner\n", "from logits_processor_zoo.vllm import ForceLastPhraseLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/vllm/gen_length_logits_processor.ipynb b/examples/vllm/gen_length_logits_processor.ipynb similarity index 99% rename from example_notebooks/vllm/gen_length_logits_processor.ipynb rename to examples/vllm/gen_length_logits_processor.ipynb index 9fd02dc..9b836b3 100644 --- a/example_notebooks/vllm/gen_length_logits_processor.ipynb +++ b/examples/vllm/gen_length_logits_processor.ipynb @@ -87,7 +87,7 @@ } ], "source": [ - "from example_notebooks.vllm.utils import vLLMRunner\n", + "from examples.vllm.utils import vLLMRunner\n", "from logits_processor_zoo.vllm import GenLengthLogitsProcessor\n", "\n", "example_prompts =[\n", diff --git a/example_notebooks/vllm/multiple_choice_logits_processor.ipynb b/examples/vllm/multiple_choice_logits_processor.ipynb similarity index 99% rename from example_notebooks/vllm/multiple_choice_logits_processor.ipynb rename to examples/vllm/multiple_choice_logits_processor.ipynb index d622c99..cd6f85f 100644 --- a/example_notebooks/vllm/multiple_choice_logits_processor.ipynb +++ b/examples/vllm/multiple_choice_logits_processor.ipynb @@ -87,7 +87,7 @@ } ], "source": [ - "from example_notebooks.vllm.utils import vLLMRunner\n", + "from examples.vllm.utils import vLLMRunner\n", "from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/vllm/performance_profiling.ipynb b/examples/vllm/performance_profiling.ipynb similarity index 99% rename from example_notebooks/vllm/performance_profiling.ipynb rename to examples/vllm/performance_profiling.ipynb index f6a2f12..d032077 100644 --- a/example_notebooks/vllm/performance_profiling.ipynb +++ b/examples/vllm/performance_profiling.ipynb @@ -73,7 +73,7 @@ } ], "source": [ - "from example_notebooks.vllm.utils import vLLMRunner\n", + "from examples.vllm.utils import vLLMRunner\n", "from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/vllm/prevent_hallucination_logits_processor.ipynb b/examples/vllm/prevent_hallucination_logits_processor.ipynb similarity index 99% rename from example_notebooks/vllm/prevent_hallucination_logits_processor.ipynb rename to examples/vllm/prevent_hallucination_logits_processor.ipynb index c679ad6..6405b6e 100644 --- a/example_notebooks/vllm/prevent_hallucination_logits_processor.ipynb +++ b/examples/vllm/prevent_hallucination_logits_processor.ipynb @@ -73,7 +73,7 @@ } ], "source": [ - "from example_notebooks.vllm.utils import vLLMRunner\n", + "from examples.vllm.utils import vLLMRunner\n", "from logits_processor_zoo.vllm import PreventHallucinationLogitsProcessor\n", "\n", "runner = vLLMRunner()" diff --git a/example_notebooks/vllm/trigger_phrase_logits_processor.ipynb b/examples/vllm/trigger_phrase_logits_processor.ipynb similarity index 99% rename from example_notebooks/vllm/trigger_phrase_logits_processor.ipynb rename to examples/vllm/trigger_phrase_logits_processor.ipynb index 81ffa61..2f796d4 100644 --- a/example_notebooks/vllm/trigger_phrase_logits_processor.ipynb +++ b/examples/vllm/trigger_phrase_logits_processor.ipynb @@ -89,7 +89,7 @@ } ], "source": [ - "from example_notebooks.vllm.utils import vLLMRunner\n", + "from examples.vllm.utils import vLLMRunner\n", "from logits_processor_zoo.vllm import TriggerPhraseLogitsProcessor, GenLengthLogitsProcessor\n", "\n", "\n", diff --git a/example_notebooks/vllm/utils.py b/examples/vllm/utils.py similarity index 100% rename from example_notebooks/vllm/utils.py rename to examples/vllm/utils.py diff --git a/example_notebooks/vllm/vllm_serve.ipynb b/examples/vllm/vllm_serve.ipynb similarity index 100% rename from example_notebooks/vllm/vllm_serve.ipynb rename to examples/vllm/vllm_serve.ipynb