Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# Weights & Biases
wandb/
3 changes: 2 additions & 1 deletion cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_mode
save_path=os.path.realpath(args.dir),
save_every=args.save_every,
save_each=args.save_each,
model_name=args.model_name_out)[0]
model_name=args.model_name_out,
wandb_enabled=not args.no_wandb)[0]
model.pretrained_model = cpmodel_path
logger.info(">>>> model trained and saved to %s" % cpmodel_path)
return model
Expand Down
12 changes: 11 additions & 1 deletion cellpose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,17 @@ def get_arg_parser():
"--model_name_out", default=None, type=str,
help="Name of model to save as, defaults to name describing model architecture. "
"Model is saved in the folder specified by --dir in models subfolder.")


# Weights & Biases logging (optional, no-op if wandb not installed / not logged in)
wandb_args = parser.add_argument_group("Weights & Biases Arguments")
wandb_args.add_argument(
"--no_wandb", action="store_true",
help="disable Weights & Biases logging even if wandb is installed and logged in. "
"Project defaults to $WANDB_PROJECT or 'cellpose'; run name defaults to --model_name_out "
"(or an auto-generated 'cellpose_<timestamp>' name). All other wandb settings (entity, "
"tags, group, notes, ...) can be set via standard wandb environment variables."
)

# TODO: remove deprecated in future version
training_args.add_argument(
"--diam_mean", default=30., type=float, help=
Expand Down
84 changes: 82 additions & 2 deletions cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from cellpose import io, utils, models, dynamics
from cellpose.transforms import normalize_img, random_rotate_and_resize
from cellpose.wandb_logger import WandbLogger
from pathlib import Path
import torch
from torch import nn
Expand Down Expand Up @@ -314,7 +315,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False,
save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256,
min_train_masks=5, model_name=None, class_weights=None):
min_train_masks=5, model_name=None, class_weights=None,
wandb_enabled=True):
"""
Train the network with images for segmentation.

Expand Down Expand Up @@ -346,6 +348,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to False.
min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5.
model_name (str, optional): String - name of the network. Defaults to None.
wandb_enabled (bool, optional): Whether to attempt Weights & Biases logging if wandb is installed and credentials are available. Defaults to True. Logging is silently skipped if wandb is unavailable. The run name defaults to ``model_name``; the project defaults to ``$WANDB_PROJECT`` or "cellpose". All other wandb settings (entity, tags, group, notes, ...) are controlled via standard wandb environment variables.

Returns:
tuple: A tuple containing the path to the saved model weights, training losses, and test losses.
Expand Down Expand Up @@ -429,9 +432,45 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,

train_logger.info(f">>> saving model to {filename}")

# Initialize wandb logging (no-op if wandb not installed or not logged in)
wandb_config = {
"model_name": model_name,
"n_epochs": n_epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
"weight_decay": weight_decay,
"optimizer": "AdamW",
"normalize": normalize_params,
"rescale": rescale,
"scale_range": scale_range,
"bsize": bsize,
"min_train_masks": min_train_masks,
"nimg_train": nimg,
"nimg_test": nimg_test,
"nimg_per_epoch": nimg_per_epoch,
"nimg_test_per_epoch": nimg_test_per_epoch,
"device": str(device),
"net_dtype": str(original_net_dtype),
"diam_mean": float(net.diam_mean.item()),
"diam_labels": float(diam_train.mean()),
"channel_axis": channel_axis,
"has_class_weights": class_weights is not None,
"save_path": str(save_path),
"save_every": save_every,
"save_each": save_each,
}
wandb_logger = WandbLogger(
enabled=wandb_enabled,
run_name=model_name,
config=wandb_config,
)

lavg, nsum = 0, 0
train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs)
best_test_loss = float("inf")
best_test_epoch = -1
for iepoch in range(n_epochs):
t_epoch_start = time.time()
np.random.seed(iepoch)
if nimg != nimg_per_epoch:
# choose random images for epoch with probability train_probs
Expand Down Expand Up @@ -479,6 +518,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
train_losses[iepoch] += train_loss
train_losses[iepoch] /= nimg_per_epoch

epoch_time = time.time() - t_epoch_start

if iepoch == 5 or iepoch % 10 == 0:
lavgt = 0.
if test_data is not None or test_files is not None:
Expand Down Expand Up @@ -516,23 +557,62 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
lavgt += test_loss
lavgt /= len(rperm)
test_losses[iepoch] = lavgt
if lavgt > 0 and lavgt < best_test_loss:
best_test_loss = float(lavgt)
best_test_epoch = int(iepoch)
lavg /= nsum
train_logger.info(
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s"
)
# log windowed/averaged + validation metrics on the same step as the per-epoch log
wandb_logger.log(
{
"val/loss": float(lavgt) if (test_data is not None or test_files is not None) else None,
"train/loss_avg_window": float(lavg),
"train/best_val_loss": float(best_test_loss) if best_test_epoch >= 0 else None,
"train/best_val_epoch": best_test_epoch if best_test_epoch >= 0 else None,
},
step=iepoch,
commit=False,
)
lavg, nsum = 0, 0

# log per-epoch metrics every epoch; commit=True flushes the step
wandb_logger.log(
{
"epoch": iepoch,
"train/loss": float(train_losses[iepoch]),
"train/learning_rate": float(LR[iepoch]),
"train/epoch_time_s": float(epoch_time),
"train/elapsed_s": float(time.time() - t0),
},
step=iepoch,
commit=True,
)

if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
if save_each and iepoch != n_epochs - 1: #separate files as model progresses
filename0 = str(filename) + f"_epoch_{iepoch:04d}"
else:
filename0 = filename
train_logger.info(f"saving network parameters to {filename0}")
net.save_model(filename0)

net.save_model(filename)
if original_net_dtype != torch.float32:
train_logger.info(f">>> converting network back to {original_net_dtype} after training")
net.dtype = original_net_dtype

wandb_logger.log_summary(
{
"final/train_loss": float(train_losses[-1]),
"final/test_loss": float(test_losses[-1]),
"best/test_loss": float(best_test_loss) if best_test_epoch >= 0 else None,
"best/test_epoch": best_test_epoch if best_test_epoch >= 0 else None,
"total_train_time_s": float(time.time() - t0),
"model_path": str(filename),
}
)
wandb_logger.finish()

return filename, train_losses, test_losses
174 changes: 174 additions & 0 deletions cellpose/wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
Optional Weights & Biases (wandb) logging for cellpose training.

This module is intentionally defensive: training must continue normally
whether or not wandb is installed, whether or not the user is logged in,
and whether or not wandb's backend is reachable. Any failure here is
swallowed and downgraded to a log message - it must never raise into the
training loop.

Standard wandb environment variables are honored (e.g. ``WANDB_PROJECT``,
``WANDB_ENTITY``, ``WANDB_API_KEY``, ``WANDB_MODE``, ``WANDB_DIR``,
``WANDB_RUN_GROUP``, ``WANDB_TAGS``). Setting ``WANDB_DISABLED=true`` or
``WANDB_MODE=disabled`` will skip wandb entirely. Setting
``CELLPOSE_DISABLE_WANDB=1`` does the same and is provided for users who
do not want to touch wandb's own env vars.
"""

