@@ -311,6 +311,66 @@ def update_last_state_before_save(self, model: nn.Module) -> None:
311311 last_mode .update_for_save (model , last_config , self ._last_metadata )
312312 self ._last_config = last_config
313313
314+ @staticmethod
315+ def validate_modelopt_state (modelopt_state : Any ) -> None :
316+ """Validate that the loaded object is a valid modelopt state file.
317+
318+ Args:
319+ modelopt_state: The loaded object to validate.
320+
321+ Raises:
322+ TypeError: If the loaded object is not a dictionary or has invalid types for nested fields.
323+ ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file.
324+ """
325+ # Validate that the loaded object is a dictionary
326+ if not isinstance (modelopt_state , dict ):
327+ raise TypeError (
328+ f"Expected loaded modelopt state to be a dictionary, "
329+ f"but got { type (modelopt_state ).__name__ } . "
330+ f"The file may not be a valid modelopt state file."
331+ )
332+
333+ # Validate that the dictionary has the expected keys
334+ required_keys = {"modelopt_state_dict" , "modelopt_version" }
335+ missing_keys = required_keys - set (modelopt_state .keys ())
336+ if missing_keys :
337+ raise ValueError (
338+ f"The loaded modelopt state is missing required keys: { missing_keys } . "
339+ f"Expected keys: { required_keys } . "
340+ f"The file may not be a valid modelopt state file."
341+ )
342+
343+ # Validate that modelopt_state_dict is a list
344+ state_dict = modelopt_state ["modelopt_state_dict" ]
345+ if not isinstance (state_dict , list ):
346+ raise TypeError (
347+ f"Expected 'modelopt_state_dict' to be a list, "
348+ f"but got { type (state_dict ).__name__ } . "
349+ f"The file may not be a valid modelopt state file."
350+ )
351+
352+ # Validate that each entry in the state_dict is a tuple with 2 elements
353+ for i , entry in enumerate (state_dict ):
354+ if not isinstance (entry , tuple ) or len (entry ) != 2 :
355+ raise ValueError (
356+ f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, "
357+ f"but entry { i } is { type (entry ).__name__ } with length { len (entry ) if isinstance (entry , (tuple , list )) else 'N/A' } . "
358+ f"The file may not be a valid modelopt state file."
359+ )
360+ mode_name , mode_state = entry
361+ if not isinstance (mode_name , str ):
362+ raise TypeError (
363+ f"Expected mode name (first element of tuple) to be a string, "
364+ f"but got { type (mode_name ).__name__ } at entry { i } . "
365+ f"The file may not be a valid modelopt state file."
366+ )
367+ if not isinstance (mode_state , dict ):
368+ raise TypeError (
369+ f"Expected mode state (second element of tuple) to be a dictionary, "
370+ f"but got { type (mode_state ).__name__ } at entry { i } . "
371+ f"The file may not be a valid modelopt state file."
372+ )
373+
314374
315375class ApplyModeError (RuntimeError ):
316376 """Error raised when applying a mode to a model fails."""
@@ -532,54 +592,8 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic
532592 kwargs .setdefault ("map_location" , "cpu" )
533593 modelopt_state = torch .load (modelopt_state_path , ** kwargs )
534594
535- # Validate that the loaded object is a dictionary
536- if not isinstance (modelopt_state , dict ):
537- raise TypeError (
538- f"Expected loaded modelopt state to be a dictionary, "
539- f"but got { type (modelopt_state ).__name__ } . "
540- f"The file may not be a valid modelopt state file."
541- )
542-
543- # Validate that the dictionary has the expected keys
544- required_keys = {"modelopt_state_dict" , "modelopt_version" }
545- missing_keys = required_keys - set (modelopt_state .keys ())
546- if missing_keys :
547- raise ValueError (
548- f"The loaded modelopt state is missing required keys: { missing_keys } . "
549- f"Expected keys: { required_keys } . "
550- f"The file may not be a valid modelopt state file."
551- )
552-
553- # Validate that modelopt_state_dict is a list
554- state_dict = modelopt_state ["modelopt_state_dict" ]
555- if not isinstance (state_dict , list ):
556- raise TypeError (
557- f"Expected 'modelopt_state_dict' to be a list, "
558- f"but got { type (state_dict ).__name__ } . "
559- f"The file may not be a valid modelopt state file."
560- )
561-
562- # Validate that each entry in the state_dict is a tuple with 2 elements
563- for i , entry in enumerate (state_dict ):
564- if not isinstance (entry , tuple ) or len (entry ) != 2 :
565- raise ValueError (
566- f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, "
567- f"but entry { i } is { type (entry ).__name__ } with length { len (entry ) if isinstance (entry , (tuple , list )) else 'N/A' } . "
568- f"The file may not be a valid modelopt state file."
569- )
570- mode_name , mode_state = entry
571- if not isinstance (mode_name , str ):
572- raise TypeError (
573- f"Expected mode name (first element of tuple) to be a string, "
574- f"but got { type (mode_name ).__name__ } at entry { i } . "
575- f"The file may not be a valid modelopt state file."
576- )
577- if not isinstance (mode_state , dict ):
578- raise TypeError (
579- f"Expected mode state (second element of tuple) to be a dictionary, "
580- f"but got { type (mode_state ).__name__ } at entry { i } . "
581- f"The file may not be a valid modelopt state file."
582- )
595+ # Validate the loaded modelopt state
596+ ModeloptStateManager .validate_modelopt_state (modelopt_state )
583597
584598 return modelopt_state
585599
0 commit comments