Skip to content

Commit c98d734

Browse files
authored
Some training fixes and improvements (#65)
* Add missing wandb finish for aborted jobs * Update utils.py * Update plugin_model_training.py * Update plugin_model_training.py * Fix num_epochs for aborted training in csv * Fix Wnet weight loading * Fix reference to WNet weights for transfer learn * Update plugin_model_training.py * Improved safety for closing previous wandb runs * Fix rogue Path instead of str * Update worker_training.py * Change bounds for WNet train parameters * Shorten file names in train log * More range foor rec loss weight * Update Dice coeff calculation * Fix matching len check for csv * Update training_wnet.rst * Update WNet docs + typos * Update utils.py * Change rec loss weight default * Change discrete output display for WNet training * Fix incorrect checks for csv saving * Small improvement to make_csv * Fix logic issue in make_csv
1 parent f89e272 commit c98d734

File tree

5 files changed

+131
-57
lines changed

5 files changed

+131
-57
lines changed

docs/source/guides/cropping_module_guide.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Once you have launched the review process, you will gain control over three slid
5959
you to **adjust the position** of the cropped volumes and labels in the x,y and z positions.
6060

6161
.. note::
62-
* If your **cropped volume isnt visible**, consider changing the **colormap** of the image and the cropped
62+
* If your **cropped volume isn't visible**, consider changing the **colormap** of the image and the cropped
6363
volume to improve their visibility.
6464
* You may want to adjust the **opacity** and **contrast thresholds** depending on your image.
6565
* If the image appears empty:

docs/source/guides/training_wnet.rst

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,50 @@ Walkthrough - WNet3D training
44
===============================
55

66
This plugin provides a reimplemented, custom version of the WNet3D model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_.
7+
78
For training your model, you can choose among:
89

910
* Directly within the plugin
1011
* The provided Jupyter notebook (locally)
11-
* Our Colab notebook (inspired by ZeroCostDL4Mic)
12+
* Our Colab notebook (inspired by https://github.com/HenriquesLab/ZeroCostDL4Mic)
13+
14+
Selecting training data
15+
-------------------------
16+
17+
The WNet3D **does not require a large amount of data to train**, but **choosing the right data** to train this unsupervised model **is crucial**.
18+
19+
You may find below some guidelines, based on our own data and testing.
20+
21+
The WNet3D is designed to segment objects based on their brightness, and is particularly well-suited for images with a clear contrast between objects and background.
22+
23+
The WNet3D is not suitable for images with artifacts, therefore care should be taken that the images are clean and that the objects are at least somewhat distinguishable from the background.
24+
25+
26+
.. important::
27+
For optimal performance, the following should be avoided for training:
28+
29+
- Images with very large, bright regions
30+
- Almost-empty and empty images
31+
- Images with large empty regions or "holes"
1232

13-
The WNet3D does not require a large amount of data to train, but during inference images should be similar to those
14-
the model was trained on; you can retrain from our pretrained model to your image dataset to quickly reach good performance.
33+
However, the model may be accomodate:
34+
35+
- Uneven brightness distribution
36+
- Varied object shapes and radius
37+
- Noisy images
38+
- Uneven illumination across the image
39+
40+
For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement.
41+
42+
You may also retrain from our pretrained model to your image dataset to help quickly reach good performance if, simply check "Use pre-trained weights" in the training module, and lower the learning rate.
1543

1644
.. note::
17-
- The WNet3D relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider training on one of the supervised models.
18-
- The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**. Unlike the method described in the original paper, these losses are added in a weighted sum and the backward pass is performed for the whole model at once. The SoftNcuts and BCE are bounded between 0 and 1; the MSE may take large positive values. It is recommended to watch for the weighted sum of losses to be **close to one on the first epoch**, for training stability.
19-
- For good performance, you should wait for the SoftNCut to reach a plateau; the reconstruction loss must also decrease but is generally less critical.
45+
- The WNet3D relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider trying one of our supervised models (for lightsheet microscopy).
46+
- The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**.
47+
- For good performance, wait for the SoftNCut to reach a plateau; the reconstruction loss should also be decreasing overall, but this is generally less critical for segmentation performance.
2048

2149
Parameters
22-
----------
50+
-------------
2351

2452
.. figure:: ../images/training_tab_4.png
2553
:scale: 100 %
@@ -29,7 +57,7 @@ Parameters
2957

3058
_`When using the WNet3D training module`, the **Advanced** tab contains a set of additional options:
3159

32-
- **Number of classes** : Dictates the segmentation classes (default is 2). Increasing the number of classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects or artifacts with a significantly different brightness.
60+
- **Number of classes** : Dictates the segmentation classes (default is 2). Increasing the number of classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects, or to approximate boundary labels.
3361
- **Reconstruction loss** : Choose between MSE or BCE (default is MSE). MSE is more precise but also sensitive to outliers; BCE is more robust against outliers at the cost of precision.
3462

3563
- NCuts parameters:
@@ -43,22 +71,28 @@ _`When using the WNet3D training module`, the **Advanced** tab contains a set of
4371

4472
- Weights for the sum of losses :
4573
- **NCuts weight** : Sets the weight of the NCuts loss (default is 0.5).
46-
- **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 0.5*1e-2).
74+
- **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 5*1e-3).
4775

