@@ -324,17 +324,18 @@ def generate_language_model(
324324 `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
325325 The sequence of generated hidden-states.
326326 """
327- cache_position_kwargs = {}
328- if is_transformers_version ("<" , "4.52.1" ):
329- cache_position_kwargs ["input_ids" ] = inputs_embeds
330- else :
331- cache_position_kwargs ["seq_length" ] = inputs_embeds .shape [0 ]
332- cache_position_kwargs ["device" ] = (
333- self .language_model .device if getattr (self , "language_model" , None ) is not None else self .device
334- )
335- cache_position_kwargs ["model_kwargs" ] = model_kwargs
336327 max_new_tokens = max_new_tokens if max_new_tokens is not None else self .language_model .config .max_new_tokens
337- model_kwargs = self .language_model ._get_initial_cache_position (** cache_position_kwargs )
328+ if hasattr (self .language_model , "_get_initial_cache_position" ):
329+ cache_position_kwargs = {}
330+ if is_transformers_version ("<" , "4.52.1" ):
331+ cache_position_kwargs ["input_ids" ] = inputs_embeds
332+ else :
333+ cache_position_kwargs ["seq_length" ] = inputs_embeds .shape [0 ]
334+ cache_position_kwargs ["device" ] = (
335+ self .language_model .device if getattr (self , "language_model" , None ) is not None else self .device
336+ )
337+ cache_position_kwargs ["model_kwargs" ] = model_kwargs
338+ model_kwargs = self .language_model ._get_initial_cache_position (** cache_position_kwargs )
338339
339340 for _ in range (max_new_tokens ):
340341 # prepare model inputs
0 commit comments