@@ -132,8 +132,7 @@ class DataLoader:
132132 Like ``torch.utils.data.DataLoader``, it batches (and optionally
133133 shuffles) an :class:`IterableDataset` into the minibatch stream that
134134 :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)``
135- is the dataset size ``N``). The full dataset never enters memory; only one
136- ``(batch_size, *sample_shape)`` batch does.
135+ is the dataset size ``N``). The full dataset never enters memory.
137136
138137 Parameters
139138 ----------
@@ -199,8 +198,13 @@ def __init__(
199198 if shuffle :
200199 if buffer_size is None :
201200 buffer_size = 50 * int (batch_size )
201+ # shuffle_buffer concatenates yields along the leading axis, so single
202+ # samples must be promoted to one-row blocks before shuffling.
202203 source_factory = shuffle_buffer (
203- raw_factory , buffer_size = buffer_size , batch_size = batch_size , seed = seed
204+ _block_factory (raw_factory , tuple (sample_shape )),
205+ buffer_size = buffer_size ,
206+ batch_size = batch_size ,
207+ seed = seed ,
204208 )
205209 self ._source_factory = source_factory
206210
@@ -490,8 +494,8 @@ def shuffle_buffer(
490494 buffer always accumulates at least ``max(buffer_size, batch_size)`` rows before
491495 emitting (so a ``buffer_size`` smaller than ``batch_size`` still yields full
492496 batches instead of silently dropping the stream), and a single chunk larger
493- than that is taken whole, so peak buffer memory is
494- ``max(buffer_size, batch_size, largest_chunk_rows)``.
497+ than that is taken whole, so the buffer holds at most
498+ ``max(buffer_size, batch_size, largest_chunk_rows)`` rows .
495499
496500 Each epoch (each call of the returned factory) draws a fresh permutation from
497501 a sub-stream of ``seed``, so the shuffle order differs across epochs while
@@ -549,6 +553,40 @@ def factory() -> Iterator[np.ndarray]:
549553 return factory
550554
551555
556+
557+ def _promote_to_block (a : np .ndarray , sample_shape : tuple [int , ...]) -> np .ndarray :
558+ """Return ``a`` as a ``(rows, *sample_shape)`` block; a single sample becomes one row."""
559+ if a .shape == sample_shape :
560+ return a [None , ...]
561+ if a .ndim != len (sample_shape ) + 1 or a .shape [1 :] != sample_shape :
562+ raise ValueError (
563+ f"source yielded shape { a .shape } ; expected a single sample of shape "
564+ f"{ sample_shape } or a block of shape (rows, *{ sample_shape } )"
565+ )
566+ return a
567+
568+
569+ def _block_factory (
570+ factory : Callable [[], Iterator [np .ndarray ]],
571+ sample_shape : tuple [int , ...],
572+ ) -> Callable [[], Iterator [np .ndarray ]]:
573+ """Wrap ``factory`` so every yield is a block, promoting single samples.
574+
575+ :func:`shuffle_buffer` counts and concatenates yields along the leading axis,
576+ so single-sample yields (e.g. the rows of a raw array) must be promoted to
577+ one-row blocks before shuffling. A known ``.n_rows`` is forwarded.
578+ """
579+
580+ def f () -> Iterator [np .ndarray ]:
581+ for arr in factory ():
582+ yield _promote_to_block (np .asarray (arr ), sample_shape )
583+
584+ n_rows = getattr (factory , "n_rows" , None )
585+ if n_rows is not None :
586+ f .n_rows = n_rows # type: ignore[attr-defined]
587+ return f
588+
589+
552590def _rebatch (
553591 blocks : Iterable [np .ndarray ],
554592 batch_size : int ,
@@ -567,14 +605,7 @@ def _rebatch(
567605 buf : list [np .ndarray ] = []
568606 have = 0
569607 for arr in blocks :
570- a = np .asarray (arr )
571- if a .shape == sample_shape : # a single sample, not a block
572- a = a [None , ...]
573- elif a .ndim != len (sample_shape ) + 1 or a .shape [1 :] != sample_shape :
574- raise ValueError (
575- f"source yielded shape { a .shape } ; expected a single sample of shape "
576- f"{ sample_shape } or a block of shape (rows, *{ sample_shape } )"
577- )
608+ a = _promote_to_block (np .asarray (arr ), sample_shape )
578609 buf .append (a )
579610 have += a .shape [0 ]
580611 if have < batch_size :
@@ -681,7 +712,7 @@ def _auto_total_size(
681712class _ParquetDataset (IterableDataset ):
682713 """An :class:`IterableDataset` over a directory of Parquet shards.
683714
684- Yields one ``(rows, n_columns)`` ``float64`` array per file and exposes
715+ Yields one ``(rows, n_columns)`` array per file and exposes
685716 :attr:`n_rows` read from Parquet metadata (no data scan).
686717 """
687718
@@ -706,7 +737,7 @@ def parquet_source(
706737) -> _ParquetDataset :
707738 """An :class:`IterableDataset` over a directory of Parquet files.
708739
709- Yields one ``(rows, n_columns)`` ``float64`` array per file, and carries an
740+ Yields one ``(rows, n_columns)`` array per file, and carries an
710741 ``n_rows`` attribute read from Parquet metadata (no data scan) so that
711742 ``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the
712743 dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or
0 commit comments