48-
.. note::
49-
The weight of the reconstruction loss should be adjusted to ensure the weighted sum is around one during the first epoch;
50-
ideally the reconstruction loss should be of the same order of magnitude as the NCuts loss after being multiplied by its weight.
76+
.. important::
77+
The weight of the reconstruction loss should be adjusted to ensure that both losses are balanced.
78+
79+
This balance can be assessed using the live view of training outputs :
80+
if the NCuts loss is "taking over", causing the segmentation to only label very large, brighter versus dimmer regions, the reconstruction loss should be increased.
81+
82+
This will help the model to focus on the details of the objects, rather than just the overall brightness of the volume.
5183

5284
Common issues troubleshooting
5385
------------------------------
54-
If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub.
5586

56-
- **The NCuts loss explodes after a few epochs** : Lower the learning rate, first by a factor of two, then ten.
87+
.. important::
88+
If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub.
89+
90+
91+
- **The NCuts loss "explodes" after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten.
5792

58-
- **The NCuts loss does not converge and is unstable** :
59-
The normalization step might not be adapted to your images. Disable normalization and change intensity_sigma according to the distribution of values in your image. For reference, by default images are remapped to values between 0 and 100, and intensity_sigma=1.
93+
- **Reconstruction (decoder) performance is poor** : First, try increasing the weight of the reconstruction loss. If this is ineffective, switch to BCE loss and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss.
6094

61-
- **Reconstruction (decoder) performance is poor** : switch to BCE and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss to make it closer to 1 in the weighted sum.
95+
- **Segmentation only separates the brighter versus dimmer regions** : Increase the weight of the reconstruction loss.
6296

6397

6498
.. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506

napari_cellseg3d/code_models/worker_training.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def log_parameters(self):
388388
self.log("-- Data --")
389389
self.log("Training data :\n")
390390
[
391-
self.log(f"{v}")
391+
self.log(f"{Path(v).stem}")
392392
for d in self.config.train_data_dict
393393
for k, v in d.items()
394394
]
@@ -423,6 +423,11 @@ def train(
423423
if WANDB_INSTALLED:
424424
config_dict = self.config.__dict__
425425
logger.debug(f"wandb config : {config_dict}")
426+
if wandb.run is not None:
427+
logger.warning(
428+
"A previous wandb run is still active. It will be stopped before starting a new one."
429+
)
430+
wandb.finish()
426431
wandb.init(
427432
config=config_dict,
428433
project="CellSeg3D - WNet",
@@ -472,13 +477,21 @@ def train(
472477
if WANDB_INSTALLED:
473478
wandb.watch(model, log_freq=100)
474479

475-
if self.config.weights_info.use_custom:
480+
if (
481+
self.config.weights_info.use_pretrained
482+
or self.config.weights_info.use_custom
483+
):
476484
if self.config.weights_info.use_pretrained:
477-
weights_file = "wnet.pth"
485+
from napari_cellseg3d.code_models.models.model_WNet import (
486+
WNet_,
487+
)
488+
489+
weights_file = WNet_.weights_file
478490
self.downloader.download_weights("WNet", weights_file)
479-
weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file)
491+
weights = str(PRETRAINED_WEIGHTS_DIR / Path(weights_file))
480492
self.config.weights_info.path = weights
481-
else:
493+
494+
if self.config.weights_info.use_custom:
482495
weights = str(Path(self.config.weights_info.path))
483496

484497
try:
@@ -624,6 +637,9 @@ def train(
624637
del criterionW
625638
torch.cuda.empty_cache()
626639

640+
if WANDB_INSTALLED:
641+
wandb.finish()
642+
627643
self.ncuts_losses.append(
628644
epoch_ncuts_loss / len(self.dataloader)
629645
)
@@ -642,9 +658,7 @@ def train(
642658
"cmap": "turbo",
643659
},
644660
"Encoder output (discrete)": {
645-
"data": AsDiscrete(threshold=0.5)(
646-
enc_out
647-
).numpy(),
661+
"data": np.where(enc_out > 0.5, enc_out, 0),
648662
"cmap": "bop blue",
649663
},
650664
"Decoder output": {
@@ -736,7 +750,8 @@ def train(
736750
if epoch % 5 == 0:
737751
torch.save(
738752
model.state_dict(),
739-
self.config.results_path_folder + "/wnet_.pth",
753+
self.config.results_path_folder
754+
+ "/wnet_checkpoint.pth",
740755
)
741756

742757
self.log("Training finished")
@@ -856,8 +871,7 @@ def eval(self, model, epoch) -> TrainingReport:
856871
self.dice_metric(
857872
y_pred=val_outputs[
858873
:,
859-
max_dice_channel : (max_dice_channel + 1),
860-
:,
874+
max_dice_channel:, # : (max_dice_channel + 1),
861875
:,
862876
:,
863877
],
@@ -1120,6 +1134,11 @@ def train(
11201134
config_dict = self.config.__dict__
11211135
logger.debug(f"wandb config : {config_dict}")
11221136
try:
1137+
if wandb.run is not None:
1138+
logger.warning(
1139+
"A previous wandb run is still active. It will be stopped before starting a new one."
1140+
)
1141+
wandb.finish()
11231142
wandb.init(
11241143
config=config_dict,
11251144
project="CellSeg3D",
@@ -1410,13 +1429,13 @@ def get_patch_loader_func(num_samples):
14101429
# time = utils.get_date_time()
14111430
logger.debug("Weights")
14121431

1413-
if weights_config.use_custom:
1432+
if weights_config.use_custom or weights_config.use_pretrained:
14141433
if weights_config.use_pretrained:
14151434
weights_file = model_class.weights_file
14161435
self.downloader.download_weights(model_name, weights_file)
1417-
weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file)
1436+
weights = str(PRETRAINED_WEIGHTS_DIR / Path(weights_file))
14181437
weights_config.path = weights
1419-
else:
1438+
elif weights_config.use_custom:
14201439
weights = str(Path(weights_config.path))
14211440

14221441
try:
@@ -1523,6 +1542,9 @@ def get_patch_loader_func(num_samples):
15231542
if device.type == "cuda":
15241543
torch.cuda.empty_cache()
15251544

1545+
if WANDB_INSTALLED:
1546+
wandb.finish()
1547+
15261548
yield TrainingReport(
15271549
show_plot=False,
15281550
weights=model.state_dict(),

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ def _set_worker_config(
10541054
logger.debug("Loading config...")
10551055
model_config = config.ModelInfo(name=self.model_choice.currentText())
10561056

1057-
self.weights_config.path = self.weights_config.path
1057+
# self.weights_config.path = self.weights_config.path
10581058
self.weights_config.use_custom = self.custom_weights_choice.isChecked()
10591059

10601060
self.weights_config.use_pretrained = (
@@ -1365,31 +1365,44 @@ def on_yield(self, report: TrainingReport):
13651365
self.on_stop()
13661366
self._stop_requested = False
13671367

1368-
def _make_csv(self):
1369-
size_column = range(1, self.worker_config.max_epochs + 1)
1368+
def _check_lens(self, size_column, loss_values):
1369+
if len(size_column) != len(loss_values):
1370+
logger.info(
1371+
f"Training was stopped, setting epochs for csv to {len(loss_values)}"
1372+
)
1373+
return range(1, len(loss_values) + 1)
1374+
return size_column
13701375

1371-
if len(self.loss_1_values) == 0 or self.loss_1_values is None:
1376+
def _handle_loss_values(self, size_column, key):
1377+
loss_values = self.loss_1_values.get(key)
1378+
if loss_values is None:
1379+
return None
1380+
1381+
if len(loss_values) == 0:
13721382
logger.warning("No loss values to add to csv !")
1373-
return
1383+
return None
13741384

1375-
try:
1376-
self.loss_1_values["Loss"]
1377-
supervised = True
1378-
except KeyError:
1379-
try:
1380-
self.loss_1_values["SoftNCuts"]
1381-
supervised = False
1382-
except KeyError as e:
1383-
raise KeyError(
1384-
"Error when making csv. Check loss dict keys ?"
1385-
) from e
1385+
return self._check_lens(size_column, loss_values)
1386+
1387+
def _make_csv(self): # TDOD(cyril) design could use a good rework
1388+
size_column = range(1, self.worker_config.max_epochs + 1)
1389+
1390+
supervised = True
1391+
size_column = self._handle_loss_values(size_column, "Loss")
1392+
if size_column is None:
1393+
size_column = self._handle_loss_values(size_column, "SoftNCuts")
1394+
if size_column is None:
1395+
raise KeyError("Error when making csv. Check loss dict keys ?")
1396+
supervised = False
13861397

13871398
if supervised:
1388-
val = utils.fill_list_in_between(
1399+
val = utils.fill_list_in_between( # fills the validation list based on validation interval
13891400
self.loss_2_values,
13901401
self.worker_config.validation_interval - 1,
13911402
"",
1392-
)[: len(size_column)]
1403+
)[
1404+
: len(size_column)
1405+
]
13931406

13941407
self.df = pd.DataFrame(
13951408
{
@@ -1404,8 +1417,13 @@ def _make_csv(self):
14041417
raise ValueError(err)
14051418
else:
14061419
ncuts_loss = self.loss_1_values["SoftNCuts"]
1420+
1421+
logger.debug(f"Epochs : {len(size_column)}")
1422+
logger.debug(f"Loss 1 values : {len(ncuts_loss)}")
1423+
logger.debug(f"Loss 2 values : {len(self.loss_2_values)}")
14071424
try:
14081425
dice_metric = self.loss_1_values["Dice metric"]
1426+
logger.debug(f"Dice metric : {dice_metric}")
14091427
self.df = pd.DataFrame(
14101428
{
14111429
"Epoch": size_column,
@@ -1630,15 +1648,15 @@ def __init__(self, parent):
16301648
text_label="Number of classes",
16311649
)
16321650
self.intensity_sigma_choice = ui.DoubleIncrementCounter(
1633-
lower=1.0,
1651+
lower=0.01,
16341652
upper=100.0,
16351653
default=self.default_config.intensity_sigma,
16361654
parent=parent,
16371655
text_label="Intensity sigma",
16381656
)
16391657
self.intensity_sigma_choice.setMaximumWidth(20)
16401658
self.spatial_sigma_choice = ui.DoubleIncrementCounter(
1641-
lower=1.0,
1659+
lower=0.01,
16421660
upper=100.0,
16431661
default=self.default_config.spatial_sigma,
16441662
parent=parent,
@@ -1674,10 +1692,10 @@ def __init__(self, parent):
16741692
)
16751693
self.reconstruction_weight_choice.setMaximumWidth(20)
16761694
self.reconstruction_weight_divide_factor_choice = (
1677-
ui.IntIncrementCounter(
1678-
lower=1,
1679-
upper=10000,
1680-
default=100,
1695+
ui.DoubleIncrementCounter(
1696+
lower=0.01,
1697+
upper=10000.0,
1698+
default=1.0,
16811699
parent=parent,
16821700
text_label="Reconstruction weight divide factor",
16831701
)

napari_cellseg3d/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ class WNetTrainingWorkerConfig(TrainingWorkerConfig):
399399
reconstruction_loss: str = "MSE" # or "BCE"
400400
# summed losses weights
401401
n_cuts_weight: float = 0.5
402-
rec_loss_weight: float = 0.5 / 100
402+
rec_loss_weight: float = 0.5 / 1.0 # 0.5 / 100
403403
# normalization params
404404
# normalizing_function: callable = remap_image # FIXME: call directly in worker, not a param
405405
# data params

0 commit comments

Comments
 (0)