Skip to content

Commit f3e7cc0

Browse files
Nic-Mamonai-bot
andauthored
3430 support dataframes and streams in CSVDataset (#3440)
* [DLMED] add dataframe Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] enhance CSV iterable dataset Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] add unit tests Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] fix typehints Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] add comment Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] fix file close issue Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] fix doc Signed-off-by: Nic Ma <nma@nvidia.com> * [MONAI] python code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent 0935d5a commit f3e7cc0

5 files changed

Lines changed: 148 additions & 35 deletions

File tree

monai/data/dataset.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
3434
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform
35-
from monai.utils import MAX_SEED, ensure_tuple, get_seed, look_up_option, min_version, optional_import
35+
from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import
3636
from monai.utils.misc import first
3737

3838
if TYPE_CHECKING:
@@ -1222,8 +1222,9 @@ class CSVDataset(Dataset):
12221222
]
12231223
12241224
Args:
1225-
filename: the filename of expected CSV file to load. if providing a list
1226-
of filenames, it will load all the files and join tables.
1225+
src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.
1226+
also support to provide pandas `DataFrame` directly, will skip loading from filename.
1227+
if provided a list of filenames or pandas `DataFrame`, it will join the tables.
12271228
row_indices: indices of the expected rows to load. it should be a list,
12281229
every item can be a int number or a range `[start, end)` for the indices.
12291230
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
@@ -1249,20 +1250,32 @@ class CSVDataset(Dataset):
12491250
transform: transform to apply on the loaded items of a dictionary data.
12501251
kwargs: additional arguments for `pandas.merge()` API to join tables.
12511252
1253+
.. deprecated:: 0.8.0
1254+
``filename`` is deprecated, use ``src`` instead.
1255+
12521256
"""
12531257

1258+
@deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.")
12541259
def __init__(
12551260
self,
1256-
filename: Union[str, Sequence[str]],
1261+
src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or sequense of `DataFrame`
12571262
row_indices: Optional[Sequence[Union[int, str]]] = None,
12581263
col_names: Optional[Sequence[str]] = None,
12591264
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
12601265
col_groups: Optional[Dict[str, Sequence[str]]] = None,
12611266
transform: Optional[Callable] = None,
12621267
**kwargs,
12631268
):
1264-
files = ensure_tuple(filename)
1265-
dfs = [pd.read_csv(f) for f in files]
1269+
srcs = (src,) if not isinstance(src, (tuple, list)) else src
1270+
dfs: List = []
1271+
for i in srcs:
1272+
if isinstance(i, str):
1273+
dfs.append(pd.read_csv(i))
1274+
elif isinstance(i, pd.DataFrame):
1275+
dfs.append(i)
1276+
else:
1277+
raise ValueError("`src` must be file path or pandas `DataFrame`.")
1278+
12661279
data = convert_tables_to_dicts(
12671280
dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs
12681281
)

monai/data/iterable_dataset.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union
12+
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
1313

1414
import numpy as np
1515
from torch.utils.data import IterableDataset as _TorchIterableDataset
@@ -18,7 +18,7 @@
1818
from monai.data.utils import convert_tables_to_dicts
1919
from monai.transforms import apply_transform
2020
from monai.transforms.transform import Randomizable
21-
from monai.utils import ensure_tuple, optional_import
21+
from monai.utils import deprecated_arg, optional_import
2222

2323
pd, _ = optional_import("pandas")
2424

@@ -147,8 +147,9 @@ class CSVIterableDataset(IterableDataset):
147147
]
148148
149149
Args:
150-
filename: the filename of CSV file to load. it can be a str, URL, path object or file-like object.
151-
if providing a list of filenames, it will load all the files and join tables.
150+
src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.
151+
also support to provide iter for stream input directly, will skip loading from filename.
152+
if provided a list of filenames or iters, it will join the tables.
152153
chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details:
153154
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.
154155
buffer_size: size of the buffer to store the loaded chunks, if None, set to `2 x chunksize`.
@@ -177,11 +178,15 @@ class CSVIterableDataset(IterableDataset):
177178
https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.
178179
kwargs: additional arguments for `pandas.merge()` API to join tables.
179180
181+
.. deprecated:: 0.8.0
182+
``filename`` is deprecated, use ``src`` instead.
183+
180184
"""
181185

