Skip to content

Commit 2ae9c79

Browse files
committed
fix: count_input_cfg_levels now resolves string file references
When input_cfg is overridden via CLI to a YAML file path (e.g. model.train_ds.input_cfg=train_all.yaml), the level counter only saw the top-level string and reported 1 level, causing reweight_temperature to broadcast incorrectly. Now it loads referenced YAML files to discover nested input_cfg keys, matching the runtime behavior of parse_and_combine_datasets. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
1 parent 78694d5 commit 2ae9c79

2 files changed

Lines changed: 51 additions & 2 deletions

File tree

nemo/collections/common/data/lhotse/cutset.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,12 @@ def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
476476
the same temperature (due to propagate_attrs.copy()), we count max depth,
477477
not total occurrences.
478478
479+
String/Path values for ``input_cfg`` are treated as file references (mirroring
480+
:func:`parse_and_combine_datasets`) and loaded so that nested ``input_cfg``
481+
keys inside those files are counted. If a file cannot be loaded (e.g. the
482+
path contains unresolved OmegaConf interpolations), it is conservatively
483+
counted as one additional level.
484+
479485
Args:
480486
config: Configuration dictionary that may contain nested 'input_cfg' keys.
481487
@@ -493,13 +499,22 @@ def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
493499
2
494500
"""
495501

502+
def _resolve_if_path(val):
503+
"""If *val* is a string/Path, try to load the YAML it points to."""
504+
if isinstance(val, (str, Path)):
505+
try:
506+
return load_yaml(str(val))
507+
except Exception:
508+
return val
509+
return val
510+
496511
def _max_depth(obj) -> int:
497512
if isinstance(obj, (dict, DictConfig)):
498513
depths = []
499514
for key, val in obj.items():
500515
if key == "input_cfg":
501-
# Found input_cfg: this level counts as 1 + max depth of children
502-
depths.append(1 + _max_depth(val))
516+
resolved = _resolve_if_path(val)
517+
depths.append(1 + _max_depth(resolved))
503518
else:
504519
depths.append(_max_depth(val))
505520
return max(depths, default=0)

tests/collections/common/test_lhotse_temperature_reweighting.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,40 @@ def test_with_omegaconf_dictconfig(self):
400400
)
401401
assert count_input_cfg_levels(config) == 2
402402

403+
def test_string_input_cfg_resolves_yaml_file(self, tmp_path):
404+
"""String input_cfg pointing to a YAML file is loaded and traversed."""
405+
inner_yaml = tmp_path / "inner.yaml"
406+
inner_yaml.write_text(
407+
"- type: lhotse_shar\n shar_path: /path1\n weight: 1.0\n"
408+
"- type: lhotse_shar\n shar_path: /path2\n weight: 2.0\n"
409+
)
410+
config = {"input_cfg": str(inner_yaml)}
411+
assert count_input_cfg_levels(config) == 1
412+
413+
def test_two_level_string_input_cfg(self, tmp_path):
414+
"""Top-level string ref → YAML with groups whose input_cfg are strings."""
415+
lang_yaml = tmp_path / "train_lang.yaml"
416+
lang_yaml.write_text("- type: lhotse_shar\n shar_path: /audio/cuts\n weight: 3.5\n")
417+
top_yaml = tmp_path / "train_all.yaml"
418+
top_yaml.write_text(
419+
f"- type: group\n weight: 100\n input_cfg: {lang_yaml}\n"
420+
f"- type: group\n weight: 200\n input_cfg: {lang_yaml}\n"
421+
)
422+
config = {"input_cfg": str(top_yaml)}
423+
assert count_input_cfg_levels(config) == 2
424+
425+
def test_unresolvable_string_input_cfg_counts_as_one(self):
426+
"""Unresolvable path (e.g. OmegaConf interpolation) counts as 1 level."""
427+
config = {
428+
"input_cfg": [
429+
{
430+
"type": "group",
431+
"input_cfg": "${oc.env:MANIFEST_ROOT}/train_ar-AE.yaml",
432+
},
433+
]
434+
}
435+
assert count_input_cfg_levels(config) == 2
436+
403437

404438
class TestReweightTemperatureValidation:
405439
def test_scalar_temperature_broadcasts_to_all_levels(self):

0 commit comments

Comments
 (0)