@@ -37,7 +37,11 @@ class AppState(Stateful):
3737 """
3838
3939 def __init__ (
40- self , model : nn .Module | list [nn .Module ], optimizer : Optimizer , lr_scheduler : Optional [LRScheduler ] = None
40+ self ,
41+ model : nn .Module | list [nn .Module ],
42+ optimizer : Optimizer ,
43+ lr_scheduler : Optional [LRScheduler ] = None ,
44+ components_to_load : list [StatefulComponents ] | None = None ,
4145 ):
4246 """Initializes the AppState object.
4347
@@ -46,12 +50,29 @@ def __init__(
4650 a non-sharded model, FSDP1 or FSDP2 model.
4751 optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
4852 lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
53+ components_to_load (list[StatefulComponents] | None, optional): The list of components to load from the
54+ checkpoint. If None, all components are loaded. Defaults to None.
4955 """
5056 self ._model_parts = list (model ) if isinstance (model , list ) else [model ]
5157 self ._optimizer = optimizer
5258 self ._lr_scheduler = lr_scheduler
5359 self ._is_loaded = False
5460
61+ # policy for which components to load from the checkpoint. If None, defaults to loading all components.
62+ if components_to_load is None :
63+ self ._components_to_load = [StatefulComponents .MODEL , StatefulComponents .OPTIMIZER ]
64+ if lr_scheduler is not None :
65+ self ._components_to_load .append (StatefulComponents .LR_SCHEDULER )
66+ else :
67+ self ._components_to_load = components_to_load
68+
69+ invalid_components = [c for c in self ._components_to_load if not isinstance (c , StatefulComponents )]
70+ if invalid_components :
71+ raise ValueError (
72+ f"components_to_load must only contain StatefulComponents, but got invalid entries: "
73+ f"{ invalid_components } "
74+ )
75+
5576 @property
5677 def is_loaded (self ) -> bool :
5778 """Returns whether the state dict has been loaded.
@@ -106,12 +127,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
106127 "Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded."
107128 )
108129
109- ModelStateRetriever .load_state_dict_ (app_state = self , state_dict = state_dict [StatefulComponents .MODEL .value ])
110- OptimizerStateRetriever .load_state_dict_ (
111- app_state = self ,
112- state_dict = state_dict [StatefulComponents .OPTIMIZER .value ],
113- )
114- if self ._lr_scheduler is not None :
130+ if StatefulComponents .MODEL in self ._components_to_load :
131+ ModelStateRetriever .load_state_dict_ (app_state = self , state_dict = state_dict [StatefulComponents .MODEL .value ])
132+ if StatefulComponents .OPTIMIZER in self ._components_to_load :
133+ OptimizerStateRetriever .load_state_dict_ (
134+ app_state = self ,
135+ state_dict = state_dict [StatefulComponents .OPTIMIZER .value ],
136+ )
137+ if self ._lr_scheduler is not None and StatefulComponents .LR_SCHEDULER in self ._components_to_load :
115138 LRSchedulerStateRetriever .load_state_dict_ (
116139 app_state = self , state_dict = state_dict [StatefulComponents .LR_SCHEDULER .value ]
117140 )
0 commit comments