186+
@deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.")
182187
def __init__(
183188
self,
184-
filename: Union[str, Sequence[str]],
189+
src: Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]],
185190
chunksize: int = 1000,
186191
buffer_size: Optional[int] = None,
187192
col_names: Optional[Sequence[str]] = None,
@@ -192,7 +197,7 @@ def __init__(
192197
seed: int = 0,
193198
**kwargs,
194199
):
195-
self.files = ensure_tuple(filename)
200+
self.src = src
196201
self.chunksize = chunksize
197202
self.buffer_size = 2 * chunksize if buffer_size is None else buffer_size
198203
self.col_names = col_names
@@ -201,16 +206,46 @@ def __init__(
201206
self.shuffle = shuffle
202207
self.seed = seed
203208
self.kwargs = kwargs
204-
self.iters = self.reset()
209+
self.iters: List[Iterable] = self.reset()
205210
super().__init__(data=None, transform=transform) # type: ignore
206211

207-
def reset(self, filename: Optional[Union[str, Sequence[str]]] = None):
208-
if filename is not None:
209-
# update files if necessary
210-
self.files = ensure_tuple(filename)
211-
self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files]
212+
@deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.")
213+
def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]]] = None):
214+
"""
215+
Reset the pandas `TextFileReader` iterable object to read data. For more details, please check:
216+
https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.
217+
218+
Args:
219+
src: if not None and provided the filename of CSV file, it can be a str, URL, path object
220+
or file-like object to load. also support to provide iter for stream input directly,
221+
will skip loading from filename. if provided a list of filenames or iters, it will join the tables.
222+
default to `self.src`.
223+
224+
"""
225+
src = self.src if src is None else src
226+
srcs = (src,) if not isinstance(src, (tuple, list)) else src
227+
self.iters = []
228+
for i in srcs:
229+
if isinstance(i, str):
230+
self.iters.append(pd.read_csv(i, chunksize=self.chunksize))
231+
elif isinstance(i, Iterable):
232+
self.iters.append(i)
233+
else:
234+
raise ValueError("`src` must be file path or iterable object.")
212235
return self.iters
213236

237+
def close(self):
238+
"""
239+
Close the pandas `TextFileReader` iterable objects.
240+
If the input src is file path, TextFileReader was created internally, need to close it.
241+
If the input src is iterable object, depends on users requirements whether to close it in this function.
242+
For more details, please check:
243+
https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.
244+
245+
"""
246+
for i in self.iters:
247+
i.close()
248+
214249
def _flattened(self):
215250
for chunks in zip(*self.iters):
216251
yield from convert_tables_to_dicts(

monai/networks/blocks/activation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
def monai_mish(x, inplace: bool = False):
2020
return torch.nn.functional.mish(x, inplace=inplace)
2121

22-
2322
else:
2423

2524
def monai_mish(x, inplace: bool = False):
@@ -31,7 +30,6 @@ def monai_mish(x, inplace: bool = False):
3130
def monai_swish(x, inplace: bool = False):
3231
return torch.nn.functional.silu(x, inplace=inplace)
3332

34-
3533
else:
3634

3735
def monai_swish(x, inplace: bool = False):

tests/test_csv_dataset.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import unittest
1515

1616
import numpy as np
17+
import pandas as pd
1718

1819
from monai.data import CSVDataset
1920
from monai.transforms import ToNumpyd
@@ -57,6 +58,7 @@ def prepare_csv_file(data, filepath):
5758
filepath1 = os.path.join(tempdir, "test_data1.csv")
5859
filepath2 = os.path.join(tempdir, "test_data2.csv")
5960
filepath3 = os.path.join(tempdir, "test_data3.csv")
61+
filepaths = [filepath1, filepath2, filepath3]
6062
prepare_csv_file(test_data1, filepath1)
6163
prepare_csv_file(test_data2, filepath2)
6264
prepare_csv_file(test_data3, filepath3)
@@ -76,7 +78,7 @@ def prepare_csv_file(data, filepath):
7678
)
7779