import logging
import os

logger = logging.getLogger(__name__)

_DEFAULT_PROJECT = "cellpose"


def _wandb_disabled_by_env():
"""Honor common opt-out env vars without trying to import wandb."""
for var in ("CELLPOSE_DISABLE_WANDB", "WANDB_DISABLED"):
val = os.environ.get(var, "").strip().lower()
if val in ("1", "true", "yes", "on"):
return True
if os.environ.get("WANDB_MODE", "").strip().lower() == "disabled":
return True
return False


class WandbLogger:
"""
Thin wrapper around ``wandb`` for cellpose training.

The logger is considered *enabled* only if:

1. ``enabled=True`` was passed (the default)
2. wandb is importable
3. no opt-out env var is set
4. ``wandb.init(...)`` completes without raising

If any of those fail, the logger silently becomes a no-op so that
``log()``, ``log_summary()`` and ``finish()`` are always safe to call.

Project, entity, tags, group, notes, etc. are driven by wandb's own
environment variables (``WANDB_PROJECT``, ``WANDB_ENTITY``, ``WANDB_TAGS``,
``WANDB_RUN_GROUP``, ``WANDB_NOTES``, ...). The project defaults to
``"cellpose"`` if ``WANDB_PROJECT`` is not set.
"""

def __init__(self, enabled=True, run_name=None, config=None):
self.enabled = False
self.wandb = None
self.run = None

