-
-
Notifications
You must be signed in to change notification settings - Fork 695
Fix max iters issue and add tests #3439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -128,13 +128,14 @@ def compute_mean_std(engine, batch): | |
|
|
||
| """ | ||
|
|
||
| _state_dict_all_req_keys = ("epoch_length", "max_epochs") | ||
| _state_dict_one_of_opt_keys = ("iteration", "epoch") | ||
| _state_dict_all_req_keys = ("epoch_length",) | ||
| _state_dict_one_of_opt_keys = (("iteration", "epoch"), ("max_epochs", "max_iters")) | ||
|
|
||
| # Flag to disable engine._internal_run as generator feature for BC | ||
| interrupt_resume_enabled = True | ||
|
|
||
| def __init__(self, process_function: Callable[["Engine", Any], Any]): | ||
| super(Engine, self).__init__() | ||
| self._event_handlers: Dict[Any, List] = defaultdict(list) | ||
| self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) | ||
| self._process_function = process_function | ||
|
|
@@ -147,7 +148,6 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): | |
| self.should_terminate_single_epoch: Union[bool, str] = False | ||
| self.should_interrupt = False | ||
| self.state = State() | ||
| self._state_dict_user_keys: List[str] = [] | ||
| self._allowed_events: List[EventEnum] = [] | ||
|
|
||
| self._dataloader_iter: Optional[Iterator[Any]] = None | ||
|
|
@@ -691,14 +691,20 @@ def save_engine(_): | |
| a dictionary containing engine's state | ||
|
|
||
| """ | ||
| keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) | ||
| keys: Tuple[str, ...] = self._state_dict_all_req_keys | ||
| keys += ("iteration",) | ||
| # Include either max_epochs or max_iters based on which was originally set | ||
| if self.state.max_iters is not None: | ||
| keys += ("max_iters",) | ||
| else: | ||
| keys += ("max_epochs",) | ||
| keys += tuple(self._state_dict_user_keys) | ||
| return OrderedDict([(k, getattr(self.state, k)) for k in keys]) | ||
|
|
||
| def load_state_dict(self, state_dict: Mapping) -> None: | ||
| """Setups engine from `state_dict`. | ||
|
|
||
| State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`. | ||
| State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`. | ||
| If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary. | ||
| Iteration and epoch values are 0-based: the first iteration or epoch is zero. | ||
|
|
||
|
|
@@ -709,10 +715,12 @@ def load_state_dict(self, state_dict: Mapping) -> None: | |
|
|
||
| .. code-block:: python | ||
|
|
||
| # Restore from the 4rd epoch | ||
| # Restore from the 4th epoch | ||
|
goanpeca marked this conversation as resolved.
|
||
| state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)} | ||
| # or 500th iteration | ||
| # state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)} | ||
| # or with max_iters | ||
| # state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)} | ||
|
|
||
| trainer = Engine(...) | ||
| trainer.load_state_dict(state_dict) | ||
|
|
@@ -721,22 +729,20 @@ def load_state_dict(self, state_dict: Mapping) -> None: | |
| """ | ||
| super(Engine, self).load_state_dict(state_dict) | ||
|
|
||
| for k in self._state_dict_user_keys: | ||
| if k not in state_dict: | ||
| raise ValueError( | ||
| f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" | ||
| ) | ||
| self.state.max_epochs = state_dict["max_epochs"] | ||
| # Set epoch_length | ||
| self.state.epoch_length = state_dict["epoch_length"] | ||
|
|
||
| # Set user keys | ||
| for k in self._state_dict_user_keys: | ||
| setattr(self.state, k, state_dict[k]) | ||
|
|
||
| # Set iteration or epoch | ||
| if "iteration" in state_dict: | ||
| self.state.iteration = state_dict["iteration"] | ||
| self.state.epoch = 0 | ||
| if self.state.epoch_length is not None: | ||
| if self.state.epoch_length is not None and self.state.epoch_length > 0: | ||
| self.state.epoch = self.state.iteration // self.state.epoch_length | ||
| elif "epoch" in state_dict: | ||
| else: # epoch is in state_dict | ||
| self.state.epoch = state_dict["epoch"] | ||
| if self.state.epoch_length is None: | ||
| raise ValueError( | ||
|
|
@@ -745,6 +751,36 @@ def load_state_dict(self, state_dict: Mapping) -> None: | |
| ) | ||
| self.state.iteration = self.state.epoch_length * self.state.epoch | ||
|
|
||
| # Set max_epochs or max_iters with validation | ||
| max_epochs_value = state_dict.get("max_epochs", None) | ||
| max_iters_value = state_dict.get("max_iters", None) | ||
|
|
||
| # Validate max_epochs if present | ||
| if max_epochs_value is not None: | ||
| if max_epochs_value < 1: | ||
| raise ValueError("max_epochs in state_dict is invalid. Please, set a correct max_epochs positive value") | ||
| if max_epochs_value < self.state.epoch: | ||
| raise ValueError( | ||
| "max_epochs in state_dict should be larger than or equal to the current epoch " | ||
| f"defined in the state: {max_epochs_value} vs {self.state.epoch}. " | ||
| ) | ||
| self.state.max_epochs = max_epochs_value | ||
| else: | ||
| self.state.max_epochs = None | ||
|
|
||
| # Validate max_iters if present | ||
| if max_iters_value is not None: | ||
| if max_iters_value < 1: | ||
| raise ValueError("max_iters in state_dict is invalid. Please, set a correct max_iters positive value") | ||
| if max_iters_value < self.state.iteration: | ||
| raise ValueError( | ||
| "max_iters in state_dict should be larger than or equal to the current iteration " | ||
| f"defined in the state: {max_iters_value} vs {self.state.iteration}. " | ||
| ) | ||
| self.state.max_iters = max_iters_value | ||
| else: | ||
| self.state.max_iters = None | ||
|
|
||
| @staticmethod | ||
| def _is_done(state: State) -> bool: | ||
| is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters | ||
|
|
@@ -756,6 +792,59 @@ def _is_done(state: State) -> bool: | |
| is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs | ||
| return is_done_iters or is_done_count or is_done_epochs | ||
|
|
||
| def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None: | ||
| """Validate and set max_epochs with proper checks.""" | ||
| if max_epochs is not None: | ||
| if max_epochs < 1: | ||
| raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value") | ||
| # Only validate if training is actually done - allow resuming interrupted training | ||
| if self.state.max_epochs is not None and max_epochs < self.state.epoch: | ||
| raise ValueError( | ||
| "Argument max_epochs should be greater than or equal to the start " | ||
| f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. " | ||
| "Please, set engine.state.max_epochs = None " | ||
| "before calling engine.run() in order to restart the training from the beginning." | ||
| ) | ||
| self.state.max_epochs = max_epochs | ||
|
|
||
| def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None: | ||
| """Validate and set max_iters with proper checks.""" | ||
| if max_iters is not None: | ||
| if max_iters < 1: | ||
| raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") | ||
| # Only validate if training is actually done - allow resuming interrupted training | ||
| if (self.state.max_iters is not None) and max_iters < self.state.iteration: | ||
| raise ValueError( | ||
| "Argument max_iters should be greater than or equal to the start " | ||
| f"iteration defined in the state: {max_iters} vs {self.state.iteration}. " | ||
| "Please, set engine.state.max_iters = None " | ||
| "before calling engine.run() in order to restart the training from the beginning." | ||
| ) | ||
| self.state.max_iters = max_iters | ||
|
|
||
| def _check_and_set_epoch_length(self, data: Optional[Iterable], epoch_length: Optional[int] = None) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is not used anywhere. Best to remove it. |
||
| """Validate and set epoch_length.""" | ||
| # Check if we can redefine epoch_length | ||
| if self.state.epoch_length is not None: | ||
| if epoch_length is not None: | ||
| if epoch_length != self.state.epoch_length: | ||
| raise ValueError( | ||
| "Argument epoch_length should be same as in the state, " | ||
| f"but given {epoch_length} vs {self.state.epoch_length}" | ||
| ) | ||
| else: | ||
| if epoch_length is None: | ||
| if data is not None: | ||
| epoch_length = self._get_data_length(data) | ||
|
|
||
| if epoch_length is not None: | ||
| if epoch_length < 1: | ||
| raise ValueError( | ||
| "Argument epoch_length is invalid. Please, either set a correct epoch_length value or " | ||
| "check if input data has non-zero size." | ||
| ) | ||
| self.state.epoch_length = epoch_length | ||
|
|
||
| def set_data(self, data: Union[Iterable, DataLoader]) -> None: | ||
| """Method to set data. After calling the method the next batch passed to `processing_function` is | ||
| from newly provided data. Please, note that epoch length is not modified. | ||
|
|
@@ -854,59 +943,98 @@ def switch_batch(engine): | |
| if data is not None and not isinstance(data, Iterable): | ||
| raise TypeError("Argument data should be iterable") | ||
|
|
||
| if self.state.max_epochs is not None: | ||
| # Check and apply overridden parameters | ||
| if max_epochs is not None: | ||
| if max_epochs < self.state.epoch: | ||
| raise ValueError( | ||
| "Argument max_epochs should be greater than or equal to the start " | ||
| f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. " | ||
| "Please, set engine.state.max_epochs = None " | ||
| "before calling engine.run() in order to restart the training from the beginning." | ||
| ) | ||
| self.state.max_epochs = max_epochs | ||
| if epoch_length is not None: | ||
| if epoch_length != self.state.epoch_length: | ||
| raise ValueError( | ||
| "Argument epoch_length should be same as in the state, " | ||
| f"but given {epoch_length} vs {self.state.epoch_length}" | ||
| ) | ||
| if max_epochs is not None and max_iters is not None: | ||
| raise ValueError( | ||
| "Arguments max_iters and max_epochs are mutually exclusive." | ||
| "Please provide only max_epochs or max_iters." | ||
| ) | ||
|
|
||
| if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None): | ||
| # Create new state | ||
| if epoch_length is None: | ||
| if data is None: | ||
| raise ValueError("epoch_length should be provided if data is None") | ||
| # Check if we need to create new state or resume | ||
| # Create new state if: | ||
| # 1. No termination params set (first run), OR | ||
| # 2. Training is done AND generator is None AND no new params provided | ||
| # 3. Training is done AND same termination params provided (restart case) | ||
| should_create_new_state = ( | ||
| (self.state.max_epochs is None and self.state.max_iters is None) | ||
| or ( | ||
| self._is_done(self.state) | ||
| and self._internal_run_generator is None | ||
| and max_epochs is None | ||
| and max_iters is None | ||
| ) | ||
| or ( | ||
| self._is_done(self.state) | ||
| and self._internal_run_generator is None | ||
| and ( | ||
| (max_epochs is not None and max_epochs == self.state.max_epochs) | ||
| or (max_iters is not None and max_iters == self.state.max_iters) | ||
| ) | ||
| ) | ||
| ) | ||
|
Comment on lines
+952
to
+973
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should discuss this more. Why are we creating a new state only if the same parameters passed and resuming otherwise? I personally think the master behavior of always restarting is much more intutive. Current proposed change can be misleading for the user. For instance, on master, |
||
|
|
||
| epoch_length = self._get_data_length(data) | ||
| if epoch_length is not None and epoch_length < 1: | ||
| raise ValueError("Input data has zero size. Please provide non-empty data") | ||
| if should_create_new_state: | ||
| # Create new state | ||
| if data is None and epoch_length is None and self.state.epoch_length is None: | ||
| raise ValueError("epoch_length should be provided if data is None") | ||
|
|
||
| # Set epoch_length for new state | ||
| if epoch_length is None: | ||
| # Try to get from data first, then fall back to existing state | ||
| if data is not None: | ||
| epoch_length = self._get_data_length(data) | ||
| if epoch_length is None and self.state.epoch_length is not None: | ||
| epoch_length = self.state.epoch_length | ||
| if epoch_length is not None and epoch_length < 1: | ||
| raise ValueError("Input data has zero size. Please provide non-empty data") | ||
|
|
||
| # Determine max_epochs/max_iters | ||
| if max_iters is None: | ||
| if max_epochs is None: | ||
| max_epochs = 1 | ||
| else: | ||
| if max_epochs is not None: | ||
| raise ValueError( | ||
| "Arguments max_iters and max_epochs are mutually exclusive." | ||
| "Please provide only max_epochs or max_iters." | ||
| ) | ||
| if epoch_length is not None: | ||
| max_epochs = math.ceil(max_iters / epoch_length) | ||
|
|
||
| # Initialize new state | ||
| self.state.iteration = 0 | ||
| self.state.epoch = 0 | ||
| self.state.max_epochs = max_epochs | ||
| self.state.max_iters = max_iters | ||
| self.state.epoch_length = epoch_length | ||
| # Reset generator if previously used | ||
| self._internal_run_generator = None | ||
| self.logger.info(f"Engine run starting with max_epochs={max_epochs}.") | ||
|
|
||
| # Log start message | ||
| if self.state.max_epochs is not None: | ||
| self.logger.info(f"Engine run starting with max_epochs={self.state.max_epochs}.") | ||
| else: | ||
| self.logger.info(f"Engine run starting with max_iters={self.state.max_iters}.") | ||
| else: | ||
| self.logger.info( | ||
| f"Engine run resuming from iteration {self.state.iteration}, " | ||
| f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" | ||
| ) | ||
| # Resume from existing state | ||
| # Apply overridden parameters using helper methods | ||
| self._check_and_set_max_epochs(max_epochs) | ||
| self._check_and_set_max_iters(max_iters) | ||
|
|
||
| # Handle epoch_length validation (simplified from original) | ||
| if epoch_length is not None: | ||
| if epoch_length != self.state.epoch_length: | ||
| raise ValueError( | ||
| "Argument epoch_length should be same as in the state, " | ||
| f"but given {epoch_length} vs {self.state.epoch_length}" | ||
| ) | ||
|
|
||
| # Log resuming message | ||
| if self.state.max_epochs is not None: | ||
| self.logger.info( | ||
| f"Engine run resuming from iteration {self.state.iteration}, " | ||
| f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" | ||
| ) | ||
| else: | ||
| self.logger.info( | ||
| f"Engine run resuming from iteration {self.state.iteration}, " | ||
| f"epoch {self.state.epoch} until {self.state.max_iters} iterations" | ||
| ) | ||
|
|
||
| if self.state.epoch_length is None and data is None: | ||
| raise ValueError("epoch_length should be provided if data is None") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some subclasses like
EarlyStopping,Checkpoint,Metricnever callsuper().__init__(). I think a better approach would be to revert it so_state_dict_user_keysstays in Engine as not all subclasses need this.