55from collections .abc import Sequence
66from contextlib import ContextDecorator
77from pathlib import Path
8- from typing import Literal , Optional , Union
8+ from typing import Literal , Union
99
1010import torch
1111import torch .optim
@@ -49,7 +49,7 @@ def get_post_transform(
4949 target_grid : Grid ,
5050 source_grid : Grid ,
5151 align = False ,
52- ) -> Optional [ SpatialTransform ] :
52+ ) -> SpatialTransform | None :
5353 r"""Get constant rigid transformation between image grid domains."""
5454 if align is False or align is None :
5555 return None
@@ -90,7 +90,7 @@ def load_transform(path: PathStr, grid: Grid) -> SpatialTransform:
9090 """
9191 target_grid = grid
9292
93- def convert_matrix (matrix : Tensor , grid : Optional [ Grid ] = None ) -> Tensor :
93+ def convert_matrix (matrix : Tensor , grid : Grid | None = None ) -> Tensor :
9494 if grid is None :
9595 pre = target_grid .transform (Axes .CUBE_CORNERS , Axes .WORLD )
9696 post = target_grid .transform (Axes .WORLD , Axes .CUBE_CORNERS )
@@ -166,7 +166,7 @@ def __exit__(self, exc_type, exc_value, traceback):
166166 self .scheduler .step ()
167167
168168
169- def overlap_mask (source_mask : Tensor | None , target_mask : Tensor | None ) -> Optional [ Tensor ] :
169+ def overlap_mask (source_mask : Tensor | None , target_mask : Tensor | None ) -> Tensor | None :
170170 r"""Overlap mask at which to evaluate pairwise data term."""
171171 if source_mask is None :
172172 return target_mask
@@ -184,11 +184,12 @@ def make_foreground_mask(image: Image, foreground_lower_threshold, foreground_up
184184 return Image (mask , image .grid ())
185185
186186
187- def normalize_img (image : Image , normalize_strategy : Optional [ Literal ["auto" , "CT" , "MRI" ]] ):
187+ def normalize_img (image : Image , normalize_strategy : Literal ["auto" , "CT" , "MRI" ] | None ):
188188 if normalize_strategy is None :
189189 return image
190190 data = image .tensor ()
191191 if normalize_strategy == "MRI" :
192+ data = data .float ()
192193 max_v = torch .quantile (data [data > 0 ], q = 0.95 )
193194 min_v = 0
194195 elif normalize_strategy == "CT" :
@@ -210,7 +211,7 @@ def normalize_img(image: Image, normalize_strategy: Optional[Literal["auto", "CT
210211 return Image (data , image .grid ())
211212
212213
213- def clamp_mask (image : Optional [ Image ] ):
214+ def clamp_mask (image : Image | None ):
214215 if image is None :
215216 return image
216217 data = image .tensor ()
0 commit comments