if not enabled:
return

if _wandb_disabled_by_env():
logger.debug("wandb logging skipped (disabled via environment).")
return

try:
import wandb # type: ignore
except ImportError:
logger.debug(
"wandb not installed; skipping wandb logging. "
"To enable, `pip install wandb` and `wandb login`."
)
return
except Exception as e:
logger.debug(f"wandb import failed ({e!r}); skipping wandb logging.")
return

# Heuristic check for credentials. wandb.init() will also error out if
# not logged in, but checking up-front lets us avoid a noisy stack
# trace in the common "not logged in" case.
if not self._looks_authenticated(wandb):
logger.info(
"wandb is installed but no credentials were found - skipping "
"wandb logging. Run `wandb login` (or set WANDB_API_KEY) to enable."
)
return

init_kwargs = dict(
project=os.environ.get("WANDB_PROJECT", _DEFAULT_PROJECT),
name=run_name,
config=config or {},
)
# The 'reinit' param was migrated to a string value ('finish_previous')
# in newer wandb. Try the new value first, fall back to the legacy bool.
try:
try:
self.run = wandb.init(reinit="finish_previous", **init_kwargs)
except TypeError:
self.run = wandb.init(reinit=True, **init_kwargs)
except Exception as e:
logger.warning(
f"wandb.init() failed ({e!r}); continuing without wandb logging."
)
self.run = None
return

self.wandb = wandb
self.enabled = True
try:
logger.info(
f">>> wandb logging enabled: project={self.run.project}, "
f"run={self.run.name}, url={self.run.url}"
)
except Exception:
logger.info(">>> wandb logging enabled")

@staticmethod
def _looks_authenticated(wandb_module):
"""Best-effort check for a usable wandb credential."""
if os.environ.get("WANDB_API_KEY"):
return True
if os.environ.get("WANDB_MODE", "").strip().lower() in ("offline", "dryrun"):
return True
try:
api_key = wandb_module.api.api_key
except Exception:
api_key = None
return bool(api_key)

def log(self, metrics, step=None, commit=None):
"""Log a dict of scalar metrics. ``None`` values are dropped.
Safe no-op if disabled."""
if not self.enabled:
return
clean = {k: v for k, v in metrics.items() if v is not None}
if not clean:
return
try:
kwargs = {}
if step is not None:
kwargs["step"] = step
if commit is not None:
kwargs["commit"] = commit
self.wandb.log(clean, **kwargs)
except Exception as e:
logger.debug(f"wandb.log() failed ({e!r}); disabling further wandb logs.")
self.enabled = False

def log_summary(self, summary):
"""Update the run's summary dict (final/best metrics)."""
if not self.enabled:
return
try:
for k, v in summary.items():
if v is None:
continue
self.run.summary[k] = v
except Exception as e:
logger.debug(f"wandb summary update failed ({e!r}).")

def finish(self):
"""End the wandb run if one was created."""
if not self.enabled:
return
try:
self.wandb.finish()
except Exception as e:
logger.debug(f"wandb.finish() failed ({e!r}).")
finally:
self.enabled = False
self.run = None
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
'bioimageio.core',
]

wandb_deps = [
'wandb',
]

try:
import torch
a = torch.ones(2, 3)
Expand Down Expand Up @@ -94,7 +98,8 @@
'gui': gui_deps,
'distributed': distributed_deps,
'bioimageio': bioimageio_deps,
'all': gui_deps + distributed_deps + image_deps + bioimageio_deps,
'wandb': wandb_deps,
'all': gui_deps + distributed_deps + image_deps + bioimageio_deps + wandb_deps,
}, include_package_data=True, classifiers=(
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
Expand Down