@@ -475,16 +475,22 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
475475
476476 def print_summary (self , name : str , prob : Any ) -> None :
477477 """Print basic dataset info."""
478- unique_nlocs = sorted (self ._nloc_groups .keys ())
479- nloc_info = ", " .join (
480- f"{ nloc } ({ len (idxs )} )" for nloc , idxs in sorted (self ._nloc_groups .items ())
481- )
478+ n_groups = len (self ._nloc_groups )
479+
482480 log .info (
483481 f"LMDB { name } : { self .lmdb_path } , "
484- f"{ self .nframes } frames, nloc groups: [ { nloc_info } ] , "
482+ f"{ self .nframes } frames, { n_groups } nloc groups, "
485483 f"batch_size={ 'auto' if self ._auto_rule else self .batch_size } , "
486484 f"mixed_batch={ self .mixed_batch } "
487485 )
486+ # Print nloc groups in rows of ~10 for readability
487+ items = [
488+ f"{ nloc } ({ len (idxs )} )" for nloc , idxs in sorted (self ._nloc_groups .items ())
489+ ]
490+ per_row = 10
491+ for i in range (0 , len (items ), per_row ):
492+ row = ", " .join (items [i : i + per_row ])
493+ log .info (f" nloc groups: { row } " )
488494
489495 def set_noise (self , noise_settings : dict [str , Any ]) -> None :
490496 """No-op for now."""
@@ -616,9 +622,11 @@ def compute_block_targets(
616622
617623def _expand_indices_by_blocks (
618624 indices : list [int ],
619- frame_system_ids : list [ int ] ,
625+ frame_system_ids : np . ndarray ,
620626 block_targets : list [tuple [list [int ], int ]],
621627 rng : np .random .Generator ,
628+ _block_total_actual : list [int ] | None = None ,
629+ _sid_to_blk_arr : np .ndarray | None = None ,
622630) -> list [int ]:
623631 """Expand frame indices according to block targets.
624632
@@ -630,78 +638,95 @@ def _expand_indices_by_blocks(
630638 ----------
631639 indices : list[int]
632640 Frame indices in the current nloc group.
633- frame_system_ids : list[int]
634- Per-frame system id for the entire dataset.
641+ frame_system_ids : np.ndarray
642+ Per-frame system id for the entire dataset (int64 array) .
635643 block_targets : list[tuple[list[int], int]]
636644 Per-block (system_ids, total_target_frames).
637645 rng : np.random.Generator
638646 RNG for remainder sampling.
647+ _block_total_actual : list[int] or None
648+ Pre-computed total actual frame count per block (across all nloc
649+ groups). When provided, avoids an O(N) scan of frame_system_ids.
650+ _sid_to_blk_arr : np.ndarray or None
651+ Pre-computed system-id to block-index lookup array. When provided,
652+ avoids rebuilding the mapping for each call.
639653
640654 Returns
641655 -------
642656 list[int]
643657 Expanded indices.
644658 """
645- # Build sys_id -> block_idx mapping
646- sys_to_block : dict [int , int ] = {}
647- for blk_idx , (sys_ids , _target ) in enumerate (block_targets ):
648- for sid in sys_ids :
649- sys_to_block [sid ] = blk_idx
650-
651- # Partition indices by block
652- block_indices : dict [int , list [int ]] = {i : [] for i in range (len (block_targets ))}
653- unassigned : list [int ] = []
654- for idx in indices :
655- sid = frame_system_ids [idx ]
656- blk = sys_to_block .get (sid )
657- if blk is not None :
658- block_indices [blk ].append (idx )
659- else :
660- unassigned .append (idx )
661-
662- # Compute total actual frames across all blocks (for proportional scaling)
663- total_actual = sum (len (block_indices [i ]) for i in range (len (block_targets )))
664- total_target_all = sum (t for _ , t in block_targets )
665-
666- expanded : list [int ] = list (unassigned )
667-
668- for blk_idx , (sys_ids , block_total_target ) in enumerate (block_targets ):
669- blk_idxs = block_indices [blk_idx ]
659+ n_blocks = len (block_targets )
660+
661+ # Build sys_id -> block_idx lookup array
662+ if _sid_to_blk_arr is None :
663+ sys_to_block : dict [int , int ] = {}
664+ for blk_idx , (sys_ids , _target ) in enumerate (block_targets ):
665+ for sid in sys_ids :
666+ sys_to_block [sid ] = blk_idx
667+ max_sid = max (sys_to_block .keys ()) + 1 if sys_to_block else 0
668+ _sid_to_blk_arr = np .full (max_sid , - 1 , dtype = np .int32 )
669+ for sid , blk in sys_to_block .items ():
670+ _sid_to_blk_arr [sid ] = blk
671+
672+ # Partition indices by block using numpy for speed
673+ idx_arr = np .asarray (indices , dtype = np .int64 )
674+ sid_arr = np .asarray (frame_system_ids , dtype = np .int64 )
675+ # Vectorized lookup: get block id for each index
676+ idx_sids = sid_arr [idx_arr ]
677+ idx_blks = _sid_to_blk_arr [idx_sids ]
678+
679+ # Pre-compute block_total_actual if not provided
680+ if _block_total_actual is None :
681+ _block_total_actual = []
682+ for sys_ids , _ in block_targets :
683+ total = sum (int (np .sum (sid_arr == sid )) for sid in sys_ids )
684+ _block_total_actual .append (total )
685+
686+ expanded_parts : list [np .ndarray ] = []
687+
688+ # Unassigned indices
689+ unassigned_mask = idx_blks == - 1
690+ if np .any (unassigned_mask ):
691+ expanded_parts .append (idx_arr [unassigned_mask ])
692+
693+ for blk_idx in range (n_blocks ):
694+ blk_mask = idx_blks == blk_idx
695+ blk_idxs = idx_arr [blk_mask ]
670696 n_actual = len (blk_idxs )
671697 if n_actual == 0 :
672698 continue
673699
674- # Proportional target for this nloc subset of the block
675- # block_total_target is for the entire block; scale by the fraction
676- # of block frames that fall in this nloc group
677- _ , block_total_target_all = block_targets [blk_idx ]
678- # Get total frames in this block across all nloc groups
679- block_total_actual = sum (
680- 1
681- for i in range (len (frame_system_ids ))
682- if frame_system_ids [i ] in set (sys_ids )
683- )
684- if block_total_actual > 0 :
685- target = round (block_total_target_all * n_actual / block_total_actual )
700+ _ , block_total_target = block_targets [blk_idx ]
701+ block_total_act = _block_total_actual [blk_idx ]
702+
703+ # Proportional target for this nloc subset
704+ if block_total_act > 0 :
705+ target = round (block_total_target * n_actual / block_total_act )
686706 else :
687707 target = n_actual
688708 target = max (target , n_actual ) # never shrink
689709
690710 # Full copies + remainder
691711 deficit = target - n_actual
692712 if deficit <= 0 :
693- expanded . extend (blk_idxs )
713+ expanded_parts . append (blk_idxs )
694714 else :
695715 full_copies = deficit // n_actual
696716 remainder = deficit % n_actual
697717 # Original + full copies
698- expanded .extend (blk_idxs * (1 + full_copies ))
718+ if full_copies > 0 :
719+ expanded_parts .append (np .tile (blk_idxs , 1 + full_copies ))
720+ else :
721+ expanded_parts .append (blk_idxs )
699722 # Remainder: sample without replacement
700723 if remainder > 0 :
701724 sampled = rng .choice (blk_idxs , size = remainder , replace = False )
702- expanded . extend (sampled . tolist () )
725+ expanded_parts . append (sampled )
703726
704- return expanded
727+ if expanded_parts :
728+ return np .concatenate (expanded_parts ).tolist ()
729+ return []
705730
706731
707732def _build_all_batches (
@@ -735,12 +760,39 @@ def _build_all_batches(
735760 """
736761 # Build per-group batches
737762 group_batches : list [list [list [int ]]] = []
763+
764+ # Pre-compute expensive objects once (avoids O(N) work per nloc group)
765+ block_total_actual : list [int ] | None = None
766+ sid_arr : np .ndarray | None = None
767+ sid_to_blk_arr : np .ndarray | None = None
768+ if block_targets and reader .frame_system_ids is not None :
769+ block_total_actual = []
770+ for sys_ids , _ in block_targets :
771+ total = sum (reader .system_nframes [s ] for s in sys_ids )
772+ block_total_actual .append (total )
773+ # Convert frame_system_ids to numpy once
774+ sid_arr = np .array (reader .frame_system_ids , dtype = np .int64 )
775+ # Build sys_id -> block_idx lookup array once
776+ sys_to_block : dict [int , int ] = {}
777+ for blk_idx , (sys_ids , _target ) in enumerate (block_targets ):
778+ for sid in sys_ids :
779+ sys_to_block [sid ] = blk_idx
780+ max_sid = max (sys_to_block .keys ()) + 1 if sys_to_block else 0
781+ sid_to_blk_arr = np .full (max_sid , - 1 , dtype = np .int32 )
782+ for sid , blk in sys_to_block .items ():
783+ sid_to_blk_arr [sid ] = blk
784+
738785 for nloc in sorted (reader .nloc_groups .keys ()):
739786 indices = list (reader .nloc_groups [nloc ])
740787 # Expand indices by block targets if provided
741- if block_targets and reader . frame_system_ids is not None :
788+ if block_targets and sid_arr is not None :
742789 indices = _expand_indices_by_blocks (
743- indices , reader .frame_system_ids , block_targets , rng
790+ indices ,
791+ sid_arr ,
792+ block_targets ,
793+ rng ,
794+ _block_total_actual = block_total_actual ,
795+ _sid_to_blk_arr = sid_to_blk_arr ,
744796 )
745797 if shuffle :
746798 rng .shuffle (indices )
0 commit comments