5252 destroy_distributed_environment ,
5353 destroy_model_parallel ,
5454 )
55- from vllm .transformers_utils . tokenizer import get_tokenizer
55+ from vllm .tokenizers import get_tokenizer
5656 from vllm .v1 .engine .async_llm import AsyncEngineArgs , AsyncLLM
5757
5858 logging .getLogger ("vllm" ).propagate = True
@@ -291,7 +291,7 @@ def _create_auto_model(self, config: VLLMModelConfig) -> Optional[LLM]:
291291 # Inferring from the tokenizer will cause vllm to bug for models with mismatches between model
292292 # config and tk config, like mistralai/Mistral-7B-v0.1
293293 if self ._max_length is None :
294- self ._max_length = model .llm_engine .model_config .max_seq_len_to_capture
294+ self ._max_length = model .llm_engine .model_config .max_model_len
295295
296296 return model
297297
@@ -415,9 +415,9 @@ def _generate(
415415 generate : bool = True ,
416416 ) -> list :
417417 """Contains the actual logic of the generation."""
418- sampling_params = SamplingParams (** self .config .generation_parameters .to_vllm_dict ())
419418
420419 if generate :
420+ sampling_params = SamplingParams (** self .config .generation_parameters .to_vllm_dict ())
421421 sampling_params .n = num_samples
422422 sampling_params .max_tokens = max_new_tokens
423423 sampling_params .stop = stop_tokens
@@ -427,17 +427,21 @@ def _generate(
427427 "num_samples > 1 is not supported with temperature=0, please set temperature > 0 or use non sampling metrics."
428428 )
429429 else :
430- sampling_params .temperature = 0
431- sampling_params .prompt_logprobs = 1
432- sampling_params .max_tokens = 1
433- sampling_params .detokenize = False
430+ sampling_params = SamplingParams (
431+ temperature = 0.0 ,
432+ prompt_logprobs = 1 ,
433+ max_tokens = 1 ,
434+ detokenize = False ,
435+ )
434436
435437 if self .data_parallel_size > 1 :
436438
437439 @ray .remote (num_gpus = self .tensor_parallel_size )
438440 def run_inference_one_model (model_args : dict , sampling_params : SamplingParams , requests ):
439441 llm = LLM (** model_args )
440- return llm .generate (prompt_token_ids = requests , sampling_params = sampling_params )
442+ # Convert token IDs to TokensPrompt format for vLLM v0.15+
443+ prompts = [{"prompt_token_ids" : req } for req in requests ]
444+ return llm .generate (prompts = prompts , sampling_params = sampling_params )
441445
442446 # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
443447 # interleaved important to balance context lengths across workers
@@ -454,8 +458,12 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r
454458 if x is not None
455459 ]
456460 else :
461+ from vllm .inputs import TokenInputs
462+
463+ # Convert token IDs to TokensPrompt format for vLLM v0.15+
464+ prompts = [TokenInputs (prompt_token_ids = token_ids ) for token_ids in inputs ]
457465 outputs = self .model .generate (
458- prompt_token_ids = inputs ,
466+ prompts = prompts ,
459467 sampling_params = sampling_params ,
460468 use_tqdm = True ,
461469 )
@@ -489,9 +497,6 @@ def _loglikelihood_tokens(
489497 tokenized_continuations_batch .append (tokenized_continuation )
490498 tokenized_contexts_batch .append (tokenized_context )
491499
492- # Left truncate the inputs to the maximum length
493- if self .max_length : # can be None if the model is initialized with ray
494- inputs = [input [- self .max_length :] for input in inputs ]
495500 outputs = self ._generate (inputs , generate = False )
496501
497502 flat_index = 0
@@ -507,12 +512,18 @@ def _loglikelihood_tokens(
507512 for output , context , continuation in zip (
508513 outputs_doc , tokenized_contexts_doc , tokenized_continuations_doc
509514 ):
515+ actual_input_len = len (output .prompt_token_ids )
516+ continuation_len = len (continuation )
517+ continuation_start_idx = actual_input_len - continuation_len
518+ continuation_prompt_logprobs = output .prompt_logprobs [continuation_start_idx :]
519+
510520 continuation_logprobs = []
511- for token , logprobs in zip (continuation [:: - 1 ], output . prompt_logprobs [:: - 1 ] ):
512- continuation_logprobs .append (logprobs [token ])
521+ for token , logprobs_at_position in zip (continuation , continuation_prompt_logprobs ):
522+ continuation_logprobs .append (logprobs_at_position [token ])
513523
514524 bool_score = all (logprob .rank == 1 for logprob in continuation_logprobs )
515525 continuation_logprobs = [logprob .logprob for logprob in continuation_logprobs ]
526+
516527 continuation_logprobs = sum (continuation_logprobs )
517528 logprobs_doc .append (continuation_logprobs )
518529 argmax_doc .append (bool_score )
@@ -544,6 +555,8 @@ class AsyncVLLMModel(VLLMModel):
544555 is_async = True
545556
546557 def cleanup (self ):
558+ if self .model is not None :
559+ del self .model
547560 gc .collect ()
548561 destroy_distributed_environment ()
549562 torch .cuda .empty_cache ()
@@ -578,7 +591,7 @@ def _create_auto_model(self, config: VLLMModelConfig):
578591
579592 # If the max_length can't get extracted from the config, it will be inferred from the model
580593 if self ._max_length is None :
581- self ._max_length = model .model_config .max_seq_len_to_capture
594+ self ._max_length = model .model_config .max_model_len
582595
583596 return model
584597
0 commit comments