@@ -280,15 +280,14 @@ class VLMTransformersBackend(BaseBackend):
280280 def load_model (self ):
281281 from transformers import AutoModelForImageTextToText , AutoProcessor
282282
283- default_kwargs = {
284- "dtype" : torch .bfloat16 ,
285- "device_map" : "auto" ,
286- "trust_remote_code" : True ,
287- }
288- default_kwargs .update (self .kwargs )
283+ device = decide_device_for_distributed ()
284+ print_with_rank (f"Loading model to device: { device } " )
285+
286+ # Prepare model loading configuration
287+ model_kwargs = self ._prepare_model_kwargs (device )
289288
290289 self .model = AutoModelForImageTextToText .from_pretrained (
291- self .model_path , ** default_kwargs
290+ self .model_path , ** model_kwargs
292291 )
293292
294293 # Freeze the base model
@@ -300,6 +299,24 @@ def load_model(self):
300299 self .model_path , trust_remote_code = True
301300 )
302301
302+ def _prepare_model_kwargs (self , device : str ) -> dict :
303+ """
304+ Prepare keyword arguments for model loading.
305+
306+ Args:
307+ device: Target device for model placement
308+
309+ Returns:
310+ Dictionary of model loading arguments
311+ """
312+ default_kwargs = {
313+ "dtype" : torch .bfloat16 ,
314+ "device_map" : device ,
315+ "trust_remote_code" : True ,
316+ }
317+ default_kwargs .update (self .kwargs )
318+ return default_kwargs
319+
303320 def get_hidden_states_and_logits (
304321 self ,
305322 input_ids : torch .Tensor ,
@@ -317,6 +334,12 @@ def get_hidden_states_and_logits(
317334 Returns:
318335 Tuple of (concatenated_hidden_states, logits)
319336 """
337+ pixel_values = None
338+ image_grid_thw = None
339+ if "pixel_values" in kwargs :
340+ pixel_values = kwargs ["pixel_values" ].squeeze (0 )
341+ if "image_grid_thw" in kwargs :
342+ image_grid_thw = kwargs ["image_grid_thw" ].squeeze (0 )
320343 inputs_embeds_list , position_ids_list = [], []
321344
322345 def hook (module , args , kwargs ):
@@ -336,6 +359,8 @@ def hook(module, args, kwargs):
336359 outputs = self .model (
337360 input_ids ,
338361 attention_mask = attention_mask ,
362+ pixel_values = pixel_values ,
363+ image_grid_thw = image_grid_thw ,
339364 output_hidden_states = True ,
340365 output_logits = True ,
341366 )
@@ -375,6 +400,12 @@ def get_aux_and_target_hiddens(
375400 Returns:
376401 Tuple of (auxiliary_hidden_states, final_hidden_states)
377402 """
403+ pixel_values = None
404+ image_grid_thw = None
405+ if "pixel_values" in kwargs :
406+ pixel_values = kwargs ["pixel_values" ].squeeze (0 )
407+ if "image_grid_thw" in kwargs :
408+ image_grid_thw = kwargs ["image_grid_thw" ].squeeze (0 )
378409 inputs_embeds_list , position_ids_list = [], []
379410
380411 def hook (module , args , kwargs ):
@@ -393,6 +424,8 @@ def hook(module, args, kwargs):
393424 with torch .no_grad ():
394425 outputs = self .model (
395426 input_ids ,
427+ pixel_values = pixel_values ,
428+ image_grid_thw = image_grid_thw ,
396429 attention_mask = attention_mask ,
397430 output_hidden_states = True ,
398431 output_logits = True ,
0 commit comments