Skip to content

Commit 0b8bf36

Browse files
Promote single samples before the shuffle buffer
shuffle_buffer concatenates yields along the leading axis, so a raw array source under shuffle=True had its rows flattened (2-D) or crashed on shape[0] (scalars). Promote single samples to one-row blocks before the shuffle wrap, with the same helper the re-batcher uses. Also tighten a few docstring claims: the parquet dtype follows the file columns, and the shuffle buffer bound is stated as rows held.
1 parent 23563ac commit 0b8bf36

2 files changed

Lines changed: 70 additions & 15 deletions

File tree

pymc/variational/streaming.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ class DataLoader:
132132
Like ``torch.utils.data.DataLoader``, it batches (and optionally
133133
shuffles) an :class:`IterableDataset` into the minibatch stream that
134134
:class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)``
135-
is the dataset size ``N``). The full dataset never enters memory; only one
136-
``(batch_size, *sample_shape)`` batch does.
135+
is the dataset size ``N``). The full dataset never enters memory.
137136
138137
Parameters
139138
----------
@@ -199,8 +198,13 @@ def __init__(
199198
if shuffle:
200199
if buffer_size is None:
201200
buffer_size = 50 * int(batch_size)
201+
# shuffle_buffer concatenates yields along the leading axis, so single
202+
# samples must be promoted to one-row blocks before shuffling.
202203
source_factory = shuffle_buffer(
203-
raw_factory, buffer_size=buffer_size, batch_size=batch_size, seed=seed
204+
_block_factory(raw_factory, tuple(sample_shape)),
205+
buffer_size=buffer_size,
206+
batch_size=batch_size,
207+
seed=seed,
204208
)
205209
self._source_factory = source_factory
206210

@@ -490,8 +494,8 @@ def shuffle_buffer(
490494
buffer always accumulates at least ``max(buffer_size, batch_size)`` rows before
491495
emitting (so a ``buffer_size`` smaller than ``batch_size`` still yields full
492496
batches instead of silently dropping the stream), and a single chunk larger
493-
than that is taken whole, so peak buffer memory is
494-
``max(buffer_size, batch_size, largest_chunk_rows)``.
497+
than that is taken whole, so the buffer holds at most
498+
``max(buffer_size, batch_size, largest_chunk_rows)`` rows.
495499
496500
Each epoch (each call of the returned factory) draws a fresh permutation from
497501
a sub-stream of ``seed``, so the shuffle order differs across epochs while
@@ -549,6 +553,40 @@ def factory() -> Iterator[np.ndarray]:
549553
return factory
550554

551555

556+
557+
def _promote_to_block(a: np.ndarray, sample_shape: tuple[int, ...]) -> np.ndarray:
558+
"""Return ``a`` as a ``(rows, *sample_shape)`` block; a single sample becomes one row."""
559+
if a.shape == sample_shape:
560+
return a[None, ...]
561+
if a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape:
562+
raise ValueError(
563+
f"source yielded shape {a.shape}; expected a single sample of shape "
564+
f"{sample_shape} or a block of shape (rows, *{sample_shape})"
565+
)
566+
return a
567+
568+
569+
def _block_factory(
570+
factory: Callable[[], Iterator[np.ndarray]],
571+
sample_shape: tuple[int, ...],
572+
) -> Callable[[], Iterator[np.ndarray]]:
573+
"""Wrap ``factory`` so every yield is a block, promoting single samples.
574+
575+
:func:`shuffle_buffer` counts and concatenates yields along the leading axis,
576+
so single-sample yields (e.g. the rows of a raw array) must be promoted to
577+
one-row blocks before shuffling. A known ``.n_rows`` is forwarded.
578+
"""
579+
580+
def f() -> Iterator[np.ndarray]:
581+
for arr in factory():
582+
yield _promote_to_block(np.asarray(arr), sample_shape)
583+
584+
n_rows = getattr(factory, "n_rows", None)
585+
if n_rows is not None:
586+
f.n_rows = n_rows # type: ignore[attr-defined]
587+
return f
588+
589+
552590
def _rebatch(
553591
blocks: Iterable[np.ndarray],
554592
batch_size: int,
@@ -567,14 +605,7 @@ def _rebatch(
567605
buf: list[np.ndarray] = []
568606
have = 0
569607
for arr in blocks:
570-
a = np.asarray(arr)
571-
if a.shape == sample_shape: # a single sample, not a block
572-
a = a[None, ...]
573-
elif a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape:
574-
raise ValueError(
575-
f"source yielded shape {a.shape}; expected a single sample of shape "
576-
f"{sample_shape} or a block of shape (rows, *{sample_shape})"
577-
)
608+
a = _promote_to_block(np.asarray(arr), sample_shape)
578609
buf.append(a)
579610
have += a.shape[0]
580611
if have < batch_size:
@@ -681,7 +712,7 @@ def _auto_total_size(
681712
class _ParquetDataset(IterableDataset):
682713
"""An :class:`IterableDataset` over a directory of Parquet shards.
683714
684-
Yields one ``(rows, n_columns)`` ``float64`` array per file and exposes
715+
Yields one ``(rows, n_columns)`` array per file and exposes
685716
:attr:`n_rows` read from Parquet metadata (no data scan).
686717
"""
687718

@@ -706,7 +737,7 @@ def parquet_source(
706737
) -> _ParquetDataset:
707738
"""An :class:`IterableDataset` over a directory of Parquet files.
708739
709-
Yields one ``(rows, n_columns)`` ``float64`` array per file, and carries an
740+
Yields one ``(rows, n_columns)`` array per file, and carries an
710741
``n_rows`` attribute read from Parquet metadata (no data scan) so that
711742
``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the
712743
dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or

tests/variational/test_streaming.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,30 @@ def test_factory_returning_reiterable_is_accepted():
385385
assert next(iter(ds)).shape == (4, 1)
386386

387387

388+
def test_raw_array_with_shuffle_true():
389+
"""A raw array source composes with shuffle=True: rows are promoted to
390+
one-row blocks before the shuffle buffer instead of being flattened by it."""
391+
data = np.arange(40, dtype="float64").reshape(20, 2)
392+
ds = DataLoader(
393+
data, batch_size=8, shuffle=True, buffer_size=16, seed=0, sample_shape=(2,), total_size=20
394+
)
395+
batches = list(ds)
396+
assert [b.shape for b in batches] == [(8, 2), (8, 2)]
397+
rows = {tuple(r) for b in batches for r in b}
398+
assert len(rows) == 16 and rows <= {tuple(r) for r in data}
399+
400+
401+
def test_scalar_raw_array_with_shuffle_true():
402+
"""Scalar samples from a raw 1-D array compose with shuffle=True."""
403+
data = np.arange(12, dtype="float64")
404+
ds = DataLoader(
405+
data, batch_size=4, shuffle=True, buffer_size=6, seed=0, sample_shape=(), total_size=12
406+
)
407+
batches = list(ds)
408+
assert [b.shape for b in batches] == [(4,), (4,), (4,)]
409+
np.testing.assert_array_equal(np.sort(np.concatenate(batches)), data)
410+
411+
388412
def test_scalar_samples_are_batched():
389413
"""With sample_shape=() a 0-D yield is one scalar sample, exactly what
390414
iterating a raw 1-D array produces; the loader batches scalars."""

0 commit comments

Comments
 (0)