From 02568516a58150edeeb1d8744de7ccdee8a7a963 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 21 May 2026 08:59:35 -0400 Subject: [PATCH] add weights and biases to cellpose training --- .gitignore | 3 + cellpose/__main__.py | 3 +- cellpose/cli.py | 12 ++- cellpose/train.py | 84 ++++++++++++++++++- cellpose/wandb_logger.py | 174 +++++++++++++++++++++++++++++++++++++++ setup.py | 7 +- 6 files changed, 278 insertions(+), 5 deletions(-) create mode 100644 cellpose/wandb_logger.py diff --git a/.gitignore b/.gitignore index b6e47617..b1814c94 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# Weights & Biases +wandb/ diff --git a/cellpose/__main__.py b/cellpose/__main__.py index 5ca07bbd..214a8c7f 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -167,7 +167,8 @@ def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_mode save_path=os.path.realpath(args.dir), save_every=args.save_every, save_each=args.save_each, - model_name=args.model_name_out)[0] + model_name=args.model_name_out, + wandb_enabled=not args.no_wandb)[0] model.pretrained_model = cpmodel_path logger.info(">>>> model trained and saved to %s" % cpmodel_path) return model diff --git a/cellpose/cli.py b/cellpose/cli.py index b7a7c182..e0842769 100644 --- a/cellpose/cli.py +++ b/cellpose/cli.py @@ -229,7 +229,17 @@ def get_arg_parser(): "--model_name_out", default=None, type=str, help="Name of model to save as, defaults to name describing model architecture. " "Model is saved in the folder specified by --dir in models subfolder.") - + + # Weights & Biases logging (optional, no-op if wandb not installed / not logged in) + wandb_args = parser.add_argument_group("Weights & Biases Arguments") + wandb_args.add_argument( + "--no_wandb", action="store_true", + help="disable Weights & Biases logging even if wandb is installed and logged in. " + "Project defaults to $WANDB_PROJECT or 'cellpose'; run name defaults to --model_name_out " + "(or an auto-generated 'cellpose_' name). All other wandb settings (entity, " + "tags, group, notes, ...) can be set via standard wandb environment variables." + ) + # TODO: remove deprecated in future version training_args.add_argument( "--diam_mean", default=30., type=float, help= diff --git a/cellpose/train.py b/cellpose/train.py index aa312c74..cfb0d90f 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -3,6 +3,7 @@ import numpy as np from cellpose import io, utils, models, dynamics from cellpose.transforms import normalize_img, random_rotate_and_resize +from cellpose.wandb_logger import WandbLogger from pathlib import Path import torch from torch import nn @@ -314,7 +315,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False, save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256, - min_train_masks=5, model_name=None, class_weights=None): + min_train_masks=5, model_name=None, class_weights=None, + wandb_enabled=True): """ Train the network with images for segmentation. @@ -346,6 +348,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to False. min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5. model_name (str, optional): String - name of the network. Defaults to None. + wandb_enabled (bool, optional): Whether to attempt Weights & Biases logging if wandb is installed and credentials are available. Defaults to True. Logging is silently skipped if wandb is unavailable. The run name defaults to ``model_name``; the project defaults to ``$WANDB_PROJECT`` or "cellpose". All other wandb settings (entity, tags, group, notes, ...) are controlled via standard wandb environment variables. Returns: tuple: A tuple containing the path to the saved model weights, training losses, and test losses. @@ -429,9 +432,45 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_logger.info(f">>> saving model to {filename}") + # Initialize wandb logging (no-op if wandb not installed or not logged in) + wandb_config = { + "model_name": model_name, + "n_epochs": n_epochs, + "batch_size": batch_size, + "learning_rate": learning_rate, + "weight_decay": weight_decay, + "optimizer": "AdamW", + "normalize": normalize_params, + "rescale": rescale, + "scale_range": scale_range, + "bsize": bsize, + "min_train_masks": min_train_masks, + "nimg_train": nimg, + "nimg_test": nimg_test, + "nimg_per_epoch": nimg_per_epoch, + "nimg_test_per_epoch": nimg_test_per_epoch, + "device": str(device), + "net_dtype": str(original_net_dtype), + "diam_mean": float(net.diam_mean.item()), + "diam_labels": float(diam_train.mean()), + "channel_axis": channel_axis, + "has_class_weights": class_weights is not None, + "save_path": str(save_path), + "save_every": save_every, + "save_each": save_each, + } + wandb_logger = WandbLogger( + enabled=wandb_enabled, + run_name=model_name, + config=wandb_config, + ) + lavg, nsum = 0, 0 train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs) + best_test_loss = float("inf") + best_test_epoch = -1 for iepoch in range(n_epochs): + t_epoch_start = time.time() np.random.seed(iepoch) if nimg != nimg_per_epoch: # choose random images for epoch with probability train_probs @@ -479,6 +518,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_losses[iepoch] += train_loss train_losses[iepoch] /= nimg_per_epoch + epoch_time = time.time() - t_epoch_start + if iepoch == 5 or iepoch % 10 == 0: lavgt = 0. if test_data is not None or test_files is not None: @@ -516,12 +557,39 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavgt += test_loss lavgt /= len(rperm) test_losses[iepoch] = lavgt + if lavgt > 0 and lavgt < best_test_loss: + best_test_loss = float(lavgt) + best_test_epoch = int(iepoch) lavg /= nsum train_logger.info( f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) + # log windowed/averaged + validation metrics on the same step as the per-epoch log + wandb_logger.log( + { + "val/loss": float(lavgt) if (test_data is not None or test_files is not None) else None, + "train/loss_avg_window": float(lavg), + "train/best_val_loss": float(best_test_loss) if best_test_epoch >= 0 else None, + "train/best_val_epoch": best_test_epoch if best_test_epoch >= 0 else None, + }, + step=iepoch, + commit=False, + ) lavg, nsum = 0, 0 + # log per-epoch metrics every epoch; commit=True flushes the step + wandb_logger.log( + { + "epoch": iepoch, + "train/loss": float(train_losses[iepoch]), + "train/learning_rate": float(LR[iepoch]), + "train/epoch_time_s": float(epoch_time), + "train/elapsed_s": float(time.time() - t0), + }, + step=iepoch, + commit=True, + ) + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): if save_each and iepoch != n_epochs - 1: #separate files as model progresses filename0 = str(filename) + f"_epoch_{iepoch:04d}" @@ -529,10 +597,22 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, filename0 = filename train_logger.info(f"saving network parameters to {filename0}") net.save_model(filename0) - + net.save_model(filename) if original_net_dtype != torch.float32: train_logger.info(f">>> converting network back to {original_net_dtype} after training") net.dtype = original_net_dtype + wandb_logger.log_summary( + { + "final/train_loss": float(train_losses[-1]), + "final/test_loss": float(test_losses[-1]), + "best/test_loss": float(best_test_loss) if best_test_epoch >= 0 else None, + "best/test_epoch": best_test_epoch if best_test_epoch >= 0 else None, + "total_train_time_s": float(time.time() - t0), + "model_path": str(filename), + } + ) + wandb_logger.finish() + return filename, train_losses, test_losses diff --git a/cellpose/wandb_logger.py b/cellpose/wandb_logger.py new file mode 100644 index 00000000..3293a36e --- /dev/null +++ b/cellpose/wandb_logger.py @@ -0,0 +1,174 @@ +""" +Optional Weights & Biases (wandb) logging for cellpose training. + +This module is intentionally defensive: training must continue normally +whether or not wandb is installed, whether or not the user is logged in, +and whether or not wandb's backend is reachable. Any failure here is +swallowed and downgraded to a log message - it must never raise into the +training loop. + +Standard wandb environment variables are honored (e.g. ``WANDB_PROJECT``, +``WANDB_ENTITY``, ``WANDB_API_KEY``, ``WANDB_MODE``, ``WANDB_DIR``, +``WANDB_RUN_GROUP``, ``WANDB_TAGS``). Setting ``WANDB_DISABLED=true`` or +``WANDB_MODE=disabled`` will skip wandb entirely. Setting +``CELLPOSE_DISABLE_WANDB=1`` does the same and is provided for users who +do not want to touch wandb's own env vars. +""" + +import logging +import os + +logger = logging.getLogger(__name__) + +_DEFAULT_PROJECT = "cellpose" + + +def _wandb_disabled_by_env(): + """Honor common opt-out env vars without trying to import wandb.""" + for var in ("CELLPOSE_DISABLE_WANDB", "WANDB_DISABLED"): + val = os.environ.get(var, "").strip().lower() + if val in ("1", "true", "yes", "on"): + return True + if os.environ.get("WANDB_MODE", "").strip().lower() == "disabled": + return True + return False + + +class WandbLogger: + """ + Thin wrapper around ``wandb`` for cellpose training. + + The logger is considered *enabled* only if: + + 1. ``enabled=True`` was passed (the default) + 2. wandb is importable + 3. no opt-out env var is set + 4. ``wandb.init(...)`` completes without raising + + If any of those fail, the logger silently becomes a no-op so that + ``log()``, ``log_summary()`` and ``finish()`` are always safe to call. + + Project, entity, tags, group, notes, etc. are driven by wandb's own + environment variables (``WANDB_PROJECT``, ``WANDB_ENTITY``, ``WANDB_TAGS``, + ``WANDB_RUN_GROUP``, ``WANDB_NOTES``, ...). The project defaults to + ``"cellpose"`` if ``WANDB_PROJECT`` is not set. + """ + + def __init__(self, enabled=True, run_name=None, config=None): + self.enabled = False + self.wandb = None + self.run = None + + if not enabled: + return + + if _wandb_disabled_by_env(): + logger.debug("wandb logging skipped (disabled via environment).") + return + + try: + import wandb # type: ignore + except ImportError: + logger.debug( + "wandb not installed; skipping wandb logging. " + "To enable, `pip install wandb` and `wandb login`." + ) + return + except Exception as e: + logger.debug(f"wandb import failed ({e!r}); skipping wandb logging.") + return + + # Heuristic check for credentials. wandb.init() will also error out if + # not logged in, but checking up-front lets us avoid a noisy stack + # trace in the common "not logged in" case. + if not self._looks_authenticated(wandb): + logger.info( + "wandb is installed but no credentials were found - skipping " + "wandb logging. Run `wandb login` (or set WANDB_API_KEY) to enable." + ) + return + + init_kwargs = dict( + project=os.environ.get("WANDB_PROJECT", _DEFAULT_PROJECT), + name=run_name, + config=config or {}, + ) + # The 'reinit' param was migrated to a string value ('finish_previous') + # in newer wandb. Try the new value first, fall back to the legacy bool. + try: + try: + self.run = wandb.init(reinit="finish_previous", **init_kwargs) + except TypeError: + self.run = wandb.init(reinit=True, **init_kwargs) + except Exception as e: + logger.warning( + f"wandb.init() failed ({e!r}); continuing without wandb logging." + ) + self.run = None + return + + self.wandb = wandb + self.enabled = True + try: + logger.info( + f">>> wandb logging enabled: project={self.run.project}, " + f"run={self.run.name}, url={self.run.url}" + ) + except Exception: + logger.info(">>> wandb logging enabled") + + @staticmethod + def _looks_authenticated(wandb_module): + """Best-effort check for a usable wandb credential.""" + if os.environ.get("WANDB_API_KEY"): + return True + if os.environ.get("WANDB_MODE", "").strip().lower() in ("offline", "dryrun"): + return True + try: + api_key = wandb_module.api.api_key + except Exception: + api_key = None + return bool(api_key) + + def log(self, metrics, step=None, commit=None): + """Log a dict of scalar metrics. ``None`` values are dropped. + Safe no-op if disabled.""" + if not self.enabled: + return + clean = {k: v for k, v in metrics.items() if v is not None} + if not clean: + return + try: + kwargs = {} + if step is not None: + kwargs["step"] = step + if commit is not None: + kwargs["commit"] = commit + self.wandb.log(clean, **kwargs) + except Exception as e: + logger.debug(f"wandb.log() failed ({e!r}); disabling further wandb logs.") + self.enabled = False + + def log_summary(self, summary): + """Update the run's summary dict (final/best metrics).""" + if not self.enabled: + return + try: + for k, v in summary.items(): + if v is None: + continue + self.run.summary[k] = v + except Exception as e: + logger.debug(f"wandb summary update failed ({e!r}).") + + def finish(self): + """End the wandb run if one was created.""" + if not self.enabled: + return + try: + self.wandb.finish() + except Exception as e: + logger.debug(f"wandb.finish() failed ({e!r}).") + finally: + self.enabled = False + self.run = None diff --git a/setup.py b/setup.py index cb4dc077..f58e1deb 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,10 @@ 'bioimageio.core', ] +wandb_deps = [ + 'wandb', +] + try: import torch a = torch.ones(2, 3) @@ -94,7 +98,8 @@ 'gui': gui_deps, 'distributed': distributed_deps, 'bioimageio': bioimageio_deps, - 'all': gui_deps + distributed_deps + image_deps + bioimageio_deps, + 'wandb': wandb_deps, + 'all': gui_deps + distributed_deps + image_deps + bioimageio_deps + wandb_deps, }, include_package_data=True, classifiers=( "Programming Language :: Python :: 3", "License :: OSI Approved :: BSD License",