Skip to content

Commit 346e9db

Browse files
committed
ruff: remove Optional
1 parent 67f4798 commit 346e9db

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

TPTBox/registration/deepali/_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Sequence
66
from contextlib import ContextDecorator
77
from pathlib import Path
8-
from typing import Literal, Optional, Union
8+
from typing import Literal, Union
99

1010
import torch
1111
import torch.optim
@@ -49,7 +49,7 @@ def get_post_transform(
4949
target_grid: Grid,
5050
source_grid: Grid,
5151
align=False,
52-
) -> Optional[SpatialTransform]:
52+
) -> SpatialTransform | None:
5353
r"""Get constant rigid transformation between image grid domains."""
5454
if align is False or align is None:
5555
return None
@@ -90,7 +90,7 @@ def load_transform(path: PathStr, grid: Grid) -> SpatialTransform:
9090
"""
9191
target_grid = grid
9292

93-
def convert_matrix(matrix: Tensor, grid: Optional[Grid] = None) -> Tensor:
93+
def convert_matrix(matrix: Tensor, grid: Grid | None = None) -> Tensor:
9494
if grid is None:
9595
pre = target_grid.transform(Axes.CUBE_CORNERS, Axes.WORLD)
9696
post = target_grid.transform(Axes.WORLD, Axes.CUBE_CORNERS)
@@ -166,7 +166,7 @@ def __exit__(self, exc_type, exc_value, traceback):
166166
self.scheduler.step()
167167

168168

169-
def overlap_mask(source_mask: Tensor | None, target_mask: Tensor | None) -> Optional[Tensor]:
169+
def overlap_mask(source_mask: Tensor | None, target_mask: Tensor | None) -> Tensor | None:
170170
r"""Overlap mask at which to evaluate pairwise data term."""
171171
if source_mask is None:
172172
return target_mask
@@ -184,11 +184,12 @@ def make_foreground_mask(image: Image, foreground_lower_threshold, foreground_up
184184
return Image(mask, image.grid())
185185

186186

187-
def normalize_img(image: Image, normalize_strategy: Optional[Literal["auto", "CT", "MRI"]]):
187+
def normalize_img(image: Image, normalize_strategy: Literal["auto", "CT", "MRI"] | None):
188188
if normalize_strategy is None:
189189
return image
190190
data = image.tensor()
191191
if normalize_strategy == "MRI":
192+
data = data.float()
192193
max_v = torch.quantile(data[data > 0], q=0.95)
193194
min_v = 0
194195
elif normalize_strategy == "CT":
@@ -210,7 +211,7 @@ def normalize_img(image: Image, normalize_strategy: Optional[Literal["auto", "CT
210211
return Image(data, image.grid())
211212

212213

213-
def clamp_mask(image: Optional[Image]):
214+
def clamp_mask(image: Image | None):
214215
if image is None:
215216
return image
216217
data = image.tensor()

0 commit comments

Comments
 (0)