Skip to content

Commit 46b2c0b

Browse files
Enhance ExtractDataKeyFromMetaKeyd to work with MetaTensor (#8772)
## Fixes #7562 ### Description Enhances `ExtractDataKeyFromMetaKeyd` to support extracting metadata from `MetaTensor` objects, in addition to plain metadata dictionaries. **Before:** Only worked with metadata dictionaries (`image_only=False`): ```python li = LoadImaged(image_only=False) dat = li({"image": "image.nii"}) # {"image": tensor, "image_meta_dict": dict} e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image_meta_dict") dat = e(dat) # extracts from dict ``` **After:** Also works with MetaTensor (`image_only=True`, the default): ```python li = LoadImaged() # image_only=True by default dat = li({"image": "image.nii"}) # {"image": MetaTensor} e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image") dat = e(dat) # extracts from MetaTensor.meta assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"] ``` ### Changes 1. **`monai/apps/reconstruction/transforms/dictionary.py`**: - Added `MetaTensor` import - Modified `ExtractDataKeyFromMetaKeyd.__call__()` to detect if `meta_key` references a `MetaTensor` and extract from its `.meta` attribute - Updated docstring with both usage modes and examples 2. **`tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py`** (new): - 8 test cases covering: dict extraction, MetaTensor extraction, multiple keys, missing keys (with/without `allow_missing_keys`), and data preservation ### Testing - [x] New unit tests for both dict-based and MetaTensor-based extraction - [x] Tests for edge cases (missing keys, allow_missing_keys) - [x] Backward compatible — existing dict-based usage unchanged Signed-off-by: haoyu-haoyu <haoyu-haoyu@users.noreply.github.com> Signed-off-by: haoyu-haoyu <haoyu-haoyu@users.noreply.github.com> Signed-off-by: SexyERIC0723 <haoyuwang144@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent a8176f1 commit 46b2c0b

File tree

2 files changed

+149
-3
lines changed

2 files changed

+149
-3
lines changed

monai/apps/reconstruction/transforms/dictionary.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
2121
from monai.config import DtypeLike, KeysCollection
2222
from monai.config.type_definitions import NdarrayOrTensor
23+
from monai.data.meta_tensor import MetaTensor
2324
from monai.transforms import InvertibleTransform
2425
from monai.transforms.croppad.array import SpatialCrop
2526
from monai.transforms.intensity.array import NormalizeIntensity
@@ -33,15 +34,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3334
Moves keys from meta to data. It is useful when a dataset of paired samples
3435
is loaded and certain keys should be moved from meta to data.
3536
37+
This transform supports two modes:
38+
39+
1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
40+
``image_only=False`` was used with ``LoadImaged``), the requested keys are
41+
extracted directly from that dictionary.
42+
43+
2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
44+
``image_only=True`` was used with ``LoadImaged``), the requested keys are
45+
extracted from its ``.meta`` attribute.
46+
3647
Args:
3748
keys: keys to be transferred from meta to data
38-
meta_key: the meta key where all the meta-data is stored
49+
meta_key: the key in the data dictionary where the metadata source is
50+
stored. This can be either a metadata dictionary or a ``MetaTensor``.
3951
allow_missing_keys: don't raise exception if key is missing
4052
4153
Example:
4254
When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4355
but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4456
In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
57+
58+
When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
59+
data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
60+
set ``meta_key`` to the key of the ``MetaTensor`` itself::
61+
62+
li = LoadImaged(keys="image") # image_only=True by default
63+
dat = li({"image": "image.nii"})
64+
e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
65+
dat = e(dat)
66+
assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4567
"""
4668

4769
def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:
@@ -58,9 +80,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5880
the new data dictionary
5981
"""
6082
d = dict(data)
83+
meta_obj = d[self.meta_key]
84+
85+
# If meta_key references a MetaTensor, extract from its .meta attribute;
86+
# otherwise treat it as a metadata dictionary directly.
87+
if isinstance(meta_obj, MetaTensor):
88+
meta_dict: dict = meta_obj.meta
89+
else:
90+
meta_dict = dict(meta_obj)
91+
6192
for key in self.keys:
62-
if key in d[self.meta_key]:
63-
d[key] = d[self.meta_key][key] # type: ignore
93+
if key in meta_dict:
94+
d[key] = meta_dict[key] # type: ignore
6495
elif not self.allow_missing_keys:
6596
raise KeyError(
6697
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data"
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
18+
from monai.apps.reconstruction.transforms.dictionary import ExtractDataKeyFromMetaKeyd
19+
from monai.data import MetaTensor
20+
21+
22+
class TestExtractDataKeyFromMetaKeyd(unittest.TestCase):
23+
"""Tests for ExtractDataKeyFromMetaKeyd covering both dict-based and MetaTensor-based metadata."""
24+
25+
def test_extract_from_dict(self):
26+
"""Test extracting keys from a plain metadata dictionary (image_only=False scenario)."""
27+
data = {
28+
"image": torch.zeros(1, 2, 2),
29+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
30+
}
31+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image_meta_dict")
32+
result = transform(data)
33+
self.assertIn("filename_or_obj", result)
34+
self.assertEqual(result["filename_or_obj"], "image.nii")
35+
self.assertEqual(result["image_meta_dict"]["filename_or_obj"], result["filename_or_obj"])
36+
37+
def test_extract_from_metatensor(self):
38+
"""Test extracting keys from a MetaTensor's .meta attribute (image_only=True scenario)."""
39+
mt = MetaTensor(torch.zeros(1, 2, 2))
40+
mt.meta["filename_or_obj"] = "image.nii"
41+
mt.meta["spatial_shape"] = [2, 2]
42+
data = {"image": mt}
43+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
44+
result = transform(data)
45+
self.assertIn("filename_or_obj", result)
46+
self.assertEqual(result["filename_or_obj"], "image.nii")
47+
self.assertEqual(result["image"].meta["filename_or_obj"], result["filename_or_obj"])
48+
49+
def test_extract_multiple_keys_from_metatensor(self):
50+
"""Test extracting multiple keys from a MetaTensor."""
51+
mt = MetaTensor(torch.zeros(1, 2, 2))
52+
mt.meta["filename_or_obj"] = "image.nii"
53+
mt.meta["spatial_shape"] = [2, 2]
54+
data = {"image": mt}
55+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image")
56+
result = transform(data)
57+
self.assertIn("filename_or_obj", result)
58+
self.assertIn("spatial_shape", result)
59+
self.assertEqual(result["filename_or_obj"], "image.nii")
60+
self.assertEqual(result["spatial_shape"], [2, 2])
61+
62+
def test_extract_multiple_keys_from_dict(self):
63+
"""Test extracting multiple keys from a plain dictionary."""
64+
data = {
65+
"image": torch.zeros(1, 2, 2),
66+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
67+
}
68+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image_meta_dict")
69+
result = transform(data)
70+
self.assertIn("filename_or_obj", result)
71+
self.assertIn("spatial_shape", result)
72+
self.assertEqual(result["filename_or_obj"], "image.nii")
73+
self.assertEqual(result["spatial_shape"], [2, 2])
74+
75+
def test_missing_key_raises(self):
76+
"""Test that a missing key raises KeyError when allow_missing_keys=False."""
77+
mt = MetaTensor(torch.zeros(1, 2, 2))
78+
mt.meta["filename_or_obj"] = "image.nii"
79+
data = {"image": mt}
80+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image")
81+
with self.assertRaises(KeyError):
82+
transform(data)
83+
84+
def test_missing_key_allowed_metatensor(self):
85+
"""Test that a missing key is silently skipped when allow_missing_keys=True with MetaTensor."""
86+
mt = MetaTensor(torch.zeros(1, 2, 2))
87+
mt.meta["filename_or_obj"] = "image.nii"
88+
data = {"image": mt}
89+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image", allow_missing_keys=True)
90+
result = transform(data)
91+
self.assertNotIn("nonexistent_key", result)
92+
93+
def test_missing_key_allowed_dict(self):
94+
"""Test that a missing key is silently skipped when allow_missing_keys=True with dict."""
95+
data = {"image": torch.zeros(1, 2, 2), "image_meta_dict": {"filename_or_obj": "image.nii"}}
96+
transform = ExtractDataKeyFromMetaKeyd(
97+
keys="nonexistent_key", meta_key="image_meta_dict", allow_missing_keys=True
98+
)
99+
result = transform(data)
100+
self.assertNotIn("nonexistent_key", result)
101+
102+
def test_original_data_preserved_metatensor(self):
103+
"""Test that the original MetaTensor remains in the data dictionary."""
104+
mt = MetaTensor(torch.ones(1, 2, 2))
105+
mt.meta["filename_or_obj"] = "image.nii"
106+
data = {"image": mt}
107+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
108+
result = transform(data)
109+
self.assertIn("image", result)
110+
self.assertIsInstance(result["image"], MetaTensor)
111+
self.assertTrue(torch.equal(result["image"], mt))
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)