@@ -232,7 +232,22 @@ class LmdbDataReader:
232232 type_map : list[str]
233233 Global type map from model config.
234234 batch_size : int or str
235- Batch size. Supports int, "auto", "auto:N".
235+ Batch size rule used to derive per-nloc batch sizes. Supports:
236+
237+ - ``int``: fixed, identical batch size for every nloc group.
238+ - ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
239+ (``N=32`` for bare ``"auto"``). Acts as a *lower* bound —
240+ each batch has at least ``N`` atoms, but may exceed ``N``
241+ by up to ``nloc - 1``.
242+ - ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
243+ Acts as an *upper* bound for groups with ``nloc <= N``
244+ (batch has at most ``N`` atoms). For groups with
245+ ``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1``
246+ and a single-frame batch still carries ``nloc`` atoms,
247+ which exceeds ``N``.
248+ - ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and**
249+ drops every frame whose ``nloc > N`` from the dataset. By
250+ construction every retained batch has at most ``N`` atoms.
236251 mixed_batch : bool
237252 If True, allow different nloc in the same batch (future).
238253 If False (default), enforce same-nloc-per-batch.
@@ -283,6 +298,10 @@ def __init__(
283298
284299 # Scan per-frame nloc only when needed for same-nloc batching.
285300 # For mixed_batch=True, skip the scan entirely (future: padding handles it).
301+ # We keep _frame_nlocs / _frame_system_ids indexable by the *original*
302+ # LMDB frame index even after filter:N: entries for dropped frames
303+ # simply never get referenced because _nloc_groups / _system_groups
304+ # no longer reference them.
286305 if not mixed_batch :
287306 # Fast path: use pre-computed frame_nlocs from metadata if available.
288307 # Falls back to scanning each frame's atom_types shape (~10 us/frame).
@@ -293,41 +312,95 @@ def __init__(
293312 self ._frame_nlocs = _scan_frame_nlocs (
294313 self ._env , self .nframes , self ._frame_fmt , self ._natoms
295314 )
296- self ._nloc_groups : dict [int , list [int ]] = {}
297- for idx , nloc in enumerate (self ._frame_nlocs ):
298- self ._nloc_groups .setdefault (nloc , []).append (idx )
299315 else :
300316 self ._frame_nlocs = []
301- self ._nloc_groups = {}
302317
303- # Parse frame_system_ids for auto_prob support
318+ # Parse frame_system_ids for auto_prob support. _nsystems must stay at
319+ # ``max(original_sid) + 1`` even after filter:N so that user-facing
320+ # auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
321+ # keeps its meaning across filter thresholds.
304322 meta_sys_ids = meta .get ("frame_system_ids" )
305323 if meta_sys_ids is not None :
306324 self ._frame_system_ids : list [int ] | None = [int (s ) for s in meta_sys_ids ]
307325 self ._nsystems = max (self ._frame_system_ids ) + 1
308- self ._system_groups : dict [int , list [int ]] = {}
309- for idx , sid in enumerate (self ._frame_system_ids ):
310- self ._system_groups .setdefault (sid , []).append (idx )
311- self ._system_nframes : list [int ] = [
312- len (self ._system_groups .get (i , [])) for i in range (self ._nsystems )
313- ]
314326 else :
315327 self ._frame_system_ids = None
316328 self ._nsystems = 1
317- self ._system_groups = {0 : list (range (self .nframes ))}
318- self ._system_nframes = [self .nframes ]
319329
320- # Parse batch_size spec
330+ # Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
331+ # exclusive; ``filter_rule`` implies ``max_rule`` plus dropping frames
332+ # whose nloc exceeds the threshold.
321333 self ._auto_rule : int | None = None
334+ self ._max_rule : int | None = None
335+ self ._filter_rule : int | None = None
322336 if isinstance (batch_size , str ):
323337 if batch_size == "auto" :
324338 self ._auto_rule = 32
325339 elif batch_size .startswith ("auto:" ):
326340 self ._auto_rule = int (batch_size .split (":" )[1 ])
341+ elif batch_size .startswith ("max:" ):
342+ self ._max_rule = int (batch_size .split (":" )[1 ])
343+ elif batch_size .startswith ("filter:" ):
344+ self ._filter_rule = int (batch_size .split (":" )[1 ])
345+ self ._max_rule = self ._filter_rule
327346 else :
328- self ._auto_rule = 32
329- # Default batch_size uses first frame's nloc (for total_batch estimate)
347+ raise ValueError (
348+ f"Unsupported batch_size { batch_size !r} . "
349+ "Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'."
350+ )
351+
352+ # Determine which original-index frames survive the filter. Without
353+ # ``filter:N`` every frame is retained. ``mixed_batch=True`` has no
354+ # per-frame nloc info to filter against, so filter:N is a no-op there.
355+ if self ._filter_rule is not None and not mixed_batch :
356+ retained_indices = [
357+ i for i , n in enumerate (self ._frame_nlocs ) if n <= self ._filter_rule
358+ ]
359+ n_dropped = self .nframes - len (retained_indices )
360+ if n_dropped > 0 :
361+ log .info (
362+ f"LMDB filter:{ self ._filter_rule } drops { n_dropped } /"
363+ f"{ self .nframes } frames with nloc > { self ._filter_rule } "
364+ f"({ self .lmdb_path } )."
365+ )
366+ else :
367+ retained_indices = list (range (self .nframes ))
368+
369+ # Group retained frames by nloc. _nloc_groups only contains nlocs
370+ # that passed the filter; its values stay as *original* LMDB frame
371+ # indices so __getitem__(index) keeps reading the right LMDB key.
372+ if not mixed_batch :
373+ self ._nloc_groups : dict [int , list [int ]] = {}
374+ for idx in retained_indices :
375+ self ._nloc_groups .setdefault (self ._frame_nlocs [idx ], []).append (idx )
376+ else :
377+ self ._nloc_groups = {}
378+
379+ # Group retained frames by system id. _system_nframes is indexed by
380+ # *original* sid and stays length _nsystems even if some systems are
381+ # fully dropped — those entries are simply zero so auto_prob block
382+ # slicing still parses predictably.
383+ if self ._frame_system_ids is not None :
384+ self ._system_groups : dict [int , list [int ]] = {}
385+ for idx in retained_indices :
386+ sid = self ._frame_system_ids [idx ]
387+ self ._system_groups .setdefault (sid , []).append (idx )
388+ self ._system_nframes : list [int ] = [
389+ len (self ._system_groups .get (i , [])) for i in range (self ._nsystems )
390+ ]
391+ else :
392+ self ._system_groups = {0 : list (retained_indices )}
393+ self ._system_nframes = [len (retained_indices )]
394+
395+ # nframes now reflects retained frames; __len__ returns this.
396+ self .nframes = len (retained_indices )
397+
398+ # Default batch_size used only by the index/total_batch estimate. The
399+ # sampler always goes through get_batch_size_for_nloc for real batches.
400+ if self ._auto_rule is not None :
330401 self .batch_size = _compute_batch_size (self ._natoms , self ._auto_rule )
402+ elif self ._max_rule is not None :
403+ self .batch_size = max (1 , self ._max_rule // max (self ._natoms , 1 ))
331404 else :
332405 self .batch_size = int (batch_size )
333406
@@ -382,9 +455,19 @@ def __del__(self) -> None:
382455 _close_lmdb (path )
383456
384457 def get_batch_size_for_nloc (self , nloc : int ) -> int :
385- """Get batch_size for a given nloc. Uses auto rule if configured."""
458+ """Return the per-nloc batch size for the configured rule.
459+
460+ - ``auto`` / ``auto:N``: ``ceil(N / nloc)`` — may overshoot the
461+ atom budget by up to ``nloc - 1`` atoms.
462+ - ``max:N`` / ``filter:N``: ``max(1, floor(N / nloc))`` — never
463+ overshoots; clamps to 1 when a single frame already exceeds ``N``
464+ atoms.
465+ - fixed int: the same value for every nloc group.
466+ """
386467 if self ._auto_rule is not None :
387468 return _compute_batch_size (nloc , self ._auto_rule )
469+ if self ._max_rule is not None :
470+ return max (1 , self ._max_rule // max (nloc , 1 ))
388471 return self .batch_size
389472
390473 def __len__ (self ) -> int :
@@ -538,11 +621,19 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
538621 def print_summary (self , name : str , prob : Any ) -> None :
539622 """Print basic dataset info."""
540623 n_groups = len (self ._nloc_groups )
624+ if self ._auto_rule is not None :
625+ bs_str = f"auto:{ self ._auto_rule } "
626+ elif self ._filter_rule is not None :
627+ bs_str = f"filter:{ self ._filter_rule } "
628+ elif self ._max_rule is not None :
629+ bs_str = f"max:{ self ._max_rule } "
630+ else :
631+ bs_str = str (self .batch_size )
541632
542633 log .info (
543634 f"LMDB { name } : { self .lmdb_path } , "
544635 f"{ self .nframes } frames, { n_groups } nloc groups, "
545- f"batch_size={ 'auto' if self . _auto_rule else self . batch_size } , "
636+ f"batch_size={ bs_str } , "
546637 f"mixed_batch={ self .mixed_batch } "
547638 )
548639 # Print nloc groups in rows of ~10 for readability
@@ -646,6 +737,24 @@ def compute_block_targets(
646737 stt , end , weight = part .split (":" )
647738 blocks .append ((int (stt ), int (end ), float (weight )))
648739
740+ # Drop blocks that retain zero frames (can happen when ``filter:N``
741+ # eliminates every system in a block). prob_sys_size_ext's per-block
742+ # ``nbatch_block / sum(nbatch_block)`` would otherwise propagate NaN
743+ # when the whole block sums to zero. An all-zero dataset yields no
744+ # targets at all.
745+ nonempty = [
746+ (stt , end , weight )
747+ for stt , end , weight in blocks
748+ if sum (system_nframes [stt :end ]) > 0
749+ ]
750+ if not nonempty :
751+ return []
752+ if len (nonempty ) < len (blocks ):
753+ auto_prob_style = "prob_sys_size;" + ";" .join (
754+ f"{ stt } :{ end } :{ weight } " for stt , end , weight in nonempty
755+ )
756+ blocks = nonempty
757+
649758 # Compute per-system probabilities using the standard function
650759 sys_probs = prob_sys_size_ext (auto_prob_style , nsystems , system_nframes )
651760
0 commit comments