Skip to content

Commit 6335c7c

Browse files
authored
fix: count_input_cfg_levels now resolves string file references (#15646)
* fix: count_input_cfg_levels now resolves string file references When input_cfg is overridden via CLI to a YAML file path (e.g. model.train_ds.input_cfg=train_all.yaml), the level counter only saw the top-level string and reported 1 level, causing reweight_temperature to broadcast incorrectly. Now it loads referenced YAML files to discover nested input_cfg keys, matching the runtime behavior of parse_and_combine_datasets. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * address Copilot review: memoize YAML loads, add debug logging, fix test name - Cache resolved YAML files in _resolve_if_path to avoid redundant disk/network I/O for sibling groups referencing the same file - Add logging.debug when a path cannot be resolved, so failures are never fully silent - Rename misleading test to match its assertion (total depth is 2, not 1; the unresolvable string is treated as a leaf) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
1 parent 24590f2 commit 6335c7c

3 files changed

Lines changed: 87 additions & 2 deletions

File tree

docs/source/audio/configs.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,28 @@ The maximum nesting depth is calculated as the maximum depth of ``input_cfg`` ke
221221
input_cfg: # Level 2 (same as above)
222222
- type: lhotse_shar
223223
224+
When ``input_cfg`` is overridden via CLI to a YAML file path (e.g.
225+
``model.train_ds.input_cfg=train_all.yaml``), the depth calculation loads the
226+
referenced file and traverses its contents to count nested ``input_cfg`` keys.
227+
This also works with multi-level file references:
228+
229+
.. code-block:: yaml
230+
231+
# train_all.yaml (referenced via input_cfg=train_all.yaml)
232+
- type: group
233+
weight: 100
234+
input_cfg: ${oc.env:MANIFEST_ROOT}/train_en.yaml # resolved at runtime
235+
- type: group
236+
weight: 200
237+
input_cfg: ${oc.env:MANIFEST_ROOT}/train_de.yaml
238+
239+
.. note::
240+
241+
Paths containing OmegaConf interpolations (e.g. ``${oc.env:MANIFEST_ROOT}``)
242+
cannot be resolved during depth counting -- they are resolved later at runtime
243+
by ``OmegaConf.create()``. Such paths are treated as a single additional
244+
nesting level.
245+
224246
**Example: Balancing Multiple Task Groups**
225247

226248
.. code-block:: yaml

nemo/collections/common/data/lhotse/cutset.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,13 @@ def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
476476
the same temperature (due to propagate_attrs.copy()), we count max depth,
477477
not total occurrences.
478478
479+
String/Path values for ``input_cfg`` are treated as file references (mirroring
480+
:func:`parse_and_combine_datasets`) and loaded so that nested ``input_cfg``
481+
keys inside those files are counted. If the file is not found (e.g. the
482+
path contains unresolved OmegaConf interpolations such as
483+
``${oc.env:MANIFEST_ROOT}``), it is conservatively counted as one additional
484+
level. All other I/O or parsing errors propagate immediately.
485+
479486
Args:
480487
config: Configuration dictionary that may contain nested 'input_cfg' keys.
481488
@@ -493,13 +500,35 @@ def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
493500
2
494501
"""
495502

503+
_cache: dict[str, object] = {}
504+
505+
def _resolve_if_path(val):
506+
"""If *val* is a string/Path, load the YAML file it points to.
507+
508+
Raises on I/O or parse errors except ``FileNotFoundError``, which is
509+
expected when the path contains OmegaConf interpolations (e.g.
510+
``${oc.env:MANIFEST_ROOT}/file.yaml``) that raw ``yaml.load`` returns
511+
as literal strings. ``parse_and_combine_datasets`` resolves them at
512+
runtime via ``OmegaConf.create()``.
513+
"""
514+
if isinstance(val, (str, Path)):
515+
key = str(val)
516+
if key not in _cache:
517+
try:
518+
_cache[key] = load_yaml(key)
519+
except FileNotFoundError:
520+
logging.debug("count_input_cfg_levels: could not load %r, treating as leaf", key)
521+
_cache[key] = val
522+
return _cache[key]
523+
return val
524+
496525
def _max_depth(obj) -> int:
497526
if isinstance(obj, (dict, DictConfig)):
498527
depths = []
499528
for key, val in obj.items():
500529
if key == "input_cfg":
501-
# Found input_cfg: this level counts as 1 + max depth of children
502-
depths.append(1 + _max_depth(val))
530+
resolved = _resolve_if_path(val)
531+
depths.append(1 + _max_depth(resolved))
503532
else:
504533
depths.append(_max_depth(val))
505534
return max(depths, default=0)

tests/collections/common/test_lhotse_temperature_reweighting.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,40 @@ def test_with_omegaconf_dictconfig(self):
400400
)
401401
assert count_input_cfg_levels(config) == 2
402402

403+
def test_string_input_cfg_resolves_yaml_file(self, tmp_path):
404+
"""String input_cfg pointing to a YAML file is loaded and traversed."""
405+
inner_yaml = tmp_path / "inner.yaml"
406+
inner_yaml.write_text(
407+
"- type: lhotse_shar\n shar_path: /path1\n weight: 1.0\n"
408+
"- type: lhotse_shar\n shar_path: /path2\n weight: 2.0\n"
409+
)
410+
config = {"input_cfg": str(inner_yaml)}
411+
assert count_input_cfg_levels(config) == 1
412+
413+
def test_two_level_string_input_cfg(self, tmp_path):
414+
"""Top-level string ref → YAML with groups whose input_cfg are strings."""
415+
lang_yaml = tmp_path / "train_lang.yaml"
416+
lang_yaml.write_text("- type: lhotse_shar\n shar_path: /audio/cuts\n weight: 3.5\n")
417+
top_yaml = tmp_path / "train_all.yaml"
418+
top_yaml.write_text(
419+
f"- type: group\n weight: 100\n input_cfg: {lang_yaml}\n"
420+
f"- type: group\n weight: 200\n input_cfg: {lang_yaml}\n"
421+
)
422+
config = {"input_cfg": str(top_yaml)}
423+
assert count_input_cfg_levels(config) == 2
424+
425+
def test_unresolvable_nested_string_input_cfg_is_treated_as_leaf(self):
426+
"""Nested unresolvable string input_cfg is treated as a leaf, so total depth is 2."""
427+
config = {
428+
"input_cfg": [
429+
{
430+
"type": "group",
431+
"input_cfg": "${oc.env:MANIFEST_ROOT}/train_ar-AE.yaml",
432+
},
433+
]
434+
}
435+
assert count_input_cfg_levels(config) == 2
436+
403437

404438
class TestReweightTemperatureValidation:
405439
def test_scalar_temperature_broadcasts_to_all_levels(self):

0 commit comments

Comments
 (0)