Skip to content

Commit 46bcde0

Browse files
authored
Merge branch 'dev' into patch-6
2 parents 3a5b114 + 7d3674e commit 46bcde0

15 files changed

Lines changed: 434 additions & 43 deletions

File tree

monai/data/image_reader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
11131113

11141114
for i, filename in zip(ensure_tuple(img), self.filenames):
11151115
header = self._get_meta_dict(i)
1116+
if MetaKeys.PIXDIM in header:
1117+
header[MetaKeys.ORIGINAL_PIXDIM] = np.array(header[MetaKeys.PIXDIM], copy=True)
11161118
header[MetaKeys.AFFINE] = self._get_affine(i)
11171119
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
11181120
header["as_closest_canonical"] = self.as_closest_canonical

monai/data/test_time_augmentation.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from copy import deepcopy
1717
from typing import TYPE_CHECKING, Any
1818

19-
import numpy as np
2019
import torch
2120

2221
from monai.config.type_definitions import NdarrayOrTensor
@@ -68,7 +67,7 @@ class TestTimeAugmentation:
6867
Args:
6968
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
7069
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
71-
. All random transforms must be of type `InvertibleTransform`.
70+
When `apply_inverse_to_pred` is True, all random transforms must be of type `InvertibleTransform`.
7271
batch_size: number of realizations to infer at once.
7372
num_workers: how many subprocesses to use for data.
7473
inferrer_fn: function to use to perform inference.
@@ -92,6 +91,11 @@ class TestTimeAugmentation:
9291
will return the full data. Dimensions will be same size as when passing a single image through
9392
`inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
9493
progress: whether to display a progress bar.
94+
apply_inverse_to_pred: whether to apply inverse transformations to the predictions.
95+
If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions
96+
back to the original spatial reference.
97+
If the prediction is non-spatial (e.g. classification label or score), this should be `False` to
98+
aggregate the raw predictions directly. Defaults to `True`.
9599
96100
Example:
97101
.. code-block:: python
@@ -125,6 +129,7 @@ def __init__(
125129
post_func: Callable = _identity,
126130
return_full_data: bool = False,
127131
progress: bool = True,
132+
apply_inverse_to_pred: bool = True,
128133
) -> None:
129134
self.transform = transform
130135
self.batch_size = batch_size
@@ -134,6 +139,7 @@ def __init__(
134139
self.image_key = image_key
135140
self.return_full_data = return_full_data
136141
self.progress = progress
142+
self.apply_inverse_to_pred = apply_inverse_to_pred
137143
self._pred_key = CommonKeys.PRED
138144
self.inverter = Invertd(
139145
keys=self._pred_key,
@@ -152,20 +158,23 @@ def __init__(
152158

153159
def _check_transforms(self):
154160
"""Should be at least 1 random transform, and all random transforms should be invertible."""
155-
ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
156-
randoms = np.array([isinstance(t, Randomizable) for t in ts])
157-
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
158-
# check at least 1 random
159-
if sum(randoms) == 0:
161+
transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
162+
warns = []
163+
randoms = []
164+
165+
for idx, t in enumerate(transforms):
166+
if isinstance(t, Randomizable):
167+
randoms.append(t)
168+
if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
169+
warns.append(f"Transform #{idx} (type {type(t).__name__}) is random but not invertible.")
170+
171+
if len(randoms) == 0:
172+
warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")
173+
174+
if len(warns) > 0:
160175
warnings.warn(
161-
"TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms."
176+
"TTA has encountered issues with the given transforms:\n " + "\n ".join(warns), stacklevel=2
162177
)
163-
# check that whenever randoms is True, invertibles is also true
164-
for r, i in zip(randoms, invertibles):
165-
if r and not i:
166-
warnings.warn(
167-
f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}"
168-
)
169178

170179
def __call__(
171180
self, data: dict[str, Any], num_examples: int = 10
@@ -199,7 +208,10 @@ def __call__(
199208
for b in tqdm(dl) if has_tqdm and self.progress else dl:
200209
# do model forward pass
201210
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
202-
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
211+
if self.apply_inverse_to_pred:
212+
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
213+
else:
214+
outs.extend([i[self._pred_key] for i in decollate_batch(b)])
203215

204216
output: NdarrayOrTensor = stack(outs, 0)
205217

monai/data/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,11 +597,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
597597
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
598598
):
599599
return batch
600+
# if scalar tensor/array, return the item itself.
601+
if getattr(batch, "ndim", -1) == 0 and hasattr(batch, "item"):
602+
return batch.item() if detach else batch
600603
if isinstance(batch, torch.Tensor):
601604
if detach:
602605
batch = batch.detach()
603-
if batch.ndim == 0:
604-
return batch.item() if detach else batch
605606
out_list = torch.unbind(batch, dim=0)
606607
# if of type MetaObj, decollate the metadata
607608
if isinstance(batch, MetaObj):

monai/inferers/utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def sliding_window_inference(
7676
7777
Args:
7878
inputs: input image to be processed (assuming NCHW[D])
79-
roi_size: the spatial window size for inferences.
79+
roi_size: the spatial window size for inferences, this must be a single value or a tuple with values
80+
for each spatial dimension (eg. 2 for 2D, 3 for 3D).
8081
When its components have None or non-positives, the corresponding inputs dimension will be used.
8182
if the components of the `roi_size` are non-positive values, the transform will use the
8283
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
@@ -131,11 +132,30 @@ def sliding_window_inference(
131132
kwargs: optional keyword args to be passed to ``predictor``.
132133
133134
Note:
134-
- input must be channel-first and have a batch dim, supports N-D sliding window.
135+
- Inputs must be channel-first and have a batch dim (NCHW / NCDHW).
136+
- If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream.
137+
138+
Raises:
139+
ValueError: When the input dimensions do not match the expected dimensions based on ``roi_size``.
135140
136141
"""
137-
buffered = buffer_steps is not None and buffer_steps > 0
138142
num_spatial_dims = len(inputs.shape) - 2
143+
144+
# Only perform strict shape validation if roi_size is a sequence (explicit dimensions).
145+
# If roi_size is an integer, it is broadcast to all dimensions, so we cannot
146+
# infer the expected dimensionality to enforce a strict check here.
147+
if isinstance(roi_size, Sequence):
148+
roi_dims = len(roi_size)
149+
if num_spatial_dims != roi_dims:
150+
raise ValueError(
151+
f"Inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size "
152+
f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), "
153+
f"but got inputs shape {inputs.shape}.\n"
154+
"If you have channel-last data (e.g. B, D, H, W, C), please use "
155+
"monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream."
156+
)
157+
# -----------------------------------------------------------------
158+
buffered = buffer_steps is not None and buffer_steps > 0
139159
if buffered:
140160
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
141161
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")

monai/losses/unified_focal_loss.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
Args:
4545
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
4646
delta : weight of the background. Defaults to 0.7.
47-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
47+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
4848
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
4949
"""
5050
super().__init__(reduction=LossReduction(reduction).value)
@@ -108,7 +108,7 @@ def __init__(
108108
Args:
109109
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110110
delta : weight of the background. Defaults to 0.7.
111-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
111+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
112112
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
113113
"""
114114
super().__init__(reduction=LossReduction(reduction).value)
@@ -167,10 +167,11 @@ def __init__(
167167
Args:
168168
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169169
num_classes : number of classes, it only supports 2 now. Defaults to 2.
170+
weight : weight for each loss function. Defaults to 0.5.
171+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
170172
delta : weight of the background. Defaults to 0.7.
171-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172-
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
173-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
173+
174+
174175
175176
Example:
176177
>>> import torch

monai/metrics/panoptic_quality.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
linear_sum_assignment, _ = optional_import("scipy.optimize", name="linear_sum_assignment")
2323

24-
__all__ = ["PanopticQualityMetric", "compute_panoptic_quality"]
24+
__all__ = ["PanopticQualityMetric", "compute_panoptic_quality", "compute_mean_iou"]
2525

2626

2727
class PanopticQualityMetric(CumulativeIterationMetric):
@@ -55,6 +55,8 @@ class PanopticQualityMetric(CumulativeIterationMetric):
5555
If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
5656
maximal amount of unique pairing.
5757
smooth_numerator: a small constant added to the numerator to avoid zero.
58+
return_confusion_matrix: if True, returns raw confusion matrix values (tp, fp, fn, iou_sum)
59+
instead of computed metrics. Default is False.
5860
5961
"""
6062

@@ -65,19 +67,22 @@ def __init__(
6567
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
6668
match_iou_threshold: float = 0.5,
6769
smooth_numerator: float = 1e-6,
70+
return_confusion_matrix: bool = False,
6871
) -> None:
6972
super().__init__()
7073
self.num_classes = num_classes
7174
self.reduction = reduction
7275
self.match_iou_threshold = match_iou_threshold
7376
self.smooth_numerator = smooth_numerator
7477
self.metric_name = ensure_tuple(metric_name)
78+
self.return_confusion_matrix = return_confusion_matrix
7579

7680
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
7781
"""
7882
Args:
79-
y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the
80-
second channel represent the instance predictions and classification predictions respectively.
83+
y_pred: Predictions. It must be in the form of B2HW (2D) or B2HWD (3D) and have integer type.
84+
The first channel and the second channel represent the instance predictions and classification
85+
predictions respectively.
8186
y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the
8287
second channel represent the instance labels and classification labels respectively.
8388
Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,
@@ -86,7 +91,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
8691
Raises:
8792
ValueError: when `y_pred` and `y` have different shapes.
8893
ValueError: when `y_pred` and `y` have != 2 channels.
89-
ValueError: when `y_pred` and `y` have != 4 dimensions.
94+
ValueError: when `y_pred` and `y` have != 4 or 5 dimensions.
9095
9196
"""
9297
if y_pred.shape != y.shape:
@@ -98,8 +103,10 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
98103
)
99104

100105
dims = y_pred.ndimension()
101-
if dims != 4:
102-
raise ValueError(f"y_pred should have 4 dimensions (batch, 2, h, w), got {dims}.")
106+
if dims not in (4, 5):
107+
raise ValueError(
108+
f"y_pred should have 4 dimensions (batch, 2, h, w) or 5 dimensions (batch, 2, h, w, d), got {dims}."
109+
)
103110

104111
batch_size = y_pred.shape[0]
105112

@@ -131,13 +138,22 @@ def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Ten
131138
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
132139
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
133140
141+
Returns:
142+
If `return_confusion_matrix` is True, returns the raw confusion matrix [tp, fp, fn, iou_sum].
143+
Otherwise, returns the computed metric(s) based on `metric_name`.
144+
134145
"""
135146
data = self.get_buffer()
136147
if not isinstance(data, torch.Tensor):
137148
raise ValueError("the data to aggregate must be PyTorch Tensor.")
138149

139150
# do metric reduction
140151
f, _ = do_metric_reduction(data, reduction or self.reduction)
152+
153+
if self.return_confusion_matrix:
154+
# Return raw confusion matrix values
155+
return f
156+
141157
tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3]
142158
results = []
143159
for metric_name in self.metric_name:
@@ -169,7 +185,7 @@ def compute_panoptic_quality(
169185
calculate PQ, and returning them directly enables further calculation over all images.
170186
171187
Args:
172-
pred: input data to compute, it must be in the form of HW and have integer type.
188+
pred: input data to compute, it must be in the form of HW (2D) or HWD (3D) and have integer type.
173189
gt: ground truth. It must have the same shape as `pred` and have integer type.
174190
metric_name: output metric. The value can be "pq", "sq" or "rq".
175191
remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.
@@ -294,3 +310,24 @@ def _check_panoptic_metric_name(metric_name: str) -> str:
294310
if metric_name in ["recognition_quality", "rq"]:
295311
return "rq"
296312
raise ValueError(f"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.")
313+
314+
315+
def compute_mean_iou(confusion_matrix: torch.Tensor, smooth_numerator: float = 1e-6) -> torch.Tensor:
316+
"""Compute mean IoU from confusion matrix values.
317+
318+
Args:
319+
confusion_matrix: tensor with shape (..., 4) where the last dimension contains
320+
[tp, fp, fn, iou_sum] as returned by `compute_panoptic_quality` with `output_confusion_matrix=True`.
321+
smooth_numerator: a small constant added to the numerator to avoid zero.
322+
323+
Returns:
324+
Mean IoU computed as iou_sum / (tp + smooth_numerator).
325+
326+
"""
327+
if confusion_matrix.shape[-1] != 4:
328+
raise ValueError(
329+
f"confusion_matrix should have shape (..., 4) with [tp, fp, fn, iou_sum], "
330+
f"got shape {confusion_matrix.shape}."
331+
)
332+
tp, iou_sum = confusion_matrix[..., 0], confusion_matrix[..., 3]
333+
return iou_sum / (tp + smooth_numerator)

monai/transforms/croppad/functional.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,27 @@ def pad_nd(
9191
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
9292
kwargs: other arguments for the `np.pad` or `torch.pad` function.
9393
note that `np.pad` treats channel dimension as the first dimension.
94+
Raises:
95+
ValueError: If `value` is provided when `mode` is not ``"constant"``.
9496
"""
97+
if mode != "constant" and "value" in kwargs:
98+
raise ValueError("'value' argument is only valid when mode='constant'")
9599
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
96100
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
97101
try:
98102
_pad = _np_pad
99-
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in {
100-
torch.int16,
101-
torch.int64,
102-
torch.bool,
103-
torch.uint8,
104-
}:
103+
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}:
104+
# Try PyTorch pad for these modes; fallback to NumPy on error.
105105
_pad = _pt_pad
106106
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
107+
except NotImplementedError:
108+
# PyTorch does not support this combination, fall back to NumPy
109+
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
107110
except (ValueError, TypeError, RuntimeError) as err:
108-
if isinstance(err, NotImplementedError) or any(
109-
k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")
110-
):
111+
# PyTorch may raise generic errors for unsupported modes/dtypes or kwargs.
112+
# Since there are no stable exception types for these cases, we fall back
113+
# to NumPy by matching known error message patterns.
114+
if any(k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")):
111115
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
112116
raise ValueError(
113117
f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}"

monai/transforms/inverse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from monai import transforms
2323
from monai.data.meta_obj import MetaObj, get_track_meta
2424
from monai.data.meta_tensor import MetaTensor
25-
from monai.data.utils import to_affine_nd
25+
from monai.data.utils import affine_to_spacing, to_affine_nd
2626
from monai.transforms.traits import InvertibleTrait
2727
from monai.transforms.transform import Transform
2828
from monai.utils import (
@@ -224,6 +224,9 @@ def track_transform_meta(
224224
else:
225225
raise
226226
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64)
227+
if MetaKeys.PIXDIM in out_obj.meta:
228+
spacing = affine_to_spacing(out_obj.meta[MetaKeys.AFFINE])
229+
out_obj.meta[MetaKeys.PIXDIM][1 : 1 + len(spacing)] = spacing
227230

228231
if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
229232
if isinstance(data, Mapping):

monai/transforms/spatial/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,13 @@ def __init__(
24362436
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
24372437
- :py:class:`Affine` for the affine transformation parameters configurations.
24382438
2439+
Note:
2440+
The affine transformations in MONAI use a 'backward mapping' (image-to-grid) logic.
2441+
This can be counter-intuitive:
2442+
- Translation: A positive value shifts the image in the negative direction.
2443+
- Scaling: Positive scale_range values decrease the image size; values in [-1, 0) increase it.
2444+
- Rotation: The direction (CW/CCW) may vary depending on the axis.
2445+
24392446
"""
24402447
RandomizableTransform.__init__(self, prob)
24412448
LazyTransform.__init__(self, lazy=lazy)

0 commit comments

Comments
 (0)