@@ -105,6 +105,12 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
105105 or not isinstance (model_args .torch_dtype , str )
106106 else getattr (torch , model_args .torch_dtype )
107107 )
108+ # NOTE for models that cannot fit in 1 GPU, keep it on CPU and use block-wise calibration.
109+ # or leverage HF's device_map="auto", BUT tracing will not work properly with "auto"
110+ total_gpu_memory = 1e-5
111+ if torch .cuda .is_available ():
112+ total_gpu_memory = torch .cuda .get_device_properties (0 ).total_memory / 1e9
113+
108114 model = AutoModelForCausalLM .from_pretrained (
109115 model_args .model_name_or_path ,
110116 from_tf = bool (".ckpt" in model_args .model_name_or_path ),
@@ -113,8 +119,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
113119 revision = "main" ,
114120 use_auth_token = True if model_args .use_auth_token else None ,
115121 torch_dtype = torch_dtype ,
116- low_cpu_mem_usage = model_args .low_cpu_mem_usage ,
117- device_map = "auto" if model_args .low_cpu_mem_usage else None ,
122+ device_map = model_args .device_map ,
123+ low_cpu_mem_usage = bool ( model_args .device_map ) ,
118124 )
119125
120126 embedding_size = model .get_input_embeddings ().weight .shape [0 ]
@@ -125,11 +131,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
125131 logger .info (f"Model is at { model .device } after intialization" )
126132 logger .info (f"Tokenizer is { tokenizer } , block size is { block_size } " )
127133 qcfg = qconfig_init (recipe = "dq" , args = fms_mo_args )
128- # for models that cannot fit in 1 GPU, keep it on CPU and use block-wise calibration.
129- # or leverage HF's device_map="auto"
130- total_gpu_memory = 1e-5
131- if torch .cuda .is_available ():
132- total_gpu_memory = torch .cuda .get_device_properties (0 ).total_memory / 1e9
134+
133135 model_size = model_size_Wb (model , unit = "GB" )
134136 gpu_mem_util_per = model_size / total_gpu_memory
135137
@@ -145,7 +147,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
145147 name in model_args .model_name_or_path for name in known_large_models
146148 ) or (gpu_mem_util_per > 0.7 )
147149 dev = "cpu" if qcfg ["large_model" ] else "cuda"
148- model .to (dev )
150+ if model_args .device_map is None :
151+ model .to (dev )
149152
150153 if hasattr (model .config , "model_type" ):
151154 qcfg ["model_type" ] = model .config .model_type
0 commit comments