2222from monai .networks .blocks import Convolution
2323from monai .networks .blocks .spatialattention import SpatialAttentionBlock
2424from monai .networks .nets .autoencoderkl import AEKLResBlock , AutoencoderKL
25+ from monai .utils .deprecate_utils import deprecated_arg
2526from monai .utils .type_conversion import convert_to_tensor
2627
2728# Set up logging configuration
@@ -34,6 +35,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
3435 return
3536
3637
38+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
3739class MaisiGroupNorm3D (nn .GroupNorm ):
3840 """
3941 Custom 3D Group Normalization with optional print_info output.
@@ -43,7 +45,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
4345 num_channels: Number of channels for the group norm.
4446 eps: Epsilon value for numerical stability.
4547 affine: Whether to use learnable affine parameters, default to `True`.
46- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
48+ norm_float16: Deprecated argument .
4749 print_info: Whether to print information, default to `False`.
4850 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
4951 """
@@ -59,14 +61,15 @@ def __init__(
5961 save_mem : bool = True ,
6062 ):
6163 super ().__init__ (num_groups , num_channels , eps , affine )
62- self .norm_float16 = norm_float16
6364 self .print_info = print_info
6465 self .save_mem = save_mem
6566
6667 def forward (self , input : torch .Tensor ) -> torch .Tensor :
6768 if self .print_info :
6869 logger .info (f"MaisiGroupNorm3D with input size: { input .size ()} " )
6970
71+ target_dtype = input .dtype
72+
7073 if len (input .shape ) != 5 :
7174 raise ValueError ("Expected a 5D tensor" )
7275
@@ -75,13 +78,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
7578
7679 inputs = []
7780 for i in range (input .size (1 )):
78- array = input [:, i : i + 1 , ...]. to ( dtype = torch . float32 )
81+ array = input [:, i : i + 1 , ...]
7982 mean = array .mean ([2 , 3 , 4 , 5 ], keepdim = True )
8083 std = array .var ([2 , 3 , 4 , 5 ], unbiased = False , keepdim = True ).add_ (self .eps ).sqrt_ ()
81- if self .norm_float16 :
82- inputs .append (((array - mean ) / std ).to (dtype = torch .float16 ))
83- else :
84- inputs .append ((array - mean ) / std )
84+ inputs .append (((array - mean ) / std ).to (dtype = target_dtype ))
8585
8686 del input
8787 _empty_cuda_cache (self .save_mem )
@@ -376,6 +376,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
376376 return x
377377
378378
379+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
379380class MaisiResBlock (nn .Module ):
380381 """
381382 Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
@@ -417,7 +418,6 @@ def __init__(
417418 num_channels = in_channels ,
418419 eps = norm_eps ,
419420 affine = True ,
420- norm_float16 = norm_float16 ,
421421 print_info = print_info ,
422422 save_mem = save_mem ,
423423 )
@@ -439,7 +439,6 @@ def __init__(
439439 num_channels = out_channels ,
440440 eps = norm_eps ,
441441 affine = True ,
442- norm_float16 = norm_float16 ,
443442 print_info = print_info ,
444443 save_mem = save_mem ,
445444 )
@@ -501,6 +500,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
501500 return out_tensor
502501
503502
503+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
504504class MaisiEncoder (nn .Module ):
505505 """
506506 Convolutional cascade that downsamples the image into a spatial latent space.
@@ -520,7 +520,7 @@ class MaisiEncoder(nn.Module):
520520 use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
521521 num_splits: Number of splits for the input tensor.
522522 dim_split: Dimension of splitting for the input tensor.
523- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
523+ norm_float16: Deprecated argument .
524524 print_info: Whether to print information, default to `False`.
525525 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
526526 """
@@ -591,7 +591,6 @@ def __init__(
591591 out_channels = output_channel ,
592592 num_splits = num_splits ,
593593 dim_split = dim_split ,
594- norm_float16 = norm_float16 ,
595594 print_info = print_info ,
596595 save_mem = save_mem ,
597596 )
@@ -660,7 +659,6 @@ def __init__(
660659 num_channels = num_channels [- 1 ],
661660 eps = norm_eps ,
662661 affine = True ,
663- norm_float16 = norm_float16 ,
664662 print_info = print_info ,
665663 save_mem = save_mem ,
666664 )
@@ -690,6 +688,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
690688 return x
691689
692690
691+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
693692class MaisiDecoder (nn .Module ):
694693 """
695694 Convolutional cascade upsampling from a spatial latent space into an image space.
@@ -710,7 +709,7 @@ class MaisiDecoder(nn.Module):
710709 use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
711710 num_splits: Number of splits for the input tensor.
712711 dim_split: Dimension of splitting for the input tensor.
713- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
712+ norm_float16: Deprecated argument .
714713 print_info: Whether to print information, default to `False`.
715714 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
716715 """
@@ -809,7 +808,6 @@ def __init__(
809808 out_channels = block_out_ch ,
810809 num_splits = num_splits ,
811810 dim_split = dim_split ,
812- norm_float16 = norm_float16 ,
813811 print_info = print_info ,
814812 save_mem = save_mem ,
815813 )
@@ -848,7 +846,6 @@ def __init__(
848846 num_channels = block_in_ch ,
849847 eps = norm_eps ,
850848 affine = True ,
851- norm_float16 = norm_float16 ,
852849 print_info = print_info ,
853850 save_mem = save_mem ,
854851 )
@@ -878,6 +875,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878875 return x
879876
880877
878+ @deprecated_arg ("norm_float16" , since = "1.5.0" , removed = "1.7.0" )
881879class AutoencoderKlMaisi (AutoencoderKL ):
882880 """
883881 AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
@@ -901,7 +899,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
901899 use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
902900 num_splits: Number of splits for the input tensor.
903901 dim_split: Dimension of splitting for the input tensor.
904- norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False` .
902+ norm_float16: Deprecated argument .
905903 print_info: Whether to print information, default to `False`.
906904 save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
907905 """
@@ -964,7 +962,6 @@ def __init__(
964962 use_flash_attention = use_flash_attention ,
965963 num_splits = num_splits ,
966964 dim_split = dim_split ,
967- norm_float16 = norm_float16 ,
968965 print_info = print_info ,
969966 save_mem = save_mem ,
970967 )
@@ -985,7 +982,6 @@ def __init__(
985982 use_convtranspose = use_convtranspose ,
986983 num_splits = num_splits ,
987984 dim_split = dim_split ,
988- norm_float16 = norm_float16 ,
989985 print_info = print_info ,
990986 save_mem = save_mem ,
991987 )
0 commit comments