Skip to content

Training loss crush after few epochs #858

@sulaimanvesal

Description

@sulaimanvesal

Hi team,

Thank you for your all supports.

It's been awhile that I am using Monai for our multi-modal image segmentation (DynUNet 3D). Recently after fixing many issues, I was able to run a model. However, there is a weird behavior during the training. The model is converging nicely however after 10-15 epochs both the loss and dice_metric crushing and goes to zero.

In the beginning, I also get the following warning which after deubgging I though it's from dice metric function.
invalid value encountered in true_divide. However, I made sure that all the labels should have binary values [0,1]. My model is basically does a lesion segmentation so it's a binary.

I was wondering if I do something wrong here.

initialize network with normal
Tue Aug  9 06:33:18 2022 Epoch: 0
Final training  0/49 loss: 0.9928 time 2634.39s
y_pred should be a binarized tensor.
y should be a binarized tensor.
Final validation stats 0/49 , Dice_TC: 0.069751926 , Dice_Avg: 0.069751926 , time 491.24s
new best (0.000000 --> 0.069752).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 07:25:25 2022 Epoch: 1
Final training  1/49 loss: 0.9752 time 2545.19s
Final validation stats 1/49 , Dice_TC: 0.2406731 , Dice_Avg: 0.2406731 , time 474.64s
new best (0.069752 --> 0.240673).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 08:15:46 2022 Epoch: 2
Final training  2/49 loss: 0.9378 time 2539.51s
Final validation stats 2/49 , Dice_TC: 0.32946482 , Dice_Avg: 0.32946482 , time 463.56s
new best (0.240673 --> 0.329465).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 09:05:50 2022 Epoch: 3
Final training  3/49 loss: 0.9069 time 2520.52s
Final validation stats 3/49 , Dice_TC: 0.4997478 , Dice_Avg: 0.4997478 , time 473.23s
new best (0.329465 --> 0.499748).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 09:55:44 2022 Epoch: 4
Final training  4/49 loss: 0.9046 time 2647.55s
Final validation stats 4/49 , Dice_TC: 0.48587176 , Dice_Avg: 0.48587176 , time 444.82s
Tue Aug  9 10:47:16 2022 Epoch: 5
Final training  5/49 loss: 0.8924 time 2472.73s
Final validation stats 5/49 , Dice_TC: 0.5072948 , Dice_Avg: 0.5072948 , time 462.37s
new best (0.499748 --> 0.507295).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 11:36:12 2022 Epoch: 6
Final training  6/49 loss: 0.8802 time 2534.62s
Final validation stats 6/49 , Dice_TC: 0.51497203 , Dice_Avg: 0.51497203 , time 504.30s
new best (0.507295 --> 0.514972).
Saving checkpoint weights_multiInput/model_fold2_multiInput.pt
Tue Aug  9 12:26:52 2022 Epoch: 7
Final training  7/49 loss: 0.5781 time 2409.70s
Final validation stats 7/49 , Dice_TC: 2.0384609e-14 , Dice_Avg: 2.0384609e-14 , time 479.41s
Tue Aug  9 13:15:01 2022 Epoch: 8
Final training  8/49 loss: 0.2768 time 2407.49s
Final validation stats 8/49 , Dice_TC: 2.9713095e-17 , Dice_Avg: 2.9713095e-17 , time 429.29s
Tue Aug  9 14:02:18 2022 Epoch: 9
Final training  9/49 loss: 0.2783 time 2592.64s
Final validation stats 9/49 , Dice_TC: 2.2672533e-18 , Dice_Avg: 2.2672533e-18 , time 471.07s
Tue Aug  9 14:53:22 2022 Epoch: 10
Final training  10/49 loss: 0.2783 time 2667.51s
Final validation stats 10/49 , Dice_TC: 1.17504706e-20 , Dice_Avg: 1.17504706e-20 , time 502.52s
def get_loader(train_images_t2w,train_images_adc,train_images_dwi, train_segs, valid_images_t2w, valid_images_adc,valid_images_dwi,valid_segs, patch_size = [256, 256, 20]):
    
    data_dicts_train = [{'image_t2': image_name_t2, 'image_adc': image_name_adc, 'image_dwi': image_name_dwi,'label': label_name} for image_name_t2, image_name_adc, image_name_dwi, label_name in zip(train_images_t2w,train_images_adc,train_images_dwi, train_segs)]
    train_transform = Compose([
        LoadImaged(keys=['image_t2', 'image_adc','image_dwi', 'label']),
        AddChanneld(keys=['image_t2', 'image_adc','image_dwi', 'label']),
        Spacingd(keys=['image_t2', 'image_adc','image_dwi', 'label'], pixdim=(0.5, 0.5, 3.0),  mode=("bilinear",
                                                                                                    "bilinear",
                                                                                                    "bilinear", "nearest")),
        CenterSpatialCropd(keys=['image_t2', 'image_adc','image_dwi', 'label'],roi_size = patch_size), 
        SpatialPadd(keys=['image_t2', 'image_adc','image_dwi', 'label'], spatial_size=patch_size, method='end'),
        ConcatItemsd(keys=['image_t2', 'image_adc','image_dwi'], name="image"),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        RandAdjustContrastd(keys=['image'],gamma=(0.5, 2.5),prob=0.2),
        RandAffined(
           keys=("image", "label"),
           prob=0.5,
           #rotate_range=np.pi / 12,
           translate_range=(256*0.0625, 256*0.0625),
           scale_range=(0.1, 0.1),
           mode="nearest",
           padding_mode="reflection",
        ),
        OneOf(
           [ 
               RandGridDistortiond(keys=("image", "label"), 
                                   prob=0.5, distort_limit=(-0.05, 0.05), 
                                   mode="nearest", 
                                   padding_mode="reflection"),
               RandCoarseDropoutd(
                   keys=("image", "label"),
                   holes=5,
                   max_holes=8,
                   spatial_size=(1, 1, 1),
                   max_spatial_size=(12, 12, 12),
                   fill_value=0.0,
                   prob=0.5,
               ),
           ]
        ),

        Lambdad(keys="image", func=lambda x: x / x.max()), #NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
        ToTensord(keys=['image', 'label'])
    ])
    data_dicts_valid = [{'image_t2': image_name_t2, 
                         'image_adc': image_name_adc, 
                         'image_dwi': image_name_dwi,
                         'label': label_name} for image_name_t2, image_name_adc, image_name_dwi, label_name in zip(valid_images_t2w,valid_images_adc,valid_images_dwi, valid_segs)]
    valid_transform = Compose([
        LoadImaged(keys=['image_t2', 'image_adc','image_dwi', 'label']), 
        AddChanneld(keys=['image_t2', 'image_adc','image_dwi', 'label']),
        Spacingd(keys=['image_t2', 'image_adc','image_dwi', "label"], pixdim=[0.5, 0.5, 3.0], mode=("bilinear",
                                                                                                    "bilinear",
                                                                                                    "bilinear", "nearest")),

        CenterSpatialCropd(keys=['image_t2', 'image_adc','image_dwi', 'label'], roi_size = patch_size), 
        SpatialPadd(keys=['image_t2', 'image_adc','image_dwi', 'label'], spatial_size=patch_size, method='end'),
        ConcatItemsd(keys=['image_t2', 'image_adc','image_dwi'], name="image"),
        Lambdad(keys="image", func=lambda x: x / x.max()), #NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
        ToTensord(keys=['image', 'label'])
    ])
    
    train_ds = data.Dataset(data=data_dicts_train, transform=train_transform)
    train_loader = data.DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True, #collate_fn=list_data_collate,
        num_workers=12,
        pin_memory=True,
    )
    val_ds = data.Dataset(data=data_dicts_valid, transform=valid_transform)
    val_loader = data.DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=12, #collate_fn=list_data_collate,
        pin_memory=False
    )
    return train_loader, val_loader

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions