diff --git a/mipcandy/common/optim/loss.py b/mipcandy/common/optim/loss.py index 34635d3..ffe49de 100644 --- a/mipcandy/common/optim/loss.py +++ b/mipcandy/common/optim/loss.py @@ -58,7 +58,7 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T if not self.include_background: outputs = outputs[:, 1:] labels = labels[:, 1:] - dice = soft_dice(outputs, labels, smooth=self.smooth) + dice = soft_dice(outputs.float(), labels.float(), smooth=self.smooth) metrics = {"soft dice": dice.item(), "ce loss": ce.item()} c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - dice) return c, metrics @@ -88,10 +88,10 @@ def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1, self.min_percentage_per_class: float | None = min_percentage_per_class def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + labels = labels.to(dtype=outputs.dtype) + bce = nn.functional.binary_cross_entropy_with_logits(outputs, labels) outputs = outputs.sigmoid() - labels = labels.float() - bce = nn.functional.binary_cross_entropy(outputs, labels) - dice = soft_dice(outputs, labels, smooth=self.smooth) + dice = soft_dice(outputs.float(), labels.float(), smooth=self.smooth) metrics = {"soft dice": dice.item(), "bce loss": bce.item()} c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - dice) return c, metrics diff --git a/mipcandy/presets/segmentation.py b/mipcandy/presets/segmentation.py index 68f139b..2f6b707 100644 --- a/mipcandy/presets/segmentation.py +++ b/mipcandy/presets/segmentation.py @@ -121,7 +121,9 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT outputs = list(torch.unbind(outputs, dim=1)) labels = self.prepare_deep_supervision_targets(labels, [m.shape[2:] for m in outputs]) loss, metrics = toolbox.criterion(outputs, labels) - loss.backward() + self._do_backward(loss, toolbox) + if toolbox.scaler: + toolbox.scaler.unscale_(toolbox.optimizer) nn.utils.clip_grad_norm_(toolbox.model.parameters(), 12) return loss.item(), metrics diff --git a/mipcandy/training.py b/mipcandy/training.py index ffd3311..d51a5ac 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -50,6 +50,7 @@ class TrainerToolbox(object): scheduler: optim.lr_scheduler.LRScheduler criterion: nn.Module ema: nn.Module | None = None + scaler: torch.amp.GradScaler | None = None @dataclass @@ -85,11 +86,14 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer **training_arguments) -> None: if self._unrecoverable: return - torch.save({ + state_dicts = { "optimizer": toolbox.optimizer.state_dict(), "scheduler": toolbox.scheduler.state_dict(), "criterion": toolbox.criterion.state_dict() - }, f"{self.experiment_folder()}/state_dicts.pth") + } + if toolbox.scaler: + state_dicts["scaler"] = toolbox.scaler.state_dict() + torch.save(state_dicts, f"{self.experiment_folder()}/state_dicts.pth") with open(f"{self.experiment_folder()}/state_orb.json", "w") as f: dump({"tracker": asdict(tracker), "training_arguments": training_arguments}, f) @@ -116,6 +120,9 @@ def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_m toolbox.optimizer.load_state_dict(state_dicts["optimizer"]) toolbox.scheduler.load_state_dict(state_dicts["scheduler"]) toolbox.criterion.load_state_dict(state_dicts["criterion"]) + if "scaler" in state_dicts: + toolbox.scaler = torch.amp.GradScaler(self._device_type()) + toolbox.scaler.load_state_dict(state_dicts["scaler"]) return toolbox def recover_from(self, experiment_id: str) -> Self: @@ -391,6 +398,12 @@ def empty_cache(self) -> None: # Training methods + def _do_backward(self, loss: torch.Tensor, toolbox: TrainerToolbox) -> None: + if toolbox.scaler: + toolbox.scaler.scale(loss).backward() + else: + loss.backward() + def sanity_check(self, template_model: nn.Module, example_shape: AmbiguousShape) -> SanityCheckResult: try: return sanity_check(template_model, example_shape, device=self._device) @@ -402,14 +415,27 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT str, float]]: raise NotImplementedError + def _device_type(self) -> str: + return self._device.type if isinstance(self._device, torch.device) else str(self._device).split(":")[0] + def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ str, float]]: toolbox.optimizer.zero_grad() - loss, metrics = self.backward(images, labels, toolbox) - toolbox.optimizer.step() - toolbox.scheduler.step() - if toolbox.ema: - toolbox.ema.update_parameters(toolbox.model) + with torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None): + loss, metrics = self.backward(images, labels, toolbox) + if toolbox.scaler: + old_scale = toolbox.scaler.get_scale() + toolbox.scaler.step(toolbox.optimizer) + toolbox.scaler.update() + if old_scale <= toolbox.scaler.get_scale(): + toolbox.scheduler.step() + if toolbox.ema: + toolbox.ema.update_parameters(toolbox.model) + else: + toolbox.optimizer.step() + toolbox.scheduler.step() + if toolbox.ema: + toolbox.ema.update_parameters(toolbox.model) return loss, metrics def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]: @@ -440,7 +466,7 @@ def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]: def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, compile_model: bool = True, ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5, val_score_prediction: bool = True, val_score_prediction_degree: int = 5, save_preview: bool = True, - preview_quality: float = .75) -> None: + preview_quality: float = .75, amp: bool = False) -> None: training_arguments = self.filter_train_params(**locals()) self.init_experiment() if note: @@ -468,6 +494,12 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co toolbox = (self.load_toolbox if self.recovery() else self.build_toolbox)( num_epochs, example_shape, compile_model, ema ) + if amp and not toolbox.scaler: + toolbox.scaler = torch.amp.GradScaler(self._device_type()) + self.log("Mixed precision training enabled") + elif not amp and toolbox.scaler: + toolbox.scaler = None + self.log("Mixed precision training disabled") checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth" es_tolerance = early_stop_tolerance self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note, @@ -550,7 +582,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co def filter_train_params(**kwargs) -> dict[str, Setting]: return {k: v for k, v in kwargs.items() if k in ( "note", "num_checkpoints", "compile_model", "ema", "seed", "early_stop_tolerance", "val_score_prediction", - "val_score_prediction_degree", "save_preview", "preview_quality" + "val_score_prediction_degree", "save_preview", "preview_quality", "amp" )} def train_with_settings(self, num_epochs: int, **kwargs) -> None: @@ -580,7 +612,7 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float worst_score = float("+inf") metrics = {} num_cases = len(self._validation_dataloader) - with torch.no_grad(), Progress( + with torch.no_grad(), torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None), Progress( *Progress.get_default_columns(), SpinnerColumn(), console=self._console ) as progress: task = progress.add_task(f"Validating", total=num_cases)