Skip to content

Commit 8d730cd

Browse files
authored
6387 update_kwargs for merging multiple configs (#7109)
Fixes #6387 Fixes #5899 ### Description - add api for update_kwargs - add support of merging multiple configs files and dictionaries - remove warning message of directory in `runtests.sh` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent fc1350a commit 8d730cd

File tree

10 files changed

+52
-32
lines changed

10 files changed

+52
-32
lines changed

docs/source/bundle.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ Model Bundle
4848
.. autofunction:: verify_metadata
4949
.. autofunction:: verify_net_in_out
5050
.. autofunction:: init_bundle
51+
.. autofunction:: update_kwargs

monai/bundle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
run,
3030
run_workflow,
3131
trt_export,
32+
update_kwargs,
3233
verify_metadata,
3334
verify_net_in_out,
3435
)

monai/bundle/config_parser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,16 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs
412412
413413
Args:
414414
files: path of target files to load, supported postfixes: `.json`, `.yml`, `.yaml`.
415-
if providing a list of files, wil merge the content of them.
415+
if providing a list of files, will merge the content of them.
416+
if providing a string with comma separated file paths, will merge the content of them.
416417
if providing a dictionary, return it directly.
417418
kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format.
418419
"""
419420
if isinstance(files, dict): # already a config dict
420421
return files
421422
parser = ConfigParser(config={})
423+
if isinstance(files, str) and not Path(files).is_file() and "," in files:
424+
files = files.split(",")
422425
for i in ensure_tuple(files):
423426
for k, v in (cls.load_config_file(i, **kwargs)).items():
424427
parser[k] = v

monai/bundle/scripts.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,33 +66,46 @@
6666
PPRINT_CONFIG_N = 5
6767

6868

69-
def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
69+
def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
7070
"""
71-
Update the `args` with the input `kwargs`.
71+
Update the `args` dictionary with the input `kwargs`.
7272
For dict data, recursively update the content based on the keys.
7373
74+
Example::
75+
76+
from monai.bundle import update_kwargs
77+
update_kwargs({'exist': 1}, exist=2, new_arg=3)
78+
# return {'exist': 2, 'new_arg': 3}
79+
7480
Args:
75-
args: source args to update.
81+
args: source `args` dictionary (or a json/yaml filename to read as dictionary) to update.
7682
ignore_none: whether to ignore input args with None value, default to `True`.
77-
kwargs: destination args to update.
83+
kwargs: key=value pairs to be merged into `args`.
7884
7985
"""
8086
args_: dict = args if isinstance(args, dict) else {}
8187
if isinstance(args, str):
8288
# args are defined in a structured file
8389
args_ = ConfigParser.load_config_file(args)
90+
if isinstance(args, (tuple, list)) and all(isinstance(x, str) for x in args):
91+
primary, overrides = args
92+
args_ = update_kwargs(primary, ignore_none, **update_kwargs(overrides, ignore_none, **kwargs))
93+
if not isinstance(args_, dict):
94+
return args_
8495
# recursively update the default args with new args
8596
for k, v in kwargs.items():
86-
print(k, v)
8797
if ignore_none and v is None:
8898
continue
8999
if isinstance(v, dict) and isinstance(args_.get(k), dict):
90-
args_[k] = _update_args(args_[k], ignore_none, **v)
100+
args_[k] = update_kwargs(args_[k], ignore_none, **v)
91101
else:
92102
args_[k] = v
93103
return args_
94104

95105

106+
_update_args = update_kwargs # backward compatibility
107+
108+
96109
def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple:
97110
"""
98111
Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`.
@@ -318,7 +331,7 @@ def download(
318331
so that the command line inputs can be simplified.
319332
320333
"""
321-
_args = _update_args(
334+
_args = update_kwargs(
322335
args=args_file,
323336
name=name,
324337
version=version,
@@ -834,7 +847,7 @@ def verify_metadata(
834847
835848
"""
836849

837-
_args = _update_args(
850+
_args = update_kwargs(
838851
args=args_file,
839852
meta_file=meta_file,
840853
filepath=filepath,
@@ -958,7 +971,7 @@ def verify_net_in_out(
958971
959972
"""
960973

961-
_args = _update_args(
974+
_args = update_kwargs(
962975
args=args_file,
963976
net_id=net_id,
964977
meta_file=meta_file,
@@ -1127,7 +1140,7 @@ def onnx_export(
11271140
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.
11281141
11291142
"""
1130-
_args = _update_args(
1143+
_args = update_kwargs(
11311144
args=args_file,
11321145
net_id=net_id,
11331146
filepath=filepath,
@@ -1242,7 +1255,7 @@ def ckpt_export(
12421255
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.
12431256
12441257
"""
1245-
_args = _update_args(
1258+
_args = update_kwargs(
12461259
args=args_file,
12471260
net_id=net_id,
12481261
filepath=filepath,
@@ -1401,7 +1414,7 @@ def trt_export(
14011414
e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.
14021415
14031416
"""
1404-
_args = _update_args(
1417+
_args = update_kwargs(
14051418
args=args_file,
14061419
net_id=net_id,
14071420
filepath=filepath,
@@ -1614,7 +1627,7 @@ def create_workflow(
16141627
kwargs: arguments to instantiate the workflow class.
16151628
16161629
"""
1617-
_args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
1630+
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
16181631
_log_input_summary(tag="run", args=_args)
16191632
(workflow_name, config_file) = _pop_args(
16201633
_args, workflow_name=ConfigWorkflow, config_file=None

monai/bundle/workflows.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import os
1515
import sys
1616
import time
17-
import warnings
1817
from abc import ABC, abstractmethod
1918
from copy import copy
2019
from logging.config import fileConfig
@@ -158,7 +157,7 @@ def add_property(self, name: str, required: str, desc: str | None = None) -> Non
158157
if self.properties is None:
159158
self.properties = {}
160159
if name in self.properties:
161-
warnings.warn(f"property '{name}' already exists in the properties list, overriding it.")
160+
logger.warn(f"property '{name}' already exists in the properties list, overriding it.")
162161
self.properties[name] = {BundleProperty.DESC: desc, BundleProperty.REQUIRED: required}
163162

164163
def check_properties(self) -> list[str] | None:
@@ -241,7 +240,7 @@ def __init__(
241240
for _config_file in _config_files:
242241
_config_file = Path(_config_file)
243242
if _config_file.parent != self.config_root_path:
244-
warnings.warn(
243+
logger.warn(
245244
f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are"
246245
f"not specified, {self.config_root_path} will be used as the default config root directory."
247246
)
@@ -254,7 +253,7 @@ def __init__(
254253
if logging_file is not None:
255254
if not os.path.exists(logging_file):
256255
if logging_file == str(self.config_root_path / "logging.conf"):
257-
warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
256+
logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
258257
else:
259258
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
260259
else:
@@ -265,7 +264,10 @@ def __init__(
265264
self.parser.read_config(f=config_file)
266265
meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file
267266
if isinstance(meta_file, str) and not os.path.exists(meta_file):
268-
raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
267+
logger.error(
268+
f"Cannot find the metadata config file: {meta_file}. "
269+
"Please see: https://docs.monai.io/en/stable/mb_specification.html"
270+
)
269271
else:
270272
self.parser.read_meta(f=meta_file)
271273

@@ -323,17 +325,17 @@ def check_properties(self) -> list[str] | None:
323325
"""
324326
ret = super().check_properties()
325327
if self.properties is None:
326-
warnings.warn("No available properties had been set, skipping check.")
328+
logger.warn("No available properties had been set, skipping check.")
327329
return None
328330
if ret:
329-
warnings.warn(f"Loaded bundle does not contain the following required properties: {ret}")
331+
logger.warn(f"Loaded bundle does not contain the following required properties: {ret}")
330332
# also check whether the optional properties use correct ID name if existing
331333
wrong_props = []
332334
for n, p in self.properties.items():
333335
if not p.get(BundleProperty.REQUIRED, False) and not self._check_optional_id(name=n, property=p):
334336
wrong_props.append(n)
335337
if wrong_props:
336-
warnings.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}")
338+
logger.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}")
337339
if ret is not None:
338340
ret.extend(wrong_props)
339341
return ret

monai/data/meta_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def astype(self, dtype, device=None, *_args, **_kwargs):
462462
@property
463463
def affine(self) -> torch.Tensor:
464464
"""Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``"""
465-
return self.meta.get(MetaKeys.AFFINE, self.get_default_affine())
465+
return self.meta.get(MetaKeys.AFFINE, self.get_default_affine()) # type: ignore
466466

467467
@affine.setter
468468
def affine(self, d: NdarrayTensor) -> None:

runtests.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function print_usage {
108108
echo " -c, --clean : clean temporary files from tests and exit"
109109
echo " -h, --help : show this help message and exit"
110110
echo " -v, --version : show MONAI and system version information and exit"
111-
echo " -p, --path : specify the path used for formatting"
111+
echo " -p, --path : specify the path used for formatting, default is the current dir if unspecified"
112112
echo " --formatfix : format code using \"isort\" and \"black\" for user specified directories"
113113
echo ""
114114
echo "${separator}For bug reports and feature requests, please file an issue at:"
@@ -359,10 +359,9 @@ if [ -e "$testdir" ]
359359
then
360360
homedir=$testdir
361361
else
362-
print_error_msg "Incorrect path: $testdir provided, run under $currentdir"
363362
homedir=$currentdir
364363
fi
365-
echo "run tests under $homedir"
364+
echo "Run tests under $homedir"
366365
cd "$homedir"
367366

368367
# python path

tests/test_bundle_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from monai.bundle import update_kwargs
2122
from monai.bundle.utils import load_bundle_config
2223
from monai.networks.nets import UNet
2324
from monai.utils import pprint_edges
@@ -141,6 +142,7 @@ def test_str(self):
141142
"[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]",
142143
)
143144
self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3))
145+
self.assertEqual(update_kwargs({"a": 1}, a=2, b=3), {"a": 2, "b": 3})
144146

145147

146148
if __name__ == "__main__":

tests/test_integration_bundle_run.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,8 @@ def test_tiny(self):
8686
with self.assertRaises(RuntimeError):
8787
# test wrong run_id="run"
8888
command_line_tests(cmd + ["run", "run", "--config_file", config_file])
89-
with self.assertRaises(RuntimeError):
90-
# test missing meta file
91-
command_line_tests(cmd + ["run", "training", "--config_file", config_file])
89+
# test missing meta file
90+
self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file]))
9291

9392
def test_scripts_fold(self):
9493
# test scripts directory has been added to Python search directories automatically
@@ -150,9 +149,8 @@ def test_scripts_fold(self):
150149
print(output)
151150
self.assertTrue(expected_condition in output)
152151

153-
with self.assertRaises(RuntimeError):
154-
# test missing meta file
155-
command_line_tests(cmd + ["run", "training", "--config_file", config_file])
152+
# test missing meta file
153+
self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file]))
156154

157155
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
158156
def test_shape(self, config_file, expected_shape):

tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def command_line_tests(cmd, copy_env=True):
818818
try:
819819
normal_out = subprocess.run(cmd, env=test_env, check=True, capture_output=True)
820820
print(repr(normal_out).replace("\\n", "\n").replace("\\t", "\t"))
821+
return repr(normal_out)
821822
except subprocess.CalledProcessError as e:
822823
output = repr(e.stdout).replace("\\n", "\n").replace("\\t", "\t")
823824
errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t")

0 commit comments

Comments
 (0)