@@ -160,7 +160,10 @@ def __call__(self, hooks, model_id, model, execution_device):
160160 if len (hooks ) == 0 :
161161 return []
162162
163- current_module_size = model .get_memory_footprint ()
163+ try :
164+ current_module_size = model .get_memory_footprint ()
165+ except AttributeError :
166+ raise AttributeError (f"Do not know how to compute memory footprint of `{ model .__class__ .__name__ } ." )
164167
165168 device_type = execution_device .type
166169 device_module = getattr (torch , device_type , torch .cuda )
@@ -703,19 +706,28 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
703706 if not is_accelerate_available ():
704707 raise ImportError ("Make sure to install accelerate to use auto_cpu_offload" )
705708
706- # TODO: add a warning if mem_get_info isn't available on `device`.
709+ if device is None :
710+ device = get_device ()
711+ if not isinstance (device , torch .device ):
712+ device = torch .device (device )
713+
714+ device_type = device .type
715+ device_module = getattr (torch , device_type , torch .cuda )
716+ if not hasattr (device_module , "mem_get_info" ):
717+ raise NotImplementedError (
718+ f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for { str (device .type )} ."
719+ )
720+
721+ if device .index is None :
722+ device = torch .device (f"{ device .type } :{ 0 } " )
707723
708724 for name , component in self .components .items ():
709725 if isinstance (component , torch .nn .Module ) and hasattr (component , "_hf_hook" ):
710726 remove_hook_from_module (component , recurse = True )
711727
712728 self .disable_auto_cpu_offload ()
713729 offload_strategy = AutoOffloadStrategy (memory_reserve_margin = memory_reserve_margin )
714- if device is None :
715- device = get_device ()
716- device = torch .device (device )
717- if device .index is None :
718- device = torch .device (f"{ device .type } :{ 0 } " )
730+
719731 all_hooks = []
720732 for name , component in self .components .items ():
721733 if isinstance (component , torch .nn .Module ):
0 commit comments