Skip to content

Commit f6473f9

Browse files
authored
Merge branch 'dev' into 8587-test-erros-on-pytorch-release-2508-on-series-50
2 parents 08a1cce + 65beb58 commit f6473f9

39 files changed

Lines changed: 1805 additions & 173 deletions

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ include monai/_version.py
33

44
include README.md
55
include LICENSE
6+
7+
prune tests

docs/source/metrics.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ Metrics
158158

159159
`Fréchet Inception Distance`
160160
------------------------------
161+
`Embedding Collapse`
162+
------------------------------
163+
.. autofunction:: compute_embedding_collapse
164+
165+
.. autoclass:: EmbeddingCollapseMetric
166+
:members:
167+
161168
.. autofunction:: compute_frechet_distance
162169

163170
.. autoclass:: FIDMetric

monai/apps/detection/utils/anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
253253
# compute anchor centers regarding to the image.
254254
# shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]
255255
shifts_centers = [
256-
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]
256+
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + stride[axis] // 2
257257
for axis in range(self.spatial_dims)
258258
]
259259

monai/apps/nuclick/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,14 @@ def inclusion_map(self, mask, dtype):
367367

368368
def exclusion_map(self, others, dtype, jitter_range, drop_rate):
369369
point_mask = torch.zeros_like(others, dtype=dtype)
370-
if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):
370+
if self.R.choice([True, False], p=[drop_rate, 1 - drop_rate]):
371371
return point_mask
372372

373373
max_x = point_mask.shape[0] - 1
374374
max_y = point_mask.shape[1] - 1
375375
stats = measure.regionprops(convert_to_numpy(others))
376376
for stat in stats:
377-
if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):
377+
if self.R.choice([True, False], p=[drop_rate, 1 - drop_rate]):
378378
continue
379379

380380
# random jitter

monai/auto3dseg/analyzer.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
216216
super().__init__(stats_name, report_format)
217217
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
218218

219+
@torch.no_grad()
219220
def __call__(self, data):
220-
# Input Validation Addition
221-
if not isinstance(data, dict):
222-
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
223-
if self.image_key not in data:
224-
raise KeyError(f"Key '{self.image_key}' not found in input data.")
225-
image = data[self.image_key]
226-
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
227-
raise TypeError(
228-
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
229-
f"but got {type(image).__name__}."
230-
)
231-
if image.ndim < 3:
232-
raise ValueError(
233-
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
234-
)
235-
# --- End of validation ---
236221
"""
237-
Callable to execute the pre-defined functions
222+
Callable to execute the pre-defined functions.
238223
239224
Returns:
240225
A dictionary. The dict has the key in self.report_format. The value of
241226
ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
242227
has stats pre-defined by SampleOperations (max, min, ....).
243228
244229
Raises:
245-
RuntimeError if the stats report generated is not consistent with the pre-
230+
KeyError: if ``self.image_key`` is not present in the input data.
231+
TypeError: if the input data is not a dictionary, or if the image value is
232+
not a numpy array, torch.Tensor, or MetaTensor.
233+
ValueError: if the image has fewer than 3 dimensions, or if pre-computed
234+
``nda_croppeds`` is not a list/tuple with one entry per image channel.
235+
RuntimeError: if the stats report generated is not consistent with the pre-
246236
defined report_format.
247237
248238
Note:
249239
The stats operation uses numpy and torch to compute max, min, and other
250240
functions. If the input has nan/inf, the stats results will be nan/inf.
251241
252242
"""
243+
if not isinstance(data, dict):
244+
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
245+
if self.image_key not in data:
246+
raise KeyError(f"Key '{self.image_key}' not found in input data.")
247+
image = data[self.image_key]
248+
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
249+
raise TypeError(
250+
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
251+
f"but got {type(image).__name__}."
252+
)
253+
if image.ndim < 3:
254+
raise ValueError(
255+
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
256+
)
257+
253258
d = dict(data)
254259
start = time.time()
255-
restore_grad_state = torch.is_grad_enabled()
256-
torch.set_grad_enabled(False)
257-
258260
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
259-
if "nda_croppeds" not in d:
261+
if "nda_croppeds" in d:
262+
nda_croppeds = d["nda_croppeds"]
263+
if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas):
264+
raise ValueError(
265+
"Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
266+
f"(expected {len(ndas)})."
267+
)
268+
else:
260269
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
261270

262-
# perform calculation
263271
report = deepcopy(self.get_report_format())
264272

265273
report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
@@ -284,7 +292,6 @@ def __call__(self, data):
284292

285293
d[self.stats_name] = report
286294

287-
torch.set_grad_enabled(restore_grad_state)
288295
logger.debug(f"Get image stats spent {time.time() - start}")
289296
return d
290297

@@ -321,6 +328,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
321328
super().__init__(stats_name, report_format)
322329
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
323330

331+
@torch.no_grad()
324332
def __call__(self, data: Mapping) -> dict:
325333
"""
326334
Callable to execute the pre-defined functions
@@ -341,9 +349,6 @@ def __call__(self, data: Mapping) -> dict:
341349

342350
d = dict(data)
343351
start = time.time()
344-
restore_grad_state = torch.is_grad_enabled()
345-
torch.set_grad_enabled(False)
346-
347352
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
348353
ndas_label = d[self.label_key] # (H,W,D)
349354

@@ -353,7 +358,6 @@ def __call__(self, data: Mapping) -> dict:
353358
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
354359
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
355360

356-
# perform calculation
357361
report = deepcopy(self.get_report_format())
358362

359363
report[ImageStatsKeys.INTENSITY] = [
@@ -365,7 +369,6 @@ def __call__(self, data: Mapping) -> dict:
365369

366370
d[self.stats_name] = report
367371

368-
torch.set_grad_enabled(restore_grad_state)
369372
logger.debug(f"Get foreground image stats spent {time.time() - start}")
370373
return d
371374

@@ -418,6 +421,7 @@ def __init__(
418421
id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST])
419422
self.update_ops_nested_label(id_seq, SampleOperations())
420423

424+
@torch.no_grad()
421425
def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]:
422426
"""
423427
Callable to execute the pre-defined functions.
@@ -470,19 +474,15 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470474
start = time.time()
471475
image_tensor = d[self.image_key]
472476
label_tensor = d[self.label_key]
473-
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
474477
using_cuda = any(
475478
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
476479
)
477-
restore_grad_state = torch.is_grad_enabled()
478-
torch.set_grad_enabled(False)
479480

480481
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
481482
label_tensor, (MetaTensor, torch.Tensor)
482483
):
483484
if label_tensor.device != image_tensor.device:
484485
if using_cuda:
485-
# Move both tensors to CUDA when mixing devices
486486
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487487
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488488
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
@@ -548,7 +548,6 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
548548

549549
d[self.stats_name] = report # type: ignore[assignment]
550550

551-
torch.set_grad_enabled(restore_grad_state)
552551
logger.debug(f"Get label stats spent {time.time() - start}")
553552
return d # type: ignore[return-value]
554553

monai/losses/image_dissimilarity.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.nn import functional as F
1616
from torch.nn.modules.loss import _Loss
1717

18-
from monai.networks.layers import gaussian_1d, separable_filtering
18+
from monai.networks.layers import separable_filtering
1919
from monai.utils import LossReduction
2020
from monai.utils.module import look_up_option
2121

@@ -34,11 +34,11 @@ def make_triangular_kernel(kernel_size: int) -> torch.Tensor:
3434

3535

3636
def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
37-
sigma = torch.tensor(kernel_size / 3.0)
38-
kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * (
39-
2.5066282 * sigma
40-
)
41-
return kernel[:kernel_size]
37+
sigma = kernel_size / 3.0
38+
half = kernel_size // 2
39+
x = torch.arange(-half, half + 1, dtype=torch.float)
40+
kernel = torch.exp(-0.5 / (sigma * sigma) * x**2)
41+
return kernel
4242

4343

4444
kernel_dict = {
@@ -111,14 +111,16 @@ def __init__(
111111
raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")
112112

113113
_kernel = look_up_option(kernel_type, kernel_dict)
114-
self.kernel = _kernel(self.kernel_size)
115-
self.kernel.require_grads = False
116-
self.kernel_vol = self.get_kernel_vol()
114+
self.kernel: torch.Tensor
115+
self.kernel_vol: torch.Tensor
116+
self.register_buffer("kernel", _kernel(self.kernel_size), persistent=False)
117+
self.register_buffer("kernel_vol", self.get_kernel_vol(), persistent=False)
117118

118119
self.smooth_nr = float(smooth_nr)
119120
self.smooth_dr = float(smooth_dr)
120121

121-
def get_kernel_vol(self):
122+
def get_kernel_vol(self) -> torch.Tensor:
123+
assert self.kernel is not None
122124
vol = self.kernel
123125
for _ in range(self.ndim - 1):
124126
vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))
@@ -138,6 +140,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
138140
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
139141

140142
t2, p2, tp = target * target, pred * pred, target * pred
143+
assert self.kernel is not None
144+
assert self.kernel_vol is not None
141145
kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred)
142146
kernels = [kernel] * self.ndim
143147
# sum over kernel

monai/losses/spectral_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(
5555
self.fft_norm = fft_norm
5656

5757
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
58-
input_amplitude = self._get_fft_amplitude(target)
59-
target_amplitude = self._get_fft_amplitude(input)
58+
input_amplitude = self._get_fft_amplitude(input)
59+
target_amplitude = self._get_fft_amplitude(target)
6060

6161
# Compute distance between amplitude of frequency components
6262
# See Section 3.3 from https://arxiv.org/abs/2005.00341

monai/losses/ssim_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
111111
# 2D data
112112
x = torch.ones([1,1,10,10])/2
113113
y = torch.ones([1,1,10,10])/2
114-
print(1-SSIMLoss(spatial_dims=2)(x,y))
114+
print(SSIMLoss(spatial_dims=2)(x,y))
115115
116116
# pseudo-3D data
117117
x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices
118118
y = torch.ones([1,5,10,10])/2
119-
print(1-SSIMLoss(spatial_dims=2)(x,y))
119+
print(SSIMLoss(spatial_dims=2)(x,y))
120120
121121
# 3D data
122122
x = torch.ones([1,1,10,10,10])/2
123123
y = torch.ones([1,1,10,10,10])/2
124-
print(1-SSIMLoss(spatial_dims=3)(x,y))
124+
print(SSIMLoss(spatial_dims=3)(x,y))
125125
"""
126126
ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)
127127
loss: torch.Tensor = 1 - ssim_value

monai/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning
1717
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
1818
from .cumulative_average import CumulativeAverage
19+
from .embedding_collapse import EmbeddingCollapseMetric, compute_embedding_collapse
1920
from .f_beta_score import FBetaScore
2021
from .fid import FIDMetric, compute_frechet_distance
2122
from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score

0 commit comments

Comments
 (0)