Skip to content

Commit 0b0d65f

Browse files
committed
address review feedback on #1074
1 parent cf73491 commit 0b0d65f

1 file changed

Lines changed: 62 additions & 48 deletions

File tree

modelopt/torch/opt/conversion.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,66 @@ def update_last_state_before_save(self, model: nn.Module) -> None:
311311
last_mode.update_for_save(model, last_config, self._last_metadata)
312312
self._last_config = last_config
313313

314+
@staticmethod
315+
def validate_modelopt_state(modelopt_state: Any) -> None:
316+
"""Validate that the loaded object is a valid modelopt state file.
317+
318+
Args:
319+
modelopt_state: The loaded object to validate.
320+
321+
Raises:
322+
TypeError: If the loaded object is not a dictionary or has invalid types for nested fields.
323+
ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file.
324+
"""
325+
# Validate that the loaded object is a dictionary
326+
if not isinstance(modelopt_state, dict):
327+
raise TypeError(
328+
f"Expected loaded modelopt state to be a dictionary, "
329+
f"but got {type(modelopt_state).__name__}. "
330+
f"The file may not be a valid modelopt state file."
331+
)
332+
333+
# Validate that the dictionary has the expected keys
334+
required_keys = {"modelopt_state_dict", "modelopt_version"}
335+
missing_keys = required_keys - set(modelopt_state.keys())
336+
if missing_keys:
337+
raise ValueError(
338+
f"The loaded modelopt state is missing required keys: {missing_keys}. "
339+
f"Expected keys: {required_keys}. "
340+
f"The file may not be a valid modelopt state file."
341+
)
342+
343+
# Validate that modelopt_state_dict is a list
344+
state_dict = modelopt_state["modelopt_state_dict"]
345+
if not isinstance(state_dict, list):
346+
raise TypeError(
347+
f"Expected 'modelopt_state_dict' to be a list, "
348+
f"but got {type(state_dict).__name__}. "
349+
f"The file may not be a valid modelopt state file."
350+
)
351+
352+
# Validate that each entry in the state_dict is a tuple with 2 elements
353+
for i, entry in enumerate(state_dict):
354+
if not isinstance(entry, tuple) or len(entry) != 2:
355+
raise ValueError(
356+
f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, "
357+
f"but entry {i} is {type(entry).__name__} with length {len(entry) if isinstance(entry, (tuple, list)) else 'N/A'}. "
358+
f"The file may not be a valid modelopt state file."
359+
)
360+
mode_name, mode_state = entry
361+
if not isinstance(mode_name, str):
362+
raise TypeError(
363+
f"Expected mode name (first element of tuple) to be a string, "
364+
f"but got {type(mode_name).__name__} at entry {i}. "
365+
f"The file may not be a valid modelopt state file."
366+
)
367+
if not isinstance(mode_state, dict):
368+
raise TypeError(
369+
f"Expected mode state (second element of tuple) to be a dictionary, "
370+
f"but got {type(mode_state).__name__} at entry {i}. "
371+
f"The file may not be a valid modelopt state file."
372+
)
373+
314374

315375
class ApplyModeError(RuntimeError):
316376
"""Error raised when applying a mode to a model fails."""
@@ -532,54 +592,8 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic
532592
kwargs.setdefault("map_location", "cpu")
533593
modelopt_state = torch.load(modelopt_state_path, **kwargs)
534594

535-
# Validate that the loaded object is a dictionary
536-
if not isinstance(modelopt_state, dict):
537-
raise TypeError(
538-
f"Expected loaded modelopt state to be a dictionary, "
539-
f"but got {type(modelopt_state).__name__}. "
540-
f"The file may not be a valid modelopt state file."
541-
)
542-
543-
# Validate that the dictionary has the expected keys
544-
required_keys = {"modelopt_state_dict", "modelopt_version"}
545-
missing_keys = required_keys - set(modelopt_state.keys())
546-
if missing_keys:
547-
raise ValueError(
548-
f"The loaded modelopt state is missing required keys: {missing_keys}. "
549-
f"Expected keys: {required_keys}. "
550-
f"The file may not be a valid modelopt state file."
551-
)
552-
553-
# Validate that modelopt_state_dict is a list
554-
state_dict = modelopt_state["modelopt_state_dict"]
555-
if not isinstance(state_dict, list):
556-
raise TypeError(
557-
f"Expected 'modelopt_state_dict' to be a list, "
558-
f"but got {type(state_dict).__name__}. "
559-
f"The file may not be a valid modelopt state file."
560-
)
561-
562-
# Validate that each entry in the state_dict is a tuple with 2 elements
563-
for i, entry in enumerate(state_dict):
564-
if not isinstance(entry, tuple) or len(entry) != 2:
565-
raise ValueError(
566-
f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, "
567-
f"but entry {i} is {type(entry).__name__} with length {len(entry) if isinstance(entry, (tuple, list)) else 'N/A'}. "
568-
f"The file may not be a valid modelopt state file."
569-
)
570-
mode_name, mode_state = entry
571-
if not isinstance(mode_name, str):
572-
raise TypeError(
573-
f"Expected mode name (first element of tuple) to be a string, "
574-
f"but got {type(mode_name).__name__} at entry {i}. "
575-
f"The file may not be a valid modelopt state file."
576-
)
577-
if not isinstance(mode_state, dict):
578-
raise TypeError(
579-
f"Expected mode state (second element of tuple) to be a dictionary, "
580-
f"but got {type(mode_state).__name__} at entry {i}. "
581-
f"The file may not be a valid modelopt state file."
582-
)
595+
# Validate the loaded modelopt state
596+
ModeloptStateManager.validate_modelopt_state(modelopt_state)
583597

584598
return modelopt_state
585599

0 commit comments

Comments
 (0)