@@ -47,7 +47,6 @@ class HuggingFaceModelLoader:
4747 cache_dir: Directory to cache downloaded models.
4848 token: HuggingFace authentication token for private/gated models.
4949 """
50-
5150 def __init__ (self , cache_dir : Optional [str ] = None , token : Optional [str ] = None ):
5251 """Initialize the HuggingFace model loader.
5352
@@ -100,11 +99,7 @@ def load_model(
10099 dtype = self ._get_torch_dtype (torch_dtype ) if torch_dtype else None
101100
102101 # Prepare loading kwargs
103- load_kwargs = {
104- 'cache_dir' : self .cache_dir ,
105- 'revision' : revision ,
106- ** kwargs
107- }
102+ load_kwargs = {'cache_dir' : self .cache_dir , 'revision' : revision , ** kwargs }
108103
109104 # Add token if available
110105 if self .token :
@@ -117,19 +112,15 @@ def load_model(
117112 # Load config (use pre-downloaded config if provided)
118113 if config is None :
119114 logger .info ('Loading model configuration...' )
120- config = AutoConfig .from_pretrained (
121- model_identifier , trust_remote_code = True , ** load_kwargs
122- )
115+ config = AutoConfig .from_pretrained (model_identifier , trust_remote_code = True , ** load_kwargs )
123116 else :
124117 logger .info ('Using pre-downloaded model configuration.' )
125118
126119 # Load tokenizer (may fail for some models, that's ok)
127120 tokenizer = None
128121 try :
129122 logger .info ('Loading tokenizer...' )
130- tokenizer = AutoTokenizer .from_pretrained (
131- model_identifier , trust_remote_code = True , ** load_kwargs
132- )
123+ tokenizer = AutoTokenizer .from_pretrained (model_identifier , trust_remote_code = True , ** load_kwargs )
133124 except Exception as e :
134125 logger .warning (f'Could not load tokenizer: { e } . Continuing without tokenizer.' )
135126
@@ -179,7 +170,9 @@ def load_model(
179170 raise ModelLoadError (f"Unexpected error loading model '{ model_identifier } ': { e } " ) from e
180171
181172 def load_model_from_config (
182- self , config : ModelSourceConfig , device : Optional [str ] = None ,
173+ self ,
174+ config : ModelSourceConfig ,
175+ device : Optional [str ] = None ,
183176 config_pretrained : Optional [PretrainedConfig ] = None ,
184177 ) -> Tuple [PreTrainedModel , PretrainedConfig , AutoTokenizer ]:
185178 """Load a model using ModelSourceConfig.
@@ -197,10 +190,7 @@ def load_model_from_config(
197190 ModelLoadError: If model loading fails.
198191 """
199192 if not config .is_huggingface ():
200- raise ValueError (
201- f"Cannot load model with source '{ config .source } '. "
202- "Use 'huggingface' source."
203- )
193+ raise ValueError (f"Cannot load model with source '{ config .source } '. Use 'huggingface' source." )
204194
205195 # Validate config
206196 is_valid , error = config .validate ()
@@ -244,10 +234,7 @@ def _get_torch_dtype(self, dtype_str: str) -> torch.dtype:
244234 }
245235
246236 if dtype_str .lower () not in dtype_map :
247- raise ValueError (
248- f"Invalid dtype '{ dtype_str } '. "
249- f'Must be one of { list (dtype_map .keys ())} '
250- )
237+ raise ValueError (f"Invalid dtype '{ dtype_str } '.Must be one of { list (dtype_map .keys ())} " )
251238
252239 return dtype_map [dtype_str .lower ()]
253240
@@ -289,9 +276,7 @@ def estimate_param_count_from_config(hf_config) -> Optional[int]:
289276
290277 # Embeddings: token + (optional) position
291278 max_pos = getattr (hf_config , 'max_position_embeddings' , 0 )
292- has_pos_embed = getattr (hf_config , 'position_embedding_type' , None ) not in (
293- 'rotary' , None
294- )
279+ has_pos_embed = getattr (hf_config , 'position_embedding_type' , None ) not in ('rotary' , None )
295280 embed_params = vocab * hidden
296281 if has_pos_embed and max_pos > 0 :
297282 embed_params += max_pos * hidden
@@ -346,7 +331,7 @@ def estimate_memory(param_count, precision_str, mode='training'):
346331 precision_lower = precision_str .lower ()
347332 if precision_lower in ('float16' , 'fp16' , 'bfloat16' , 'bf16' ):
348333 bytes_per_param = 2
349- elif precision_lower in ('int8' ,):
334+ elif precision_lower in ('int8' , ):
350335 bytes_per_param = 1
351336 else :
352337 bytes_per_param = 4
@@ -368,7 +353,7 @@ def estimate_memory(param_count, precision_str, mode='training'):
368353 except ImportError :
369354 logger .warning ('psutil not installed — cannot check system memory. Skipping memory check.' )
370355 return 0 , 0 , True
371- max_gpu_mem = 80 * (1024 ** 3 ) # 80GB — largest common single-GPU memory
356+ max_gpu_mem = 80 * (1024 ** 3 ) # 80GB — largest common single-GPU memory
372357 effective_mem = min (sys_mem , max_gpu_mem )
373358 fits = (estimated_bytes / effective_mem ) < 0.85
374359 return estimated_bytes , effective_mem , fits
0 commit comments