@@ -194,6 +194,7 @@ def __init__(
194194 target_path : str ,
195195 merge_into_existing : bool = True ,
196196 cleanup_keys : Optional [Tuple [str , ...]] = None ,
197+ list_key : Optional [str ] = None ,
197198 ) -> None :
198199 if isinstance (selector_paths , str ):
199200 selector_paths = [selector_paths ]
@@ -202,6 +203,7 @@ def __init__(
202203 self .target_path = target_path
203204 self .merge_into_existing = merge_into_existing
204205 self .cleanup_keys = cleanup_keys if cleanup_keys is not None else (profiles_key ,)
206+ self .list_key = list_key
205207
206208 def _apply_list_overrides (self , yaml_conf : DictConfig , selector_path : str ) -> None :
207209 """Apply per-index overrides to a list-valued profile.
@@ -224,7 +226,8 @@ def _apply_list_overrides(self, yaml_conf: DictConfig, selector_path: str) -> No
224226 if overrides is None or not isinstance (overrides , (DictConfig , dict )):
225227 return
226228
227- target_list = OmegaConf .select (yaml_conf , self .target_path )
229+ list_path = f"{ self .target_path } .{ self .list_key } " if self .list_key else self .target_path
230+ target_list = OmegaConf .select (yaml_conf , list_path )
228231 if not isinstance (target_list , (ListConfig , list )):
229232 return
230233
@@ -343,10 +346,12 @@ def __init__(
343346 profiles_key : str ,
344347 target_paths : List [str ],
345348 profile_key : str = "profile" ,
349+ list_key : Optional [str ] = None ,
346350 ) -> None :
347351 self .profiles_key = profiles_key
348352 self .target_paths = target_paths
349353 self .profile_key = profile_key
354+ self .list_key = list_key
350355 self .cleanup_keys = (profiles_key ,)
351356
352357 def apply (self , yaml_conf : DictConfig ) -> None :
@@ -376,34 +381,40 @@ def apply(self, yaml_conf: DictConfig) -> None:
376381 )
377382
378383 profile_payload = profiles [profile_name ]
384+
385+ # Extract nested list from dict-valued profiles.
386+ profile_list = profile_payload
387+ if self .list_key and isinstance (profile_payload , (DictConfig , dict )):
388+ profile_list = profile_payload .get (self .list_key , profile_payload )
389+
379390 overrides = {k : v for k , v in item .items () if k != self .profile_key }
380391
381392 if not overrides :
382393 # Pure profile reference: expand entire profile list inline
383- if not isinstance (profile_payload , (ListConfig , list )):
394+ if not isinstance (profile_list , (ListConfig , list )):
384395 raise ValueError (
385396 f"Profile '{ profile_name } ' in { self .profiles_key } must resolve to a list "
386- f"for target '{ target_path } ', got { type (profile_payload )} "
397+ f"for target '{ target_path } ', got { type (profile_list )} "
387398 )
388- expanded_items .extend (list (profile_payload ))
399+ expanded_items .extend (list (profile_list ))
389400 else :
390401 # Profile + overrides: use first profile item as base, merge overrides on top
391- if isinstance (profile_payload , (ListConfig , list )):
392- if not profile_payload :
402+ if isinstance (profile_list , (ListConfig , list )):
403+ if not profile_list :
393404 raise ValueError (
394405 f"Profile '{ profile_name } ' in { self .profiles_key } is empty."
395406 )
396407 base = OmegaConf .create (
397- OmegaConf .to_container (profile_payload [0 ], resolve = False )
408+ OmegaConf .to_container (profile_list [0 ], resolve = False )
398409 )
399- elif isinstance (profile_payload , (DictConfig , dict )):
410+ elif isinstance (profile_list , (DictConfig , dict )):
400411 base = OmegaConf .create (
401- OmegaConf .to_container (profile_payload , resolve = False )
412+ OmegaConf .to_container (profile_list , resolve = False )
402413 )
403414 else :
404415 raise ValueError (
405416 f"Profile '{ profile_name } ' in { self .profiles_key } must be a list or dict, "
406- f"got { type (profile_payload )} "
417+ f"got { type (profile_list )} "
407418 )
408419 merged = OmegaConf .merge (base , OmegaConf .create (overrides ))
409420 expanded_items .append (merged )
@@ -460,26 +471,28 @@ def _stage_path(stage: str, rel_path: str) -> str:
460471
461472
462473# Each tuple:
463- # (profiles_key, stages, selector_rel, target_rel, merge_into_existing)
474+ # (profiles_key, stages, selector_rel, target_rel, merge_into_existing, list_key )
464475# Order matters: pipeline/system first, then arch, then the rest.
465- _VALUE_PROFILE_FAMILIES : List [Tuple [str , Tuple [str , ...], str , str , bool ]] = [
466- ("pipeline_profiles" , (_STAGE_DEFAULT ,), "pipeline_profile" , "" , True ),
467- ("system_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN , _STAGE_TEST , _STAGE_TUNE ), "system.profile" , "system" , True ),
468- ("arch_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN , _STAGE_TEST , _STAGE_TUNE ), "model.arch.profile" , "model" , True ),
469- ("augmentation_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "data.augmentation.profile" , "data.augmentation" , True ),
470- ("dataloader_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN , _STAGE_TEST , _STAGE_TUNE ), "data.dataloader.profile" , "data.dataloader" , True ),
471- ("optimizer_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "optimization.profile" , "optimization" , True ),
472- ("loss_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "model.loss.profile" , "model.loss.losses" , False ),
473- ("label_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "data.label_transform.profile" , "data.label_transform" , False ),
474- ("decoding_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference.decoding_profile" , "inference.decoding" , False ),
475- ("activation_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference.test_time_augmentation.activation_profile" , "inference.test_time_augmentation.channel_activations" , False ),
476+ # All profiles are dict-valued and use merge semantics. ``list_key`` identifies the
477+ # nested list within dict-valued profiles (used for positional overrides).
478+ _VALUE_PROFILE_FAMILIES : List [Tuple [str , Tuple [str , ...], str , str , bool , Optional [str ]]] = [
479+ ("pipeline_profiles" , (_STAGE_DEFAULT ,), "pipeline_profile" , "" , True , None ),
480+ ("system_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN , _STAGE_TEST , _STAGE_TUNE ), "system.profile" , "system" , True , None ),
481+ ("arch_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN , _STAGE_TEST , _STAGE_TUNE ), "model.arch.profile" , "model" , True , None ),
482+ ("augmentation_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "data.augmentation.profile" , "data.augmentation" , True , None ),
483+ ("dataloader_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN , _STAGE_TEST , _STAGE_TUNE ), "data.dataloader.profile" , "data.dataloader" , True , None ),
484+ ("optimizer_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "optimization.profile" , "optimization" , True , None ),
485+ ("loss_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "model.loss.profile" , "model.loss" , True , "losses" ),
486+ ("label_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "data.label_transform.profile" , "data.label_transform" , True , None ),
487+ ("decoding_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference.decoding_profile" , "inference" , True , "decoding" ),
488+ ("activation_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference.test_time_augmentation.activation_profile" , "inference.test_time_augmentation" , True , "channel_activations" ),
476489]
477490
478491
479- def _build_value_profile_specs () -> List [Tuple [List [str ], str , str , bool , Tuple [str , ...]]]:
480- specs : List [Tuple [List [str ], str , str , bool , Tuple [str , ...]]] = []
492+ def _build_value_profile_specs () -> List [Tuple [List [str ], str , str , bool , Tuple [str , ...], Optional [ str ] ]]:
493+ specs : List [Tuple [List [str ], str , str , bool , Tuple [str , ...], Optional [ str ] ]] = []
481494
482- for profiles_key , stages , selector_rel , target_rel , merge in _VALUE_PROFILE_FAMILIES :
495+ for profiles_key , stages , selector_rel , target_rel , merge , list_key in _VALUE_PROFILE_FAMILIES :
483496 for stage in stages :
484497 selector_path = _stage_path (stage , selector_rel )
485498 target_path = _stage_path (stage , target_rel )
@@ -490,6 +503,7 @@ def _build_value_profile_specs() -> List[Tuple[List[str], str, str, bool, Tuple[
490503 target_path ,
491504 merge ,
492505 (profiles_key , selector_path ),
506+ list_key ,
493507 )
494508 )
495509
@@ -498,10 +512,10 @@ def _build_value_profile_specs() -> List[Tuple[List[str], str, str, bool, Tuple[
498512
499513# Each tuple: (profiles_key, stages, target_rel)
500514_REFERENCE_PROFILE_FAMILIES : List [Tuple [str , Tuple [str , ...], str ]] = [
501- ("loss_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "model.loss.losses " ),
515+ ("loss_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "model.loss" ),
502516 ("label_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "data.label_transform" ),
503- ("decoding_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference.decoding " ),
504- ("activation_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference.test_time_augmentation.channel_activations " ),
517+ ("decoding_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference" ),
518+ ("activation_profiles" , (_STAGE_DEFAULT , _STAGE_TEST , _STAGE_TUNE ), "inference.test_time_augmentation" ),
505519 ("augmentation_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "data.augmentation" ),
506520 ("dataloader_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN , _STAGE_TEST , _STAGE_TUNE ), "data.dataloader" ),
507521 ("optimizer_profiles" , (_STAGE_DEFAULT , _STAGE_TRAIN ), "optimization" ),
@@ -517,24 +531,25 @@ def _build_reference_profile_specs() -> List[Tuple[str, List[str]]]:
517531 return specs
518532
519533
520- # Each tuple: (profiles_key, stages, target_rel)
521- _LIST_REFERENCE_FAMILIES : List [Tuple [str , Tuple [str , ...], str ]] = [
522- ("decoding_profiles" , (_STAGE_DEFAULT , _STAGE_TUNE , _STAGE_TEST ), "inference.decoding" ),
534+ # Each tuple: (profiles_key, stages, target_rel, list_key)
535+ # ``list_key`` identifies the nested list within dict-valued profiles.
536+ _LIST_REFERENCE_FAMILIES : List [Tuple [str , Tuple [str , ...], str , str ]] = [
537+ ("decoding_profiles" , (_STAGE_DEFAULT , _STAGE_TUNE , _STAGE_TEST ), "inference.decoding" , "decoding" ),
523538]
524539
525540
526- def _build_list_reference_specs () -> List [Tuple [str , List [str ]]]:
527- specs : List [Tuple [str , List [str ]]] = []
528- for profiles_key , stages , target_rel in _LIST_REFERENCE_FAMILIES :
541+ def _build_list_reference_specs () -> List [Tuple [str , List [str ], str ]]:
542+ specs : List [Tuple [str , List [str ], str ]] = []
543+ for profiles_key , stages , target_rel , list_key in _LIST_REFERENCE_FAMILIES :
529544 target_paths = [_stage_path (stage , target_rel ) for stage in stages ]
530- specs .append ((profiles_key , target_paths ))
545+ specs .append ((profiles_key , target_paths , list_key ))
531546 return specs
532547
533548
534549def _build_allowed_selector_paths () -> set [str ]:
535550 paths : set [str ] = set ()
536551
537- for _ , stages , selector_rel , _ , _ in _VALUE_PROFILE_FAMILIES :
552+ for _ , stages , selector_rel , _ , _ , _ in _VALUE_PROFILE_FAMILIES :
538553 for stage in stages :
539554 paths .add (_normalize_selector_path (_stage_path (stage , selector_rel )))
540555
@@ -574,13 +589,14 @@ def _build_profile_engine() -> YamlProfileEngine:
574589 appliers : List [YamlProfileApplier ] = []
575590
576591 # 1) Value profile appliers (order defined by _VALUE_PROFILE_FAMILIES table)
577- for selector_paths , profiles_key , target_path , merge , cleanup_keys in _VALUE_PROFILE_SPECS :
592+ for selector_paths , profiles_key , target_path , merge , cleanup_keys , list_key in _VALUE_PROFILE_SPECS :
578593 appliers .append (ValueProfileApplier (
579594 selector_paths = selector_paths ,
580595 profiles_key = profiles_key ,
581596 target_path = target_path ,
582597 merge_into_existing = merge ,
583598 cleanup_keys = cleanup_keys ,
599+ list_key = list_key ,
584600 ))
585601
586602 # 2) Reference profile appliers
@@ -591,10 +607,11 @@ def _build_profile_engine() -> YamlProfileEngine:
591607 ))
592608
593609 # 3) List profile reference appliers
594- for profiles_key , target_paths in _LIST_REFERENCE_SPECS :
610+ for profiles_key , target_paths , list_key in _LIST_REFERENCE_SPECS :
595611 appliers .append (ListProfileReferenceApplier (
596612 profiles_key = profiles_key ,
597613 target_paths = target_paths ,
614+ list_key = list_key ,
598615 ))
599616
600617 return YamlProfileEngine (appliers )
0 commit comments