From 587c79b3f4c72c81ff334987de142a0f4c44f6da Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Wed, 18 Feb 2026 19:44:23 -0500 Subject: [PATCH 1/6] Support mixed precision training (#179) --- mipcandy/common/optim/loss.py | 2 +- mipcandy/presets/segmentation.py | 4 +++- mipcandy/training.py | 38 ++++++++++++++++++++++++++------ 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/mipcandy/common/optim/loss.py b/mipcandy/common/optim/loss.py index 2bfcf86..9b802d1 100644 --- a/mipcandy/common/optim/loss.py +++ b/mipcandy/common/optim/loss.py @@ -94,7 +94,7 @@ def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1, def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: outputs = outputs.sigmoid() - labels = labels.float() + labels = labels.to(dtype=outputs.dtype) bce = nn.functional.binary_cross_entropy(outputs, labels) soft_dice = soft_dice_coefficient(outputs, labels, smooth=self.smooth) metrics = {"soft dice": soft_dice.item(), "bce loss": bce.item()} diff --git a/mipcandy/presets/segmentation.py b/mipcandy/presets/segmentation.py index a101fae..2dc7552 100644 --- a/mipcandy/presets/segmentation.py +++ b/mipcandy/presets/segmentation.py @@ -121,7 +121,9 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT outputs = list(torch.unbind(outputs, dim=1)) labels = self.prepare_deep_supervision_targets(labels, [m.shape[2:] for m in outputs]) loss, metrics = toolbox.criterion(outputs, labels) - loss.backward() + self._do_backward(loss, toolbox) + if toolbox.scaler: + toolbox.scaler.unscale_(toolbox.optimizer) nn.utils.clip_grad_norm_(toolbox.model.parameters(), 12) return loss.item(), metrics diff --git a/mipcandy/training.py b/mipcandy/training.py index d71469c..251c9a4 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -50,6 +50,7 @@ class TrainerToolbox(object): scheduler: optim.lr_scheduler.LRScheduler criterion: nn.Module ema: nn.Module | None = None + scaler: torch.amp.GradScaler | None = None @dataclass @@ -85,11 +86,14 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer **training_arguments) -> None: if self._unrecoverable: return - torch.save({ + state_dicts = { "optimizer": toolbox.optimizer.state_dict(), "scheduler": toolbox.scheduler.state_dict(), "criterion": toolbox.criterion.state_dict() - }, f"{self.experiment_folder()}/state_dicts.pth") + } + if toolbox.scaler: + state_dicts["scaler"] = toolbox.scaler.state_dict() + torch.save(state_dicts, f"{self.experiment_folder()}/state_dicts.pth") with open(f"{self.experiment_folder()}/state_orb.json", "w") as f: dump({"tracker": asdict(tracker), "training_arguments": training_arguments}, f) @@ -116,6 +120,9 @@ def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_m toolbox.optimizer.load_state_dict(state_dicts["optimizer"]) toolbox.scheduler.load_state_dict(state_dicts["scheduler"]) toolbox.criterion.load_state_dict(state_dicts["criterion"]) + if "scaler" in state_dicts: + toolbox.scaler = torch.amp.GradScaler() + toolbox.scaler.load_state_dict(state_dicts["scaler"]) return toolbox def recover_from(self, experiment_id: str) -> Self: @@ -391,16 +398,30 @@ def empty_cache(self) -> None: # Training methods + def _do_backward(self, loss: torch.Tensor, toolbox: TrainerToolbox) -> None: + if toolbox.scaler: + toolbox.scaler.scale(loss).backward() + else: + loss.backward() + @abstractmethod def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ str, float]]: raise NotImplementedError + def _device_type(self) -> str: + return self._device.type if isinstance(self._device, torch.device) else str(self._device).split(":")[0] + def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ str, float]]: toolbox.optimizer.zero_grad() - loss, metrics = self.backward(images, labels, toolbox) - toolbox.optimizer.step() + with torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None): + loss, metrics = self.backward(images, labels, toolbox) + if toolbox.scaler: + toolbox.scaler.step(toolbox.optimizer) + toolbox.scaler.update() + else: + toolbox.optimizer.step() toolbox.scheduler.step() if toolbox.ema: toolbox.ema.update_parameters(toolbox.model) @@ -434,7 +455,7 @@ def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]: def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, compile_model: bool = True, ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5, val_score_prediction: bool = True, val_score_prediction_degree: int = 5, save_preview: bool = True, - preview_quality: float = .75) -> None: + preview_quality: float = .75, amp: bool = False) -> None: training_arguments = self.filter_train_params(**locals()) self.init_experiment() if note: @@ -462,6 +483,9 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co toolbox = (self.load_toolbox if self.recovery() else self.build_toolbox)( num_epochs, example_shape, compile_model, ema ) + if amp and not toolbox.scaler: + toolbox.scaler = torch.amp.GradScaler() + self.log("Mixed precision training enabled") checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth" es_tolerance = early_stop_tolerance self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note, @@ -543,7 +567,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co def filter_train_params(**kwargs) -> dict[str, Setting]: return {k: v for k, v in kwargs.items() if k in ( "note", "num_checkpoints", "compile_model", "ema", "seed", "early_stop_tolerance", "val_score_prediction", - "val_score_prediction_degree", "save_preview", "preview_quality" + "val_score_prediction_degree", "save_preview", "preview_quality", "amp" )} def train_with_settings(self, num_epochs: int, **kwargs) -> None: @@ -573,7 +597,7 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float worst_score = float("+inf") metrics = {} num_cases = len(self._validation_dataloader) - with torch.no_grad(), Progress( + with torch.no_grad(), torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None), Progress( *Progress.get_default_columns(), SpinnerColumn(), console=self._console ) as progress: task = progress.add_task(f"Validating", total=num_cases) From 74fc98c23bd75b7fba7eae823671490ad47abefa Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Sun, 1 Mar 2026 01:31:45 -0500 Subject: [PATCH 2/6] Fix AMP scheduler drift and GradScaler device mismatch (#219) --- mipcandy/training.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 0286303..e139e5f 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -121,7 +121,7 @@ def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_m toolbox.scheduler.load_state_dict(state_dicts["scheduler"]) toolbox.criterion.load_state_dict(state_dicts["criterion"]) if "scaler" in state_dicts: - toolbox.scaler = torch.amp.GradScaler() + toolbox.scaler = torch.amp.GradScaler(self._device_type()) toolbox.scaler.load_state_dict(state_dicts["scaler"]) return toolbox @@ -424,11 +424,14 @@ def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: Train with torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None): loss, metrics = self.backward(images, labels, toolbox) if toolbox.scaler: + old_scale = toolbox.scaler.get_scale() toolbox.scaler.step(toolbox.optimizer) toolbox.scaler.update() + if old_scale <= toolbox.scaler.get_scale(): + toolbox.scheduler.step() else: toolbox.optimizer.step() - toolbox.scheduler.step() + toolbox.scheduler.step() if toolbox.ema: toolbox.ema.update_parameters(toolbox.model) return loss, metrics @@ -490,7 +493,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co num_epochs, example_shape, compile_model, ema ) if amp and not toolbox.scaler: - toolbox.scaler = torch.amp.GradScaler() + toolbox.scaler = torch.amp.GradScaler(self._device_type()) self.log("Mixed precision training enabled") checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth" es_tolerance = early_stop_tolerance From 45be0149291509ca6d279a82f3b30be7987fd14b Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Fri, 6 Mar 2026 10:06:22 -0500 Subject: [PATCH 3/6] Use `binary_cross_entropy_with_logits` for autocast safety (#179) --- mipcandy/common/optim/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mipcandy/common/optim/loss.py b/mipcandy/common/optim/loss.py index 7edcb0e..78c2bb2 100644 --- a/mipcandy/common/optim/loss.py +++ b/mipcandy/common/optim/loss.py @@ -88,9 +88,9 @@ def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1, self.min_percentage_per_class: float | None = min_percentage_per_class def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: - outputs = outputs.sigmoid() labels = labels.to(dtype=outputs.dtype) - bce = nn.functional.binary_cross_entropy(outputs, labels) + bce = nn.functional.binary_cross_entropy_with_logits(outputs, labels) + outputs = outputs.sigmoid() dice = soft_dice(outputs, labels, smooth=self.smooth) metrics = {"soft dice": dice.item(), "bce loss": bce.item()} c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - dice) From 7ad71be4fa7ddfff12e9c6ceed67e7ebeb7d096b Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Fri, 6 Mar 2026 21:51:18 -0500 Subject: [PATCH 4/6] Guard EMA update on successful AMP optimizer step (#179) --- mipcandy/training.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index e139e5f..5ba5c16 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -429,11 +429,13 @@ def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: Train toolbox.scaler.update() if old_scale <= toolbox.scaler.get_scale(): toolbox.scheduler.step() + if toolbox.ema: + toolbox.ema.update_parameters(toolbox.model) else: toolbox.optimizer.step() toolbox.scheduler.step() - if toolbox.ema: - toolbox.ema.update_parameters(toolbox.model) + if toolbox.ema: + toolbox.ema.update_parameters(toolbox.model) return loss, metrics def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]: From 84dfb45a5d7de6a441422abd7f9e7c14f0144d5e Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Mon, 16 Mar 2026 12:51:23 -0400 Subject: [PATCH 5/6] Honor `amp=False` when resuming from a mixed-precision run (#219) --- mipcandy/training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mipcandy/training.py b/mipcandy/training.py index 5ba5c16..d51a5ac 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -497,6 +497,9 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co if amp and not toolbox.scaler: toolbox.scaler = torch.amp.GradScaler(self._device_type()) self.log("Mixed precision training enabled") + elif not amp and toolbox.scaler: + toolbox.scaler = None + self.log("Mixed precision training disabled") checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth" es_tolerance = early_stop_tolerance self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note, From 1dd724767ac9868afbb1f32d9cfee451864c8a13 Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Thu, 19 Mar 2026 11:49:34 -0400 Subject: [PATCH 6/6] Cast soft-dice inputs to float32 for AMP compatibility (#219) --- mipcandy/common/optim/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mipcandy/common/optim/loss.py b/mipcandy/common/optim/loss.py index 78c2bb2..ffe49de 100644 --- a/mipcandy/common/optim/loss.py +++ b/mipcandy/common/optim/loss.py @@ -58,7 +58,7 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T if not self.include_background: outputs = outputs[:, 1:] labels = labels[:, 1:] - dice = soft_dice(outputs, labels, smooth=self.smooth) + dice = soft_dice(outputs.float(), labels.float(), smooth=self.smooth) metrics = {"soft dice": dice.item(), "ce loss": ce.item()} c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - dice) return c, metrics @@ -91,7 +91,7 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T labels = labels.to(dtype=outputs.dtype) bce = nn.functional.binary_cross_entropy_with_logits(outputs, labels) outputs = outputs.sigmoid() - dice = soft_dice(outputs, labels, smooth=self.smooth) + dice = soft_dice(outputs.float(), labels.float(), smooth=self.smooth) metrics = {"soft dice": dice.item(), "bce loss": bce.item()} c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - dice) return c, metrics