7880
# test multiple CSV files, join tables with kwargs
79-
dataset = CSVDataset([filepath1, filepath2, filepath3], on="subject_id")
81+
dataset = CSVDataset(filepaths, on="subject_id")
8082
self.assertDictEqual(
8183
{k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[3].items()},
8284
{
@@ -102,7 +104,7 @@ def prepare_csv_file(data, filepath):
102104

103105
# test selected rows and columns
104106
dataset = CSVDataset(
105-
filename=[filepath1, filepath2, filepath3],
107+
src=filepaths,
106108
row_indices=[[0, 2], 3], # load row: 0, 1, 3
107109
col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"],
108110
)
@@ -120,7 +122,7 @@ def prepare_csv_file(data, filepath):
120122

121123
# test group columns
122124
dataset = CSVDataset(
123-
filename=[filepath1, filepath2, filepath3],
125+
src=filepaths,
124126
row_indices=[1, 3], # load row: 1, 3
125127
col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"],
126128
col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]},
@@ -133,9 +135,7 @@ def prepare_csv_file(data, filepath):
133135

134136
# test transform
135137
dataset = CSVDataset(
136-
filename=[filepath1, filepath2, filepath3],
137-
col_groups={"ehr": [f"ehr_{i}" for i in range(5)]},
138-
transform=ToNumpyd(keys="ehr"),
138+
src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr")
139139
)
140140
self.assertEqual(len(dataset), 5)
141141
expected = [
@@ -151,7 +151,7 @@ def prepare_csv_file(data, filepath):
151151

152152
# test default values and dtype
153153
dataset = CSVDataset(
154-
filename=[filepath1, filepath2, filepath3],
154+
src=filepaths,
155155
col_names=["subject_id", "image", "ehr_1", "ehr_9", "meta_1"],
156156
col_types={"image": {"type": str, "default": "No image"}, "ehr_1": {"type": int, "default": 0}},
157157
how="outer", # generate NaN values in this merge mode
@@ -161,6 +161,29 @@ def prepare_csv_file(data, filepath):
161161
self.assertEqual(type(dataset[-1]["ehr_1"]), int)
162162
np.testing.assert_allclose(dataset[-1]["ehr_9"], 3.3537, rtol=1e-2)
163163

164+
# test pre-loaded DataFrame
165+
df = pd.read_csv(filepath1)
166+
dataset = CSVDataset(src=df)
167+
self.assertDictEqual(
168+
{k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()},
169+
{
170+
"subject_id": "s000002",
171+
"label": 4,
172+
"image": "./imgs/s000002.png",
173+
"ehr_0": 3.7725,
174+
"ehr_1": 4.2118,
175+
"ehr_2": 4.6353,
176+
},
177+
)
178+
179+
# test pre-loaded multiple DataFrames, join tables with kwargs
180+
dfs = [pd.read_csv(i) for i in filepaths]
181+
dataset = CSVDataset(src=dfs, on="subject_id")
182+
self.assertEqual(dataset[3]["subject_id"], "s000003")
183+
self.assertEqual(dataset[3]["label"], 1)
184+
self.assertEqual(round(dataset[3]["ehr_0"], 4), 3.3333)
185+
self.assertEqual(dataset[3]["meta_0"], False)
186+
164187

165188
if __name__ == "__main__":
166189
unittest.main()

0 commit comments

Comments
 (0)