Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/source/audio/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,28 @@ The maximum nesting depth is calculated as the maximum depth of ``input_cfg`` ke
input_cfg: # Level 2 (same as above)
- type: lhotse_shar

When ``input_cfg`` is overridden via CLI to a YAML file path (e.g.
``model.train_ds.input_cfg=train_all.yaml``), the depth calculation loads the
referenced file and traverses its contents to count nested ``input_cfg`` keys.
This also works with multi-level file references:

.. code-block:: yaml

# train_all.yaml (referenced via input_cfg=train_all.yaml)
- type: group
weight: 100
input_cfg: ${oc.env:MANIFEST_ROOT}/train_en.yaml # resolved at runtime
- type: group
weight: 200
input_cfg: ${oc.env:MANIFEST_ROOT}/train_de.yaml

.. note::

Paths containing OmegaConf interpolations (e.g. ``${oc.env:MANIFEST_ROOT}``)
cannot be resolved during depth counting -- they are resolved later at runtime
by ``OmegaConf.create()``. Such paths are treated as a single additional
nesting level.

**Example: Balancing Multiple Task Groups**

.. code-block:: yaml
Expand Down
33 changes: 31 additions & 2 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
the same temperature (due to propagate_attrs.copy()), we count max depth,
not total occurrences.

String/Path values for ``input_cfg`` are treated as file references (mirroring
:func:`parse_and_combine_datasets`) and loaded so that nested ``input_cfg``
keys inside those files are counted. If the file is not found (e.g. the
path contains unresolved OmegaConf interpolations such as
``${oc.env:MANIFEST_ROOT}``), it is conservatively counted as one additional
level. All other I/O or parsing errors propagate immediately.

Args:
config: Configuration dictionary that may contain nested 'input_cfg' keys.

Expand All @@ -493,13 +500,35 @@ def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
2
"""

_cache: dict[str, object] = {}

def _resolve_if_path(val):
"""If *val* is a string/Path, load the YAML file it points to.

Raises on I/O or parse errors except ``FileNotFoundError``, which is
expected when the path contains OmegaConf interpolations (e.g.
``${oc.env:MANIFEST_ROOT}/file.yaml``) that raw ``yaml.load`` returns
as literal strings. ``parse_and_combine_datasets`` resolves them at
runtime via ``OmegaConf.create()``.
"""
if isinstance(val, (str, Path)):
key = str(val)
if key not in _cache:
try:
_cache[key] = load_yaml(key)
except FileNotFoundError:
logging.debug("count_input_cfg_levels: could not load %r, treating as leaf", key)
_cache[key] = val
return _cache[key]
return val

def _max_depth(obj) -> int:
if isinstance(obj, (dict, DictConfig)):
depths = []
for key, val in obj.items():
if key == "input_cfg":
# Found input_cfg: this level counts as 1 + max depth of children
depths.append(1 + _max_depth(val))
resolved = _resolve_if_path(val)
depths.append(1 + _max_depth(resolved))
Comment thread
XuesongYang marked this conversation as resolved.
else:
depths.append(_max_depth(val))
return max(depths, default=0)
Expand Down
34 changes: 34 additions & 0 deletions tests/collections/common/test_lhotse_temperature_reweighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,40 @@ def test_with_omegaconf_dictconfig(self):
)
assert count_input_cfg_levels(config) == 2

def test_string_input_cfg_resolves_yaml_file(self, tmp_path):
"""String input_cfg pointing to a YAML file is loaded and traversed."""
inner_yaml = tmp_path / "inner.yaml"
inner_yaml.write_text(
"- type: lhotse_shar\n shar_path: /path1\n weight: 1.0\n"
"- type: lhotse_shar\n shar_path: /path2\n weight: 2.0\n"
)
config = {"input_cfg": str(inner_yaml)}
assert count_input_cfg_levels(config) == 1

def test_two_level_string_input_cfg(self, tmp_path):
"""Top-level string ref → YAML with groups whose input_cfg are strings."""
lang_yaml = tmp_path / "train_lang.yaml"
lang_yaml.write_text("- type: lhotse_shar\n shar_path: /audio/cuts\n weight: 3.5\n")
top_yaml = tmp_path / "train_all.yaml"
top_yaml.write_text(
f"- type: group\n weight: 100\n input_cfg: {lang_yaml}\n"
f"- type: group\n weight: 200\n input_cfg: {lang_yaml}\n"
)
config = {"input_cfg": str(top_yaml)}
assert count_input_cfg_levels(config) == 2

def test_unresolvable_nested_string_input_cfg_is_treated_as_leaf(self):
"""Nested unresolvable string input_cfg is treated as a leaf, so total depth is 2."""
config = {
"input_cfg": [
{
"type": "group",
"input_cfg": "${oc.env:MANIFEST_ROOT}/train_ar-AE.yaml",
},
]
}
assert count_input_cfg_levels(config) == 2


class TestReweightTemperatureValidation:
def test_scalar_temperature_broadcasts_to_all_levels(self):
Expand Down
Loading