Skip to content

Commit fc1350a

Browse files
authored
7107 Add support to validate at training start (#7108)
Fixes #7107 . ### Description This PR added support to optionally execute validation at training start first, this is useful for transfer learning to validate the initial model. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Nic Ma <nma@nvidia.com>
1 parent 5c6d199 commit fc1350a

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

monai/handlers/validation_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,19 @@ class ValidationHandler:
3131
3232
"""
3333

34-
def __init__(self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True) -> None:
34+
def __init__(
35+
self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True, exec_at_start: bool = False
36+
) -> None:
3537
"""
3638
Args:
3739
interval: do validation every N epochs or every N iterations during training.
3840
validator: run the validator when trigger validation, suppose to be Evaluator.
3941
if None, should call `set_validator()` before training.
4042
epoch_level: execute validation every N epochs or N iterations.
4143
`True` is epoch level, `False` is iteration level.
44+
exec_at_start: whether to execute a validation first when starting the training.
45+
default to `False`. It can be useful especially for some transfer-learning cases
46+
to validate the initial model before training.
4247
4348
Raises:
4449
TypeError: When ``validator`` is not a ``monai.engines.evaluator.Evaluator``.
@@ -49,6 +54,7 @@ def __init__(self, interval: int, validator: Evaluator | None = None, epoch_leve
4954
self.validator = validator
5055
self.interval = interval
5156
self.epoch_level = epoch_level
57+
self.exec_at_start = exec_at_start
5258

5359
def set_validator(self, validator: Evaluator) -> None:
5460
"""
@@ -67,6 +73,8 @@ def attach(self, engine: Engine) -> None:
6773
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self)
6874
else:
6975
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self)
76+
if self.exec_at_start:
77+
engine.add_event_handler(Events.STARTED, self)
7078

7179
def __call__(self, engine: Engine) -> None:
7280
"""

tests/test_handler_validation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ def _train_func(engine, batch):
3939
# set up testing handler
4040
val_data_loader = torch.utils.data.DataLoader(Dataset(data))
4141
evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader)
42-
saver = ValidationHandler(interval=2, validator=evaluator)
43-
saver.attach(engine)
42+
ValidationHandler(interval=2, validator=evaluator, exec_at_start=True).attach(engine)
43+
# test execution at start
44+
engine.run(data, max_epochs=1)
45+
self.assertEqual(evaluator.state.max_epochs, 0)
46+
self.assertEqual(evaluator.state.epoch_length, 8)
4447

4548
engine.run(data, max_epochs=5)
4649
self.assertEqual(evaluator.state.max_epochs, 4)

0 commit comments

Comments
 (0)