From 30d0ce264b0cb5e302efd20e3ff397f65e045d6f Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Fri, 5 Jun 2026 07:14:59 -0500 Subject: [PATCH 01/27] feat(variational): StreamingDataset for out-of-core minibatch VI pm.Minibatch random-indexes a fully-resident array (peak memory O(N)). StreamingDataset feeds minibatches from an arbitrary source into a small pytensor.shared buffer (peak memory O(batch_size)), reusing the existing total_size / create_minibatch_rv rescaling unchanged. Adds a shuffle_buffer helper and an equivalence test (streaming ADVI == in-RAM pm.Minibatch ADVI). --- pymc/variational/__init__.py | 3 + pymc/variational/streaming.py | 312 ++++++++++++++++++++++++++++ tests/variational/test_streaming.py | 180 ++++++++++++++++ 3 files changed, 495 insertions(+) create mode 100644 pymc/variational/streaming.py create mode 100644 tests/variational/test_streaming.py diff --git a/pymc/variational/__init__.py b/pymc/variational/__init__.py index 17b3cf3f7f..d9fe170822 100644 --- a/pymc/variational/__init__.py +++ b/pymc/variational/__init__.py @@ -44,6 +44,7 @@ # special from pymc.variational.stein import Stein +from pymc.variational.streaming import StreamingDataset, shuffle_buffer from pymc.variational.updates import ( adadelta, adagrad, @@ -69,6 +70,8 @@ "FullRankADVI", "Group", "MeanField", + "StreamingDataset", + "shuffle_buffer", "adadelta", "adagrad", "adagrad_window", diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py new file mode 100644 index 0000000000..843961c233 --- /dev/null +++ b/pymc/variational/streaming.py @@ -0,0 +1,312 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Out-of-core minibatching for variational inference. + +``pm.Minibatch`` random-indexes an array that is *fully resident in memory*; its +peak memory is therefore O(N) in the dataset size. ``StreamingDataset`` instead +feeds minibatches from an arbitrary source (a generator, a directory of Parquet +shards, ...) into a small fixed-size ``pytensor.shared`` buffer, so peak memory is +O(buffer) -- the batch buffer plus, if used, the shuffle buffer -- and +independent of N. The unbiased-gradient rescaling is the *same* as for +``pm.Minibatch``: pass ``total_size=N`` to the observed distribution and PyMC +scales the minibatch log-likelihood by ``N / batch_size`` through the existing +:func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. + +The one extra obligation relative to ``pm.Minibatch`` is **shuffling**. +``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its +minibatches are i.i.d. by construction. A streaming source is only as well +mixed as the order it yields rows in: reading time/row-ordered data through a +*bounded* buffer is merely a block-shuffle and biases the variational posterior. +Pre-shuffle the data once (or interleave shards) and/or use :func:`shuffle_buffer`. + +Example +------- +.. code-block:: python + + import pymc as pm + from pymc.variational.streaming import StreamingDataset, shuffle_buffer + + N = 10_000_000 # rows on disk; never all in memory at once + + + def chunks(): # yields (rows, n_features+1) float64 blocks off disk + for shard in shards: + yield read(shard) + + + ds = StreamingDataset( + shuffle_buffer(chunks, buffer_size=1_000_000, batch_size=4096, seed=0), + batch_size=4096, + sample_shape=(4,), # 3 features + 1 observed column + total_size=N, + ) + ds.advance() # seed the buffer + + with pm.Model(): + b = pm.Normal("b", 0.0, 3.0, shape=4) + buf = ds.as_tensor() # (batch_size, 4) shared + logit = b[0] + b[1] * buf[:, 0] + b[2] * buf[:, 1] + b[3] * buf[:, 2] + pm.Bernoulli("y", logit_p=logit, observed=buf[:, 3], total_size=ds.total_size) + approx = pm.fit(20_000, method="advi", callbacks=[ds.fit_callback()]) +""" + +from __future__ import annotations + +import warnings + +from collections.abc import Callable, Iterable, Iterator + +import numpy as np +import pytensor +import pytensor.tensor as pt + + +class StreamingDataset: + """Feed minibatches to variational inference from an out-of-core source. + + Parameters + ---------- + source : Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] + Yields ``np.ndarray`` batches of shape ``(batch_size, *sample_shape)``. + Pass a zero-arg *callable* (a factory) so the stream can be restarted + when ``cycle=True``; a bare generator can only be consumed once. + batch_size : int + Leading dimension of every yielded batch (and of the buffer). + sample_shape : tuple of int, default () + Trailing shape of a single observation. ``()`` for scalar observations, + ``(k,)`` to stream ``k`` columns (e.g. features + the observed column). + dtype : str, default "float64" + Dtype of the shared buffer. If it differs from ``pytensor.config.floatX`` + the model will insert a per-step cast on the observed tensor. + total_size : int, optional + The true dataset size ``N``. Pass it to the observed distribution as + ``total_size=ds.total_size`` so the minibatch log-likelihood is rescaled + by ``N / batch_size`` (the same mechanism as ``pm.Minibatch``). Unlike + ``pm.Minibatch`` it cannot be inferred from a resident array, so it must + be supplied; a warning is issued at construction if it is left ``None``. + preprocess_fn : callable, optional + Pure transform applied to each batch before it lands in the buffer. + cycle : bool, default True + Restart the source when exhausted (the usual case: many epochs). If + ``False``, :meth:`advance` raises ``StopIteration`` once exhausted. + name : str + Name of the underlying ``pytensor.shared`` variable. + """ + + def __init__( + self, + source: Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], + *, + batch_size: int, + sample_shape: tuple[int, ...] = (), + dtype: str = "float64", + total_size: int | None = None, + preprocess_fn: Callable[[np.ndarray], np.ndarray] | None = None, + cycle: bool = True, + name: str = "streaming_buffer", + ): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError(f"batch_size must be a positive integer, got {batch_size}") + if total_size is None: + warnings.warn( + "StreamingDataset created with total_size=None: the minibatch " + "log-likelihood will NOT be rescaled and the posterior will be " + "biased. Pass total_size=N (the true dataset size).", + UserWarning, + stacklevel=2, + ) + + self._source_factory = _make_factory(source) + self._source_iter: Iterator[np.ndarray] = self._source_factory() + self._batch_size = batch_size + self._sample_shape = tuple(sample_shape) + self._dtype = dtype + self._total_size = total_size + self._preprocess_fn = preprocess_fn + self._cycle = cycle + + self._batches_seen = 0 + self._rows_streamed = 0 + + self._shared = pytensor.shared( + np.zeros((batch_size, *self._sample_shape), dtype=dtype), name=name + ) + + # ----- read-only state --------------------------------------------------- + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def total_size(self) -> int | None: + """The dataset size ``N`` (pass to the distribution's ``total_size``).""" + return self._total_size + + @property + def batches_seen(self) -> int: + return self._batches_seen + + @property + def rows_streamed(self) -> int: + """Total rows pushed through the buffer (grows past ``N`` across epochs).""" + return self._rows_streamed + + # ----- the model-facing tensor ------------------------------------------ + + def as_tensor(self) -> pt.TensorVariable: + """The ``pytensor.shared`` buffer the model observes (mutates each step).""" + return self._shared + + # ----- the only mutator -------------------------------------------------- + + def advance(self) -> None: + """Pull the next batch from the source into the buffer.""" + batch = self._next_batch() + if self._preprocess_fn is not None: + batch = self._preprocess_fn(batch) + self._validate(batch) + # Own a fresh contiguous copy before borrowing into the shared variable: + # the source may legitimately yield *views* into a reused array, so we + # must not alias it. np.array(copy default) guarantees an owned array. + arr = np.array(batch, dtype=self._dtype) + self._shared.set_value(arr, borrow=True) + self._batches_seen += 1 + self._rows_streamed += int(arr.shape[0]) + + def fit_callback(self) -> Callable: + """A 3-arg callback ``(approx, losses, i)`` for ``pm.fit(callbacks=...)``.""" + + def _cb(*_): + self.advance() + + return _cb + + # ----- iterator sugar ---------------------------------------------------- + + def __iter__(self) -> StreamingDataset: + return self + + def __next__(self) -> np.ndarray: + self.advance() + return self._shared.get_value(borrow=False) # an owned copy, safe to keep + + # ----- internals --------------------------------------------------------- + + def _next_batch(self) -> np.ndarray: + try: + return next(self._source_iter) + except StopIteration: + if not self._cycle: + raise + self._source_iter = self._source_factory() + return next(self._source_iter) + + def _validate(self, batch: np.ndarray) -> None: + if not isinstance(batch, np.ndarray): + raise TypeError(f"expected np.ndarray batch, got {type(batch).__name__}") + if batch.shape[0] != self._batch_size: + raise ValueError( + f"batch shape[0] = {batch.shape[0]} does not match batch_size = " + f"{self._batch_size}; partial batches are not allowed (drop them in " + "the source, e.g. via shuffle_buffer)." + ) + if batch.shape[1:] != self._sample_shape: + raise ValueError( + f"batch sample-shape {batch.shape[1:]} does not match declared " + f"sample_shape={self._sample_shape}" + ) + + +def shuffle_buffer( + chunk_source: Callable[[], Iterator[np.ndarray]], + *, + buffer_size: int, + batch_size: int, + seed: int | None = None, +) -> Callable[[], Iterator[np.ndarray]]: + """Wrap a chunk source into a shuffled, fixed-size batch source. + + Accumulates rows from ``chunk_source`` into a buffer of at least + ``buffer_size`` rows, shuffles it, and yields ``batch_size`` slices; rows that + do not fill a final batch are **carried over** into the next buffer (never + dropped) until the source is exhausted, at which point a single trailing + partial batch (< ``batch_size`` rows) is dropped. This approximates i.i.d. + minibatches from an *unordered* or pre-shuffled stream. + + It does **not** by itself fix a strongly time/row-ordered stream (a bounded + buffer only block-shuffles such data) -- pre-shuffle on disk, or interleave + shards into ``chunk_source``, for that. Note ``buffer_size`` is a *lower* + bound: a single yielded chunk larger than ``buffer_size`` is taken whole, so + peak buffer memory is ``max(buffer_size, largest_chunk_rows)``. + """ + + def factory() -> Iterator[np.ndarray]: + rng = np.random.default_rng(seed) + it = chunk_source() + carry: np.ndarray | None = None # leftover (< batch_size) from last fill + exhausted = False + while not exhausted: + bufs: list[np.ndarray] = [] + have = 0 + if carry is not None: + bufs.append(carry) + have += carry.shape[0] + carry = None + for arr in it: + a = np.asarray(arr) + bufs.append(a) + have += a.shape[0] + if have >= buffer_size: + break + else: + exhausted = True # for-loop ran to completion: source is done + if have < batch_size: + return # nothing left that can form a batch + buf = np.concatenate(bufs, axis=0) # always a fresh, owned copy + rng.shuffle(buf) + n_full = buf.shape[0] // batch_size + for i in range(n_full): + yield buf[i * batch_size : (i + 1) * batch_size] + rem = buf.shape[0] - n_full * batch_size + carry = buf[n_full * batch_size :].copy() if rem else None + + return factory + + +def _make_factory( + source: Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], +) -> Callable[[], Iterator[np.ndarray]]: + """Coerce ``source`` into a zero-arg callable returning a fresh iterator. + + A callable that is not itself an iterator is treated as the factory; a bare + iterator is wrapped (and refuses a second epoch); any other iterable is + re-``iter``-ed each epoch. + """ + if callable(source) and not isinstance(source, Iterator): + return source # type: ignore[return-value] + if isinstance(source, Iterator): + consumed = {"done": False} + + def _factory() -> Iterator[np.ndarray]: + if consumed["done"]: + raise RuntimeError( + "source is a bare iterator and cycle=True was requested; pass a " + "zero-arg factory or a re-iterable instead" + ) + consumed["done"] = True + return source + + return _factory + return lambda: iter(source) diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py new file mode 100644 index 0000000000..5cfd3cc4db --- /dev/null +++ b/tests/variational/test_streaming.py @@ -0,0 +1,180 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +import pymc as pm + +from pymc.variational.streaming import StreamingDataset, shuffle_buffer + + +def _chunks(data, size): + def factory(): + for i in range(0, len(data), size): + yield data[i : i + size] + + return factory + + +def test_advance_shape_and_counters(): + data = np.arange(40, dtype="float64").reshape(20, 2) + ds = StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(2,), total_size=20) + assert ds.batches_seen == 0 + ds.advance() + assert ds.as_tensor().get_value().shape == (4, 2) + assert ds.batches_seen == 1 and ds.rows_streamed == 4 + ds.advance() + assert ds.batches_seen == 2 and ds.rows_streamed == 8 + + +def test_wrong_batch_shape_rejected(): + data = np.zeros((10, 2)) + ds = StreamingDataset(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) + with pytest.raises(ValueError, match="does not match batch_size"): + ds.advance() + + +def test_total_size_none_warns_at_construction(): + data = np.zeros((8, 1)) + with pytest.warns(UserWarning, match="total_size=None"): + StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(1,)) + + +def test_cycle_true_restarts_source(): + data = np.arange(8, dtype="float64").reshape(8, 1) + ds = StreamingDataset( + _chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=True + ) + for _ in range(4): # two epochs worth + ds.advance() + assert ds.batches_seen == 4 + + +def test_cycle_false_raises_when_exhausted(): + data = np.arange(8, dtype="float64").reshape(8, 1) + ds = StreamingDataset( + _chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=False + ) + ds.advance() + ds.advance() + with pytest.raises(StopIteration): + ds.advance() + + +def test_preprocess_fn_applied(): + data = np.ones((8, 1)) + ds = StreamingDataset( + _chunks(data, 4), + batch_size=4, + sample_shape=(1,), + total_size=8, + preprocess_fn=lambda b: b * 3.0, + ) + ds.advance() + np.testing.assert_array_equal(ds.as_tensor().get_value(), np.full((4, 1), 3.0)) + + +def test_shuffle_buffer_conserves_rows_non_dividing(): + # buffer_size and chunk size deliberately do NOT divide batch_size: the + # carry-over must not lose or duplicate any row (regression for the drop bug). + data = np.arange(140, dtype="float64").reshape(140, 1) + src = shuffle_buffer(_chunks(data, 7), buffer_size=55, batch_size=10, seed=0) + batches = list(src()) + assert all(b.shape == (10, 1) for b in batches) + seen = np.sort(np.concatenate([b.ravel() for b in batches])) + # 140 rows, batch 10 -> 14 full batches, nothing dropped (140 % 10 == 0) + np.testing.assert_array_equal(seen, data.ravel()) + + +def test_shuffle_buffer_does_not_mutate_source(): + data = np.arange(100, dtype="float64").reshape(100, 1) + original = data.copy() + src = shuffle_buffer(_chunks(data, 25), buffer_size=40, batch_size=10, seed=1) + list(src()) + np.testing.assert_array_equal(data, original) # source untouched + + +def test_total_size_rescales_logp_like_minibatch(): + # observed=buf[:, k] + total_size=N must scale the observed log-likelihood by + # N / batch_size via the existing create_minibatch_rv path -- pin this without + # training anything. + rng = np.random.default_rng(0) + N, bs = 1000, 16 + data = rng.normal(size=(bs, 1)) + ds = StreamingDataset(lambda: iter([data]), batch_size=bs, sample_shape=(1,), total_size=N) + ds.advance() + + with pm.Model() as scaled: + mu = pm.Normal("mu", 0, 1) + pm.Normal("y", mu, 1, observed=ds.as_tensor()[:, 0], total_size=ds.total_size) + with pm.Model() as plain: + mu = pm.Normal("mu", 0, 1) + pm.Normal("y", mu, 1, observed=data[:, 0]) # no total_size + + point = {"mu": np.array(0.3)} + obs_scaled = scaled.compile_logp(scaled.observed_RVs)(point) + obs_plain = plain.compile_logp(plain.observed_RVs)(point) + np.testing.assert_allclose(obs_scaled, obs_plain * (N / bs), rtol=1e-6) + + +def test_equivalence_with_in_ram_minibatch(): + """End-to-end: streaming ADVI reproduces in-RAM pm.Minibatch ADVI.""" + seed = 0 + rng = np.random.default_rng(seed) + N, bs = 60_000, 2048 + X = rng.normal(size=(N, 2)) + b_true = np.array([0.3, -1.1, 0.7]) + y = (rng.random(N) < 1 / (1 + np.exp(-(b_true[0] + X @ b_true[1:])))).astype("float64") + data = np.column_stack([X, y]) + + with pm.Model(): + b = pm.Normal("b", 0, 3, shape=3) + xb, zb, yb = pm.Minibatch(X[:, 0].copy(), X[:, 1].copy(), y, batch_size=bs) + pm.Bernoulli("o", logit_p=b[0] + b[1] * xb + b[2] * zb, observed=yb, total_size=N) + ap = pm.fit( + 6000, + method="advi", + obj_optimizer=pm.adam(learning_rate=0.02), + progressbar=False, + random_seed=seed, + ) + in_ram = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) + + ds = StreamingDataset( + shuffle_buffer(_chunks(data, 20_000), buffer_size=40_000, batch_size=bs, seed=seed), + batch_size=bs, + sample_shape=(3,), + total_size=N, + ) + ds.advance() + with pm.Model(): + b = pm.Normal("b", 0, 3, shape=3) + buf = ds.as_tensor() + pm.Bernoulli( + "o", + logit_p=b[0] + b[1] * buf[:, 0] + b[2] * buf[:, 1], + observed=buf[:, 2], + total_size=ds.total_size, + ) + ap = pm.fit( + 6000, + method="advi", + obj_optimizer=pm.adam(learning_rate=0.02), + callbacks=[ds.fit_callback()], + progressbar=False, + random_seed=seed, + ) + stream = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) + + np.testing.assert_allclose(in_ram, stream, atol=0.1) From cc3658d5a1e040a43e2fa400acf1b1b0e30ab650 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Fri, 5 Jun 2026 10:28:29 -0500 Subject: [PATCH 02/27] harden StreamingDataset validation (deep-review fixes) Close three silent-corruption holes found in a 5-lens review: - reject total_size <= 0 in __init__: 0 is falsy and skips the N/batch_size rescaling entirely (posterior collapses to prior); negative flips the data log-likelihood's sign via get_scaling. - shuffle_buffer now accumulates max(buffer_size, batch_size) rows before emitting, so buffer_size < batch_size no longer silently discards the whole stream; also validate buffer_size/batch_size as positive ints. - positive-int checks use numbers.Integral (accept numpy ints, reject bool). +5 regression tests; existing 10 unchanged and passing. --- pymc/variational/streaming.py | 57 +++++++++++++++++++++++------ tests/variational/test_streaming.py | 49 +++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 12 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 843961c233..54ff62af81 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -63,6 +63,7 @@ def chunks(): # yields (rows, n_features+1) float64 blocks off disk from __future__ import annotations +import numbers import warnings from collections.abc import Callable, Iterable, Iterator @@ -72,6 +73,11 @@ def chunks(): # yields (rows, n_features+1) float64 blocks off disk import pytensor.tensor as pt +def _is_positive_int(value: object) -> bool: + """True for a strictly positive integer (incl. numpy integer types), excluding bool.""" + return isinstance(value, numbers.Integral) and not isinstance(value, bool) and value > 0 + + class StreamingDataset: """Feed minibatches to variational inference from an out-of-core source. @@ -90,11 +96,13 @@ class StreamingDataset: Dtype of the shared buffer. If it differs from ``pytensor.config.floatX`` the model will insert a per-step cast on the observed tensor. total_size : int, optional - The true dataset size ``N``. Pass it to the observed distribution as - ``total_size=ds.total_size`` so the minibatch log-likelihood is rescaled - by ``N / batch_size`` (the same mechanism as ``pm.Minibatch``). Unlike - ``pm.Minibatch`` it cannot be inferred from a resident array, so it must - be supplied; a warning is issued at construction if it is left ``None``. + The true dataset size ``N`` (a positive integer). Pass it to the observed + distribution as ``total_size=ds.total_size`` so the minibatch + log-likelihood is rescaled by ``N / batch_size`` (the same mechanism as + ``pm.Minibatch``). Unlike ``pm.Minibatch`` it cannot be inferred from a + resident array, so it must be supplied; ``None`` warns at construction and + a non-positive value raises (it would otherwise silently disable or invert + the rescaling). preprocess_fn : callable, optional Pure transform applied to each batch before it lands in the buffer. cycle : bool, default True @@ -116,8 +124,8 @@ def __init__( cycle: bool = True, name: str = "streaming_buffer", ): - if not isinstance(batch_size, int) or batch_size <= 0: - raise ValueError(f"batch_size must be a positive integer, got {batch_size}") + if not _is_positive_int(batch_size): + raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") if total_size is None: warnings.warn( "StreamingDataset created with total_size=None: the minibatch " @@ -126,6 +134,17 @@ def __init__( UserWarning, stacklevel=2, ) + elif not _is_positive_int(total_size): + # A non-positive total_size is silently dangerous: 0 is falsy, so the + # model never wraps the observed RV and the N/batch_size rescaling is + # skipped (posterior collapses toward the prior); a negative value + # yields a negative scaling coefficient that flips the data + # log-likelihood's sign (VI then maximizes mis-fit). Reject it loudly. + raise ValueError( + "total_size must be a positive integer (the true dataset size N) so " + "the minibatch log-likelihood is rescaled by N / batch_size; got " + f"{total_size!r}." + ) self._source_factory = _make_factory(source) self._source_iter: Iterator[np.ndarray] = self._source_factory() @@ -247,16 +266,28 @@ def shuffle_buffer( It does **not** by itself fix a strongly time/row-ordered stream (a bounded buffer only block-shuffles such data) -- pre-shuffle on disk, or interleave - shards into ``chunk_source``, for that. Note ``buffer_size`` is a *lower* - bound: a single yielded chunk larger than ``buffer_size`` is taken whole, so - peak buffer memory is ``max(buffer_size, largest_chunk_rows)``. + shards into ``chunk_source``, for that. ``buffer_size`` is a *lower* bound: the + buffer always accumulates at least ``max(buffer_size, batch_size)`` rows before + emitting (so a ``buffer_size`` smaller than ``batch_size`` still yields full + batches instead of silently dropping the stream), and a single chunk larger + than that is taken whole, so peak buffer memory is + ``max(buffer_size, batch_size, largest_chunk_rows)``. """ + if not _is_positive_int(batch_size): + raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") + if not _is_positive_int(buffer_size): + raise ValueError(f"buffer_size must be a positive integer, got {buffer_size!r}") def factory() -> Iterator[np.ndarray]: rng = np.random.default_rng(seed) it = chunk_source() carry: np.ndarray | None = None # leftover (< batch_size) from last fill exhausted = False + # Accumulate at least one full batch's worth even when buffer_size < + # batch_size: otherwise the inner loop would break early with fewer than + # batch_size rows and the `have < batch_size` guard below would silently + # discard the entire stream. + target = max(buffer_size, batch_size) while not exhausted: bufs: list[np.ndarray] = [] have = 0 @@ -268,12 +299,14 @@ def factory() -> Iterator[np.ndarray]: a = np.asarray(arr) bufs.append(a) have += a.shape[0] - if have >= buffer_size: + if have >= target: break else: exhausted = True # for-loop ran to completion: source is done if have < batch_size: - return # nothing left that can form a batch + # Only reachable once the source is exhausted: drop the final + # sub-batch remainder (it cannot form a full batch). + return buf = np.concatenate(bufs, axis=0) # always a fresh, owned copy rng.shuffle(buf) n_full = buf.shape[0] // batch_size diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 5cfd3cc4db..0e9ca7f69f 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -178,3 +178,52 @@ def test_equivalence_with_in_ram_minibatch(): stream = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) np.testing.assert_allclose(in_ram, stream, atol=0.1) + + +def test_total_size_zero_raises(): + # total_size=0 is falsy: it slips a None-only check and the model's truthy + # `if total_size:` guard, silently skipping the N/batch_size rescaling. + data = np.zeros((8, 1)) + with pytest.raises(ValueError, match="positive integer"): + StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=0) + + +def test_total_size_negative_raises(): + # negative total_size is truthy but yields a negative scaling coefficient + # (the data log-likelihood's sign flips, so VI maximizes mis-fit). + data = np.zeros((8, 1)) + with pytest.raises(ValueError, match="positive integer"): + StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=-100) + + +def test_shuffle_buffer_small_buffer_conserves_rows(): + # buffer_size < batch_size must NOT silently discard the dataset: the buffer + # accumulates to at least batch_size before emitting (regression for the + # early-return data-loss bug). + data = np.arange(120, dtype="float64").reshape(120, 1) + src = shuffle_buffer(_chunks(data, 7), buffer_size=3, batch_size=10, seed=0) + batches = list(src()) + assert batches, "buffer_size < batch_size silently produced zero batches" + assert all(b.shape == (10, 1) for b in batches) + seen = np.sort(np.concatenate([b.ravel() for b in batches])) + np.testing.assert_array_equal(seen, data.ravel()) # 120 % 10 == 0, nothing dropped + + +def test_shuffle_buffer_rejects_nonpositive_sizes(): + data = np.zeros((10, 1)) + with pytest.raises(ValueError, match="buffer_size"): + shuffle_buffer(_chunks(data, 5), buffer_size=0, batch_size=4) + with pytest.raises(ValueError, match="batch_size"): + shuffle_buffer(_chunks(data, 5), buffer_size=10, batch_size=0) + + +def test_accepts_numpy_integer_sizes_rejects_bool(): + # the positive-int check uses numbers.Integral: numpy ints are valid, bool is not. + data = np.zeros((8, 1)) + ds = StreamingDataset( + _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) + ) + ds.advance() + assert ds.batch_size == 4 + with pytest.raises(ValueError): + StreamingDataset(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8) From 2c47255d0373990d2e679321014b73029bc20c5b Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Fri, 5 Jun 2026 10:45:49 -0500 Subject: [PATCH 03/27] strengthen shuffle_buffer: reshuffle each epoch A seeded shuffle_buffer rebuilt its RNG from the same seed on every factory call, so under cycle=True every epoch replayed one fixed permutation -- which weakens the very mixing the buffer exists to provide and compounds the block-shuffle bias on ordered data. Derive a fresh per-epoch sub-stream from a SeedSequence so the order differs across epochs while staying reproducible for a given seed. +2 tests. --- pymc/variational/streaming.py | 11 +++++++++- tests/variational/test_streaming.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 54ff62af81..7b90a425de 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -272,14 +272,23 @@ def shuffle_buffer( batches instead of silently dropping the stream), and a single chunk larger than that is taken whole, so peak buffer memory is ``max(buffer_size, batch_size, largest_chunk_rows)``. + + Each epoch (each call of the returned factory) draws a fresh permutation from + a sub-stream of ``seed``, so the shuffle order differs across epochs -- a + seeded buffer must not replay one fixed order forever -- while staying + reproducible for a given ``seed``. """ if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") if not _is_positive_int(buffer_size): raise ValueError(f"buffer_size must be a positive integer, got {buffer_size!r}") + seed_seq = np.random.SeedSequence(seed) def factory() -> Iterator[np.ndarray]: - rng = np.random.default_rng(seed) + # Spawn a fresh sub-stream per epoch so re-iterating (cycle=True) reshuffles + # rather than replaying one fixed permutation forever; still reproducible + # across runs for a given seed. + rng = np.random.default_rng(seed_seq.spawn(1)[0]) it = chunk_source() carry: np.ndarray | None = None # leftover (< batch_size) from last fill exhausted = False diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 0e9ca7f69f..ec184dcb6e 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -227,3 +227,34 @@ def test_accepts_numpy_integer_sizes_rejects_bool(): assert ds.batch_size == 4 with pytest.raises(ValueError): StreamingDataset(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8) + + +def test_shuffle_buffer_reshuffles_across_epochs(): + # a seeded buffer must NOT replay one fixed permutation every epoch (that + # would weaken shuffling under cycle=True); each epoch reshuffles, but rows + # are conserved. + data = np.arange(60, dtype="float64").reshape(60, 1) + factory = shuffle_buffer(_chunks(data, 10), buffer_size=60, batch_size=10, seed=0) + epoch1 = np.concatenate([b.ravel() for b in factory()]) + epoch2 = np.concatenate([b.ravel() for b in factory()]) + assert not np.array_equal(epoch1, epoch2) # different order across epochs + np.testing.assert_array_equal(np.sort(epoch1), data.ravel()) # but conserves rows + np.testing.assert_array_equal(np.sort(epoch2), data.ravel()) + + +def test_shuffle_buffer_seed_reproducible_across_runs(): + # same seed => identical first-epoch order across independent constructions. + data = np.arange(60, dtype="float64").reshape(60, 1) + a = np.concatenate( + [ + b.ravel() + for b in shuffle_buffer(_chunks(data, 10), buffer_size=60, batch_size=10, seed=7)() + ] + ) + b = np.concatenate( + [ + b.ravel() + for b in shuffle_buffer(_chunks(data, 10), buffer_size=60, batch_size=10, seed=7)() + ] + ) + np.testing.assert_array_equal(a, b) From 90b5b83e4f6161f9d2ce98cb3216920571e9c369 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Sat, 6 Jun 2026 08:31:40 -0500 Subject: [PATCH 04/27] prototype: total_size="auto" + rows_streamed sanity warning Cuts the "user must pass total_size" burden (open question #1 for the design review): - total_size="auto" resolves N from a source's .n_rows (cheap -- e.g. Parquet footer metadata via the new parquet_source) else one counting pass over a finite, re-readable source. One-shot / infinite sources still pass total_size explicitly (and are rejected with a clear error under "auto"). - a free sanity check using the existing rows_streamed counter: at the first epoch boundary, warn if total_size grossly disagrees with the rows actually streamed in one pass (catches a wrong-but-positive total_size). - parquet_source(directory): a finite, re-readable source carrying .n_rows read from Parquet metadata (no data scan). +7 tests; the existing 17 are unchanged and still pass. --- pymc/variational/__init__.py | 5 +- pymc/variational/streaming.py | 105 +++++++++++++++++- tests/variational/test_streaming_autosize.py | 107 +++++++++++++++++++ 3 files changed, 211 insertions(+), 6 deletions(-) create mode 100644 tests/variational/test_streaming_autosize.py diff --git a/pymc/variational/__init__.py b/pymc/variational/__init__.py index d9fe170822..414d50a6b1 100644 --- a/pymc/variational/__init__.py +++ b/pymc/variational/__init__.py @@ -44,7 +44,7 @@ # special from pymc.variational.stein import Stein -from pymc.variational.streaming import StreamingDataset, shuffle_buffer +from pymc.variational.streaming import StreamingDataset, parquet_source, shuffle_buffer from pymc.variational.updates import ( adadelta, adagrad, @@ -71,7 +71,6 @@ "Group", "MeanField", "StreamingDataset", - "shuffle_buffer", "adadelta", "adagrad", "adagrad_window", @@ -83,8 +82,10 @@ "momentum", "nesterov_momentum", "norm_constraint", + "parquet_source", "rmsprop", "sample_approx", "sgd", + "shuffle_buffer", "total_norm_constraint", ) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 7b90a425de..bfd3cb97eb 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -119,18 +119,28 @@ def __init__( batch_size: int, sample_shape: tuple[int, ...] = (), dtype: str = "float64", - total_size: int | None = None, + total_size: int | str | None = None, preprocess_fn: Callable[[np.ndarray], np.ndarray] | None = None, cycle: bool = True, name: str = "streaming_buffer", ): if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") - if total_size is None: + + self._source_factory = _make_factory(source) + + if isinstance(total_size, str): + if total_size != "auto": + raise ValueError(f"total_size string must be 'auto', got {total_size!r}") + # Resolve N automatically: a source-provided .n_rows (cheap, e.g. from + # parquet_source's metadata) else one counting pass over a finite, + # re-readable source. One-shot / infinite sources cannot be auto-counted. + total_size = _auto_total_size(self._source_factory, source) + elif total_size is None: warnings.warn( "StreamingDataset created with total_size=None: the minibatch " "log-likelihood will NOT be rescaled and the posterior will be " - "biased. Pass total_size=N (the true dataset size).", + "biased. Pass total_size=N (the true dataset size) or total_size='auto'.", UserWarning, stacklevel=2, ) @@ -146,7 +156,6 @@ def __init__( f"{total_size!r}." ) - self._source_factory = _make_factory(source) self._source_iter: Iterator[np.ndarray] = self._source_factory() self._batch_size = batch_size self._sample_shape = tuple(sample_shape) @@ -157,6 +166,7 @@ def __init__( self._batches_seen = 0 self._rows_streamed = 0 + self._warned_size = False # the sanity check below fires at most once self._shared = pytensor.shared( np.zeros((batch_size, *self._sample_shape), dtype=dtype), name=name @@ -229,9 +239,28 @@ def _next_batch(self) -> np.ndarray: except StopIteration: if not self._cycle: raise + # First exhaustion == one full pass: rows_streamed now equals the real + # row count, so we can sanity-check the user's total_size for free. + self._maybe_warn_total_size() self._source_iter = self._source_factory() return next(self._source_iter) + def _maybe_warn_total_size(self) -> None: + """Warn once if total_size grossly disagrees with the rows seen in one pass.""" + if self._warned_size or self._total_size is None: + return + self._warned_size = True + seen = self._rows_streamed + if seen and abs(self._total_size - seen) > 0.1 * seen: + warnings.warn( + f"total_size={self._total_size} disagrees with the {seen} rows streamed " + f"in one full pass; the N/batch_size rescaling -- and therefore the " + f"posterior width -- is likely wrong. Pass the true dataset size, or " + f"total_size='auto'.", + UserWarning, + stacklevel=3, + ) + def _validate(self, batch: np.ndarray) -> None: if not isinstance(batch, np.ndarray): raise TypeError(f"expected np.ndarray batch, got {type(batch).__name__}") @@ -352,3 +381,71 @@ def _factory() -> Iterator[np.ndarray]: return _factory return lambda: iter(source) + + +def _auto_total_size( + factory: Callable[[], Iterator[np.ndarray]], + source: object, +) -> int: + """Resolve ``total_size="auto"``: a source ``.n_rows`` (cheap) else a counting pass. + + Fast path: if ``source`` advertises ``.n_rows`` (e.g. :func:`parquet_source`, which + reads it from Parquet metadata without scanning the data) use it directly. Otherwise + do a single counting pass over a finite, re-readable source. A bare one-shot iterator + cannot be auto-counted (counting consumes it) and an infinite stream would make the + pass hang -- both must pass ``total_size`` explicitly. + """ + n = getattr(source, "n_rows", None) + if n is not None: + if not _is_positive_int(n): + raise ValueError(f"source.n_rows must be a positive integer, got {n!r}") + return int(n) + if isinstance(source, Iterator): + raise ValueError( + "total_size='auto' needs a re-readable source (a zero-arg factory or an " + "iterable), not a one-shot iterator; pass total_size=N explicitly instead." + ) + warnings.warn( + "total_size='auto' is doing a full counting pass over the source; for a cheap " + "path use a source exposing .n_rows (e.g. parquet_source, from Parquet metadata).", + UserWarning, + stacklevel=3, + ) + count = 0 + for chunk in factory(): + count += int(np.asarray(chunk).shape[0]) + if count <= 0: + raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") + return count + + +def parquet_source( + directory: str, + *, + columns: list[str] | None = None, + pattern: str = "*.parquet", +) -> Callable[[], Iterator[np.ndarray]]: + """A finite, re-readable streaming source over a directory of Parquet files. + + Yields one ``(rows, n_columns)`` ``float64`` array per file, and carries an + ``n_rows`` attribute read from Parquet *metadata* (no data scan) so that + ``StreamingDataset(parquet_source(dir), ..., total_size="auto")`` resolves the + dataset size for free. Wrap it in :func:`shuffle_buffer` to get fixed-size, + shuffled batches. + """ + import glob as _glob + import os + + import pyarrow.parquet as pq + + paths = sorted(_glob.glob(os.path.join(directory, pattern))) + if not paths: + raise ValueError(f"no Parquet files match {os.path.join(directory, pattern)!r}") + + def factory() -> Iterator[np.ndarray]: + for path in paths: + table = pq.read_table(path, columns=columns) + yield np.column_stack([table.column(c).to_numpy() for c in table.column_names]) + + factory.n_rows = sum(pq.read_metadata(p).num_rows for p in paths) # type: ignore[attr-defined] + return factory diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py new file mode 100644 index 0000000000..1c33791572 --- /dev/null +++ b/tests/variational/test_streaming_autosize.py @@ -0,0 +1,107 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prototype: total_size='auto' resolution + the rows_streamed sanity warning.""" + +import warnings + +import numpy as np +import pytest + +from pymc.variational.streaming import StreamingDataset, parquet_source + + +def _factory(data, size): + """A re-readable zero-arg factory yielding `size`-row chunks of `data`.""" + + def f(): + for i in range(0, len(data), size): + yield data[i : i + size] + + return f + + +def test_auto_counts_finite_source(): + # no .n_rows -> auto does one counting pass and resolves the true N. + data = np.arange(60, dtype="float64").reshape(60, 1) + with pytest.warns(UserWarning, match="counting pass"): + ds = StreamingDataset( + _factory(data, 7), batch_size=10, sample_shape=(1,), total_size="auto" + ) + assert ds.total_size == 60 + + +def test_auto_uses_n_rows_fast_path(): + # source advertises .n_rows -> auto trusts it WITHOUT counting (the factory only + # really yields 8 rows, but n_rows says 999; auto must return 999). + data = np.zeros((8, 1)) + f = _factory(data, 4) + f.n_rows = 999 + ds = StreamingDataset(f, batch_size=4, sample_shape=(1,), total_size="auto") + assert ds.total_size == 999 + + +def test_auto_rejects_one_shot_iterator(): + # a bare generator is consumed by counting -> auto must refuse it. + data = np.zeros((20, 1)) + one_shot = (data[i : i + 4] for i in range(0, 20, 4)) + with pytest.raises(ValueError, match="re-readable"): + StreamingDataset(one_shot, batch_size=4, sample_shape=(1,), total_size="auto") + + +def test_auto_rejects_bad_n_rows(): + f = _factory(np.zeros((8, 1)), 4) + f.n_rows = 0 + with pytest.raises(ValueError, match="n_rows must be a positive integer"): + StreamingDataset(f, batch_size=4, sample_shape=(1,), total_size="auto") + + +def test_sanity_warns_on_grossly_wrong_total_size(): + # one full pass = 20 rows, but total_size=100 -> at the first epoch boundary, warn. + data = np.arange(20, dtype="float64").reshape(20, 1) + ds = StreamingDataset(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=100) + with pytest.warns(UserWarning, match="disagrees with"): + for _ in range(6): # 5 batches = one epoch, the 6th crosses the boundary + ds.advance() + + +def test_sanity_silent_when_total_size_matches(): + data = np.arange(20, dtype="float64").reshape(20, 1) + ds = StreamingDataset(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=20) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) # any UserWarning fails the test + for _ in range(6): + ds.advance() + + +def test_parquet_source_n_rows_from_metadata(tmp_path): + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + rng = np.random.default_rng(0) + total = 0 + for i in range(3): + n = 100 + 50 * i + total += n + block = rng.normal(size=(n, 2)) + pq.write_table( + pa.table({"a": block[:, 0], "b": block[:, 1]}), + f"{tmp_path}/part_{i:02d}.parquet", + ) + src = parquet_source(str(tmp_path)) + assert src.n_rows == total # read from metadata, no data scan + + # and total_size='auto' picks it up for free (no counting pass / warning) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ds = StreamingDataset(src, batch_size=10, sample_shape=(2,), total_size="auto") + assert ds.total_size == total From a75e7cfa60449e00c8e72f3aa4a3c801f6adbcc6 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Sat, 6 Jun 2026 11:27:32 -0500 Subject: [PATCH 05/27] fix(streaming): normalize int sizes; harden factory, callback, validation An adversarial re-review surfaced edge cases the first hardening pass missed: - total_size / batch_size: numpy integers were accepted but stored unchanged, so a stored np.int64 reached create_minibatch_rv and raised "Invalid type for total_size". Normalize to Python int at construction. - _make_factory: a zero-arg factory returning a non-iterator iterable (e.g. a list) crashed in __next__ ("'list' object is not an iterator"); wrap in iter(). - total_size="auto": a factory that returns the same one-shot iterator each call now raises, instead of leaving the first advance() empty. - fit_callback: seeds the buffer by default. PyMC runs callbacks after each step, so an unseeded first step trained on the zero-initialized placeholder. - _validate: a 0-D batch now raises a clear ValueError instead of IndexError. Adds 7 regression tests (31 total). --- pymc/variational/streaming.py | 46 ++++++++++++--- tests/variational/test_streaming.py | 62 ++++++++++++++++++++ tests/variational/test_streaming_autosize.py | 9 +++ 3 files changed, 110 insertions(+), 7 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index bfd3cb97eb..b54497529e 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -51,7 +51,6 @@ def chunks(): # yields (rows, n_features+1) float64 blocks off disk sample_shape=(4,), # 3 features + 1 observed column total_size=N, ) - ds.advance() # seed the buffer with pm.Model(): b = pm.Normal("b", 0.0, 3.0, shape=4) @@ -157,10 +156,14 @@ def __init__( ) self._source_iter: Iterator[np.ndarray] = self._source_factory() - self._batch_size = batch_size + # Normalize integer-like sizes to plain Python ints. ``_is_positive_int`` + # accepts numpy integers (via ``numbers.Integral``), but the downstream + # ``create_minibatch_rv`` type-checks ``isinstance(total_size, int)`` and + # would raise on a stored ``np.int64`` ("Invalid type for total_size"). + self._batch_size = int(batch_size) self._sample_shape = tuple(sample_shape) self._dtype = dtype - self._total_size = total_size + self._total_size = None if total_size is None else int(total_size) self._preprocess_fn = preprocess_fn self._cycle = cycle @@ -214,8 +217,17 @@ def advance(self) -> None: self._batches_seen += 1 self._rows_streamed += int(arr.shape[0]) - def fit_callback(self) -> Callable: - """A 3-arg callback ``(approx, losses, i)`` for ``pm.fit(callbacks=...)``.""" + def fit_callback(self, *, seed: bool = True) -> Callable: + """A 3-arg callback ``(approx, losses, i)`` for ``pm.fit(callbacks=...)``. + + PyMC runs callbacks *after* each optimization step, so the buffer must + already hold a real batch before step 0 -- otherwise the first step trains + on the zero-initialized placeholder. By default this seeds the buffer (one + :meth:`advance`) if it has not been advanced yet; pass ``seed=False`` to + opt out (e.g. when you seed manually). + """ + if seed and self._batches_seen == 0: + self.advance() def _cb(*_): self.advance() @@ -264,6 +276,11 @@ def _maybe_warn_total_size(self) -> None: def _validate(self, batch: np.ndarray) -> None: if not isinstance(batch, np.ndarray): raise TypeError(f"expected np.ndarray batch, got {type(batch).__name__}") + if batch.ndim < 1: + raise ValueError( + "batch needs a leading batch dimension; got a scalar array with " + f"shape {batch.shape}." + ) if batch.shape[0] != self._batch_size: raise ValueError( f"batch shape[0] = {batch.shape[0]} does not match batch_size = " @@ -366,7 +383,13 @@ def _make_factory( re-``iter``-ed each epoch. """ if callable(source) and not isinstance(source, Iterator): - return source # type: ignore[return-value] + # A factory may return any iterable (a list of batches, a generator, ...), + # not only an iterator; normalize so ``__next__`` always has an iterator to + # pull from (a bare ``list`` would otherwise fail ``next(...)``). + def _factory() -> Iterator[np.ndarray]: + return iter(source()) # type: ignore[operator] + + return _factory if isinstance(source, Iterator): consumed = {"done": False} @@ -411,11 +434,20 @@ def _auto_total_size( UserWarning, stacklevel=3, ) + first_iter = factory() count = 0 - for chunk in factory(): + for chunk in first_iter: count += int(np.asarray(chunk).shape[0]) if count <= 0: raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") + if factory() is first_iter: + # A genuine factory yields a FRESH iterator each call; one that returns the + # same (now-exhausted) iterator would leave advance() with nothing to pull. + raise ValueError( + "total_size='auto' got a factory that returns the same one-shot iterator " + "each call; pass a factory that creates a fresh iterator each call, or " + "total_size=N explicitly." + ) return count diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index ec184dcb6e..5a9abb027e 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -258,3 +258,65 @@ def test_shuffle_buffer_seed_reproducible_across_runs(): ] ) np.testing.assert_array_equal(a, b) + + +def test_sizes_normalized_to_python_int(): + # numpy integer sizes must be stored as plain Python ints so ds.total_size is + # accepted downstream by create_minibatch_rv (regression for the np.int64 trap). + data = np.zeros((8, 1)) + ds = StreamingDataset( + _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) + ) + assert type(ds.batch_size) is int + assert type(ds.total_size) is int + + +def test_numpy_total_size_accepted_by_observed_rv(): + # a stored np.int64 total_size used to reach create_minibatch_rv and raise + # "Invalid type for total_size"; it must now build a valid observed RV. + data = np.zeros((4, 1), dtype="float64") + ds = StreamingDataset( + lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=np.int64(4) + ) + ds.advance() + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + pm.Normal("y", mu, 1, observed=ds.as_tensor()[:, 0], total_size=ds.total_size) + # compiling the observed logp exercises the create_minibatch_rv scaling path + model.compile_logp(model.observed_RVs)({"mu": np.array(0.0)}) + + +def test_factory_returning_reiterable_is_accepted(): + # a zero-arg factory may return ANY iterable (e.g. a list), not just an + # iterator; advance() used to crash with "'list' object is not an iterator". + data = [np.zeros((4, 1), dtype="float64")] + ds = StreamingDataset(lambda: data, batch_size=4, sample_shape=(1,), total_size=4) + ds.advance() + assert ds.as_tensor().get_value().shape == (4, 1) + + +def test_scalar_batch_rejected_with_clear_error(): + # a 0-D batch used to raise an opaque IndexError on batch.shape[0]. + ds = StreamingDataset( + lambda: iter([np.array(1.0)]), batch_size=1, sample_shape=(), total_size=1 + ) + with pytest.raises(ValueError, match="leading batch dimension"): + ds.advance() + + +def test_fit_callback_seeds_buffer_by_default(): + # PyMC runs callbacks AFTER each step, so the buffer must be seeded before the + # first step; fit_callback() seeds on creation unless seed=False. + data = np.ones((4, 1)) + ds = StreamingDataset(lambda: iter([data, data]), batch_size=4, sample_shape=(1,), total_size=8) + assert ds.batches_seen == 0 + ds.fit_callback() # default seed=True + assert ds.batches_seen == 1 + np.testing.assert_array_equal(ds.as_tensor().get_value(), data) # not the zero placeholder + + +def test_fit_callback_seed_false_does_not_advance(): + data = np.ones((4, 1)) + ds = StreamingDataset(lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=4) + ds.fit_callback(seed=False) + assert ds.batches_seen == 0 diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index 1c33791572..34d7048ae8 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -59,6 +59,15 @@ def test_auto_rejects_one_shot_iterator(): StreamingDataset(one_shot, batch_size=4, sample_shape=(1,), total_size="auto") +def test_auto_rejects_factory_returning_same_one_shot_iterator(): + # a "factory" that hands back the SAME already-consumable iterator each call is + # not re-readable: the counting pass consumes it and advance() would get nothing. + data = np.zeros((20, 1)) + one_shot = (data[i : i + 4] for i in range(0, 20, 4)) + with pytest.raises(ValueError, match="fresh iterator"): + StreamingDataset(lambda: one_shot, batch_size=4, sample_shape=(1,), total_size="auto") + + def test_auto_rejects_bad_n_rows(): f = _factory(np.zeros((8, 1)), 4) f.n_rows = 0 From 7b9fe85ed3e416ea1059a5d0d5510e7cc14e46a9 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Sat, 6 Jun 2026 12:00:34 -0500 Subject: [PATCH 06/27] feat(streaming): forward shuffle_buffer's source .n_rows for total_size="auto" shuffle_buffer now propagates a known .n_rows (e.g. parquet_source's, read from Parquet metadata) to its wrapped factory, so the common composition StreamingDataset(shuffle_buffer(parquet_source(dir)), total_size="auto") resolves N for free instead of doing a full counting pass over the data. The only discrepancy is the single dropped trailing partial batch (< batch_size rows), which is within the auto-size sanity tolerance. Adds 2 regression tests (33 total). --- pymc/variational/streaming.py | 10 ++++++++ tests/variational/test_streaming_autosize.py | 25 +++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index b54497529e..574b4dcfe6 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -370,6 +370,16 @@ def factory() -> Iterator[np.ndarray]: rem = buf.shape[0] - n_full * batch_size carry = buf[n_full * batch_size :].copy() if rem else None + # Forward a known row count (e.g. parquet_source's .n_rows from Parquet + # metadata) to the wrapped factory, so + # ``StreamingDataset(shuffle_buffer(parquet_source(dir)), total_size="auto")`` + # resolves N for free instead of doing a counting pass. The only discrepancy is + # the single dropped trailing partial batch (< batch_size rows), well within the + # auto-size sanity tolerance. + source_n_rows = getattr(chunk_source, "n_rows", None) + if source_n_rows is not None: + factory.n_rows = source_n_rows # type: ignore[attr-defined] + return factory diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index 34d7048ae8..94d0a77aed 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from pymc.variational.streaming import StreamingDataset, parquet_source +from pymc.variational.streaming import StreamingDataset, parquet_source, shuffle_buffer def _factory(data, size): @@ -59,6 +59,29 @@ def test_auto_rejects_one_shot_iterator(): StreamingDataset(one_shot, batch_size=4, sample_shape=(1,), total_size="auto") +def test_shuffle_buffer_forwards_n_rows_for_auto(): + # shuffle_buffer must forward a known .n_rows so total_size="auto" works through + # the common shuffle_buffer(parquet_source(...)) composition WITHOUT a counting + # pass (the realistic way users wrap a Parquet source). + data = np.arange(40, dtype="float64").reshape(40, 1) + src = _factory(data, 8) + src.n_rows = 40 + wrapped = shuffle_buffer(src, buffer_size=20, batch_size=10, seed=0) + assert wrapped.n_rows == 40 + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) # a counting pass would warn -> fail + ds = StreamingDataset(wrapped, batch_size=10, sample_shape=(1,), total_size="auto") + assert ds.total_size == 40 + + +def test_shuffle_buffer_without_n_rows_has_no_attribute(): + # a plain source without .n_rows must not gain a bogus one. + data = np.arange(40, dtype="float64").reshape(40, 1) + wrapped = shuffle_buffer(_factory(data, 8), buffer_size=20, batch_size=10, seed=0) + assert not hasattr(wrapped, "n_rows") + + def test_auto_rejects_factory_returning_same_one_shot_iterator(): # a "factory" that hands back the SAME already-consumable iterator each call is # not re-readable: the counting pass consumes it and advance() would get nothing. From 20d78703ebd633988484129fd43bb64e6aac1fae Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Sat, 6 Jun 2026 16:24:52 -0500 Subject: [PATCH 07/27] refactor(streaming): adopt PyTorch-style Dataset/DataLoader + add Trainer Design-review feedback from Rob (mentor): the streaming API should mirror torch.utils.data so the mental model transfers, and the user-facing callback should go away in favour of a Lightning-style trainer. - IterableDataset: re-iterable out-of-core source base (parquet_source now returns one); carries an optional .n_rows for total_size="auto". - DataLoader: the former StreamingDataset, renamed; gains PyTorch-style shuffle=/buffer_size=/seed= (wraps shuffle_buffer internally). Still owns the fixed pytensor.shared buffer the model observes; advance()/as_tensor() kept. - Trainer: Trainer(method="advi").fit(model, loader, n) drives VI with NO user-facing callbacks -- it seeds the buffer and advances it each step internally. The per-step advance is wired into pm.fit privately. All hardening preserved (int normalization, total_size guards + "auto", shuffle row-conservation + per-epoch reshuffle, copy-before-borrow, validation). shuffle_buffer/parquet_source stay public. 36 tests pass (1 skipped: pyarrow). total_size still appears in the model (total_size=loader.total_size); removing it is an open design question for Rob -- see notes. It is compiled into the logp graph at register_rv time (MinibatchRandomVariable Op), so fit-time injection needs either Trainer graph surgery or a dims-based rule in core. --- pymc/variational/__init__.py | 12 +- pymc/variational/streaming.py | 376 ++++++++++++++----- tests/variational/test_streaming.py | 153 +++++--- tests/variational/test_streaming_autosize.py | 54 ++- 4 files changed, 434 insertions(+), 161 deletions(-) diff --git a/pymc/variational/__init__.py b/pymc/variational/__init__.py index 414d50a6b1..61896fb068 100644 --- a/pymc/variational/__init__.py +++ b/pymc/variational/__init__.py @@ -44,7 +44,13 @@ # special from pymc.variational.stein import Stein -from pymc.variational.streaming import StreamingDataset, parquet_source, shuffle_buffer +from pymc.variational.streaming import ( + DataLoader, + IterableDataset, + Trainer, + parquet_source, + shuffle_buffer, +) from pymc.variational.updates import ( adadelta, adagrad, @@ -65,12 +71,14 @@ "ADVI", "ASVGD", "SVGD", + "DataLoader", "Empirical", "FullRank", "FullRankADVI", "Group", + "IterableDataset", "MeanField", - "StreamingDataset", + "Trainer", "adadelta", "adagrad", "adagrad_window", diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 574b4dcfe6..a25f2bee31 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -14,50 +14,67 @@ """Out-of-core minibatching for variational inference. ``pm.Minibatch`` random-indexes an array that is *fully resident in memory*; its -peak memory is therefore O(N) in the dataset size. ``StreamingDataset`` instead -feeds minibatches from an arbitrary source (a generator, a directory of Parquet -shards, ...) into a small fixed-size ``pytensor.shared`` buffer, so peak memory is -O(buffer) -- the batch buffer plus, if used, the shuffle buffer -- and -independent of N. The unbiased-gradient rescaling is the *same* as for -``pm.Minibatch``: pass ``total_size=N`` to the observed distribution and PyMC -scales the minibatch log-likelihood by ``N / batch_size`` through the existing -:func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. +peak memory is therefore O(N) in the dataset size. This module instead feeds +minibatches from an out-of-core source into a small fixed-size +``pytensor.shared`` buffer, so peak memory is O(buffer) -- the batch buffer plus, +if used, the shuffle buffer -- and independent of N. + +The API deliberately mirrors PyTorch's ``torch.utils.data`` so the mental model +transfers directly: + +* :class:`IterableDataset` -- a re-iterable, out-of-core source of rows + (e.g. :func:`parquet_source` over a directory of shards). It never loads the + whole dataset; it yields it a chunk at a time. +* :class:`DataLoader` -- turns a dataset into fixed-size (optionally shuffled) + minibatches and owns the small ``pytensor.shared`` buffer the model observes. +* :class:`Trainer` -- drives variational inference (ADVI, ...) over a + ``DataLoader`` with **no user-facing callbacks**; ``Trainer.fit(model, loader)`` + advances the buffer each step internally. + +**The full data never enters RAM.** The model graph observes only the +``(batch_size, *sample_shape)`` shared buffer -- a *placeholder* that the loader +overwrites with the next minibatch every step. Passing a directory of 122 GB of +Parquet shards still gives a model whose resident footprint is one batch. + +The unbiased-gradient rescaling is the *same* as for ``pm.Minibatch``: the +observed log-likelihood must be scaled by ``N / batch_size`` through the existing +:func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. Today that means +passing ``total_size=loader.total_size`` to the observed distribution; see +:class:`Trainer` for the in-progress effort to inject it at fit time so it no +longer has to appear in the model body. The one extra obligation relative to ``pm.Minibatch`` is **shuffling**. ``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its minibatches are i.i.d. by construction. A streaming source is only as well mixed as the order it yields rows in: reading time/row-ordered data through a *bounded* buffer is merely a block-shuffle and biases the variational posterior. -Pre-shuffle the data once (or interleave shards) and/or use :func:`shuffle_buffer`. +Pre-shuffle the data once on disk (or interleave shards) and/or pass +``shuffle=True``. Example ------- .. code-block:: python import pymc as pm - from pymc.variational.streaming import StreamingDataset, shuffle_buffer + from pymc.variational.streaming import DataLoader, Trainer, parquet_source - N = 10_000_000 # rows on disk; never all in memory at once - - - def chunks(): # yields (rows, n_features+1) float64 blocks off disk - for shard in shards: - yield read(shard) - - - ds = StreamingDataset( - shuffle_buffer(chunks, buffer_size=1_000_000, batch_size=4096, seed=0), + # The data was pre-shuffled on disk once (see the module note on shuffling), + # so the loader streams it sequentially. The full table stays on disk. + loader = DataLoader( + parquet_source("shuffled/"), # an IterableDataset over the shards batch_size=4096, sample_shape=(4,), # 3 features + 1 observed column - total_size=N, + total_size="auto", # infer N from Parquet metadata ) - with pm.Model(): + with pm.Model() as model: b = pm.Normal("b", 0.0, 3.0, shape=4) - buf = ds.as_tensor() # (batch_size, 4) shared + buf = loader.as_tensor() # (4096, 4) shared buffer -- the ONLY data in RAM logit = b[0] + b[1] * buf[:, 0] + b[2] * buf[:, 1] + b[3] * buf[:, 2] - pm.Bernoulli("y", logit_p=logit, observed=buf[:, 3], total_size=ds.total_size) - approx = pm.fit(20_000, method="advi", callbacks=[ds.fit_callback()]) + pm.Bernoulli("y", logit_p=logit, observed=buf[:, 3], total_size=loader.total_size) + + # No callbacks: the Trainer advances the buffer each step internally. + approx = Trainer(method="advi").fit(model, loader, 20_000) """ from __future__ import annotations @@ -77,31 +94,73 @@ def _is_positive_int(value: object) -> bool: return isinstance(value, numbers.Integral) and not isinstance(value, bool) and value > 0 -class StreamingDataset: - """Feed minibatches to variational inference from an out-of-core source. +class IterableDataset: + """A re-iterable, out-of-core source of rows -- the analogue of ``torch.utils.data.IterableDataset``. + + Subclass and implement :meth:`__iter__` to yield ``np.ndarray`` blocks of rows + (shape ``(rows, *sample_shape)``); :class:`DataLoader` re-batches those blocks + into fixed-size minibatches. ``__iter__`` must return a **fresh** iterator each + call so the dataset can be replayed across epochs. + + Optionally set :attr:`n_rows` (the total row count, if known cheaply -- e.g. + from file metadata) so a :class:`DataLoader` with ``total_size="auto"`` can + resolve ``N`` without a counting pass. + + A plain zero-arg factory (``Callable[[], Iterator[np.ndarray]]``) or any + re-iterable is also accepted directly by :class:`DataLoader`; this base class + is only needed when you want to attach behaviour or ``n_rows`` to a custom + source. + """ + + n_rows: int | None = None + + def __iter__(self) -> Iterator[np.ndarray]: + raise NotImplementedError("IterableDataset subclasses must implement __iter__") + + +class DataLoader: + """Turn an out-of-core dataset into fixed-size minibatches for variational inference. + + The analogue of ``torch.utils.data.DataLoader``: it batches (and optionally + shuffles) an :class:`IterableDataset` and owns the small ``pytensor.shared`` + buffer the model observes. The full dataset never enters memory -- only one + ``(batch_size, *sample_shape)`` buffer does. Parameters ---------- - source : Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] - Yields ``np.ndarray`` batches of shape ``(batch_size, *sample_shape)``. - Pass a zero-arg *callable* (a factory) so the stream can be restarted - when ``cycle=True``; a bare generator can only be consumed once. + dataset : IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] + The source of rows. An :class:`IterableDataset`, a re-iterable, or a + zero-arg *factory* returning a fresh iterator (preferred, so the stream can + be restarted when ``cycle=True``). It may yield individual rows or + multi-row blocks; the loader re-batches to exactly ``batch_size`` rows. batch_size : int - Leading dimension of every yielded batch (and of the buffer). + Leading dimension of every yielded minibatch (and of the buffer). + shuffle : bool, default False + If ``True``, wrap the source in a bounded :func:`shuffle_buffer` of + ``buffer_size`` rows. This only approximates i.i.d. batches for an + *already unordered* stream; a bounded buffer cannot fix strongly + time/row-ordered data (pre-shuffle on disk for that -- see the module + docstring). + buffer_size : int, optional + Shuffle-buffer size in rows when ``shuffle=True``. Defaults to + ``50 * batch_size``. Ignored when ``shuffle=False``. + seed : int, optional + Seed for the shuffle buffer (ignored when ``shuffle=False``). sample_shape : tuple of int, default () Trailing shape of a single observation. ``()`` for scalar observations, ``(k,)`` to stream ``k`` columns (e.g. features + the observed column). dtype : str, default "float64" Dtype of the shared buffer. If it differs from ``pytensor.config.floatX`` the model will insert a per-step cast on the observed tensor. - total_size : int, optional - The true dataset size ``N`` (a positive integer). Pass it to the observed - distribution as ``total_size=ds.total_size`` so the minibatch - log-likelihood is rescaled by ``N / batch_size`` (the same mechanism as - ``pm.Minibatch``). Unlike ``pm.Minibatch`` it cannot be inferred from a - resident array, so it must be supplied; ``None`` warns at construction and - a non-positive value raises (it would otherwise silently disable or invert - the rescaling). + total_size : int or "auto", optional + The true dataset size ``N`` (a positive integer), or ``"auto"`` to infer + it (from the source's ``n_rows`` if available, else a single counting + pass). Pass it on to the observed distribution as + ``total_size=loader.total_size`` so the minibatch log-likelihood is + rescaled by ``N / batch_size`` (the same mechanism as ``pm.Minibatch``). + Unlike ``pm.Minibatch`` it cannot be inferred from a resident array; + ``None`` warns at construction and a non-positive value raises (it would + otherwise silently disable or invert the rescaling). preprocess_fn : callable, optional Pure transform applied to each batch before it lands in the buffer. cycle : bool, default True @@ -113,20 +172,32 @@ class StreamingDataset: def __init__( self, - source: Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], + dataset: IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], *, batch_size: int, + shuffle: bool = False, + buffer_size: int | None = None, + seed: int | None = None, sample_shape: tuple[int, ...] = (), dtype: str = "float64", total_size: int | str | None = None, preprocess_fn: Callable[[np.ndarray], np.ndarray] | None = None, cycle: bool = True, - name: str = "streaming_buffer", + name: str = "dataloader_buffer", ): if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") - self._source_factory = _make_factory(source) + source_factory = _make_factory(dataset) + if shuffle: + if buffer_size is None: + buffer_size = 50 * int(batch_size) + # shuffle_buffer forwards a known .n_rows, so total_size="auto" still + # resolves cheaply through the shuffle wrapper. + source_factory = shuffle_buffer( + source_factory, buffer_size=buffer_size, batch_size=batch_size, seed=seed + ) + self._source_factory = source_factory if isinstance(total_size, str): if total_size != "auto": @@ -134,10 +205,10 @@ def __init__( # Resolve N automatically: a source-provided .n_rows (cheap, e.g. from # parquet_source's metadata) else one counting pass over a finite, # re-readable source. One-shot / infinite sources cannot be auto-counted. - total_size = _auto_total_size(self._source_factory, source) + total_size = _auto_total_size(self._source_factory, dataset) elif total_size is None: warnings.warn( - "StreamingDataset created with total_size=None: the minibatch " + "DataLoader created with total_size=None: the minibatch " "log-likelihood will NOT be rescaled and the posterior will be " "biased. Pass total_size=N (the true dataset size) or total_size='auto'.", UserWarning, @@ -204,7 +275,11 @@ def as_tensor(self) -> pt.TensorVariable: # ----- the only mutator -------------------------------------------------- def advance(self) -> None: - """Pull the next batch from the source into the buffer.""" + """Pull the next batch from the source into the buffer. + + :class:`Trainer` calls this once per optimization step, so end users never + need to. Power users driving ``pm.fit`` directly can call it themselves. + """ batch = self._next_batch() if self._preprocess_fn is not None: batch = self._preprocess_fn(batch) @@ -217,26 +292,9 @@ def advance(self) -> None: self._batches_seen += 1 self._rows_streamed += int(arr.shape[0]) - def fit_callback(self, *, seed: bool = True) -> Callable: - """A 3-arg callback ``(approx, losses, i)`` for ``pm.fit(callbacks=...)``. - - PyMC runs callbacks *after* each optimization step, so the buffer must - already hold a real batch before step 0 -- otherwise the first step trains - on the zero-initialized placeholder. By default this seeds the buffer (one - :meth:`advance`) if it has not been advanced yet; pass ``seed=False`` to - opt out (e.g. when you seed manually). - """ - if seed and self._batches_seen == 0: - self.advance() - - def _cb(*_): - self.advance() - - return _cb - # ----- iterator sugar ---------------------------------------------------- - def __iter__(self) -> StreamingDataset: + def __iter__(self) -> DataLoader: return self def __next__(self) -> np.ndarray: @@ -245,6 +303,29 @@ def __next__(self) -> np.ndarray: # ----- internals --------------------------------------------------------- + def _seed_buffer(self) -> None: + """Load the first real batch if the buffer still holds the zero placeholder. + + PyMC runs ``pm.fit`` callbacks *after* each optimization step, so the + buffer must already hold a real batch before step 0 -- otherwise the first + step trains on the zero-initialized placeholder. :class:`Trainer` calls + this before fitting. + """ + if self._batches_seen == 0: + self.advance() + + def _advance_callback(self) -> Callable: + """A 3-arg ``(approx, losses, i)`` callback that advances the buffer. + + Internal: :class:`Trainer` wires this into ``pm.fit`` so the user never has + to. Kept private deliberately -- the user-facing design has no callbacks. + """ + + def _cb(*_): + self.advance() + + return _cb + def _next_batch(self) -> np.ndarray: try: return next(self._source_iter) @@ -285,7 +366,7 @@ def _validate(self, batch: np.ndarray) -> None: raise ValueError( f"batch shape[0] = {batch.shape[0]} does not match batch_size = " f"{self._batch_size}; partial batches are not allowed (drop them in " - "the source, e.g. via shuffle_buffer)." + "the source, e.g. via shuffle=True / shuffle_buffer)." ) if batch.shape[1:] != self._sample_shape: raise ValueError( @@ -294,6 +375,96 @@ def _validate(self, batch: np.ndarray) -> None: ) +class Trainer: + """Drive variational inference over a :class:`DataLoader` -- without callbacks. + + Mirrors the PyTorch Lightning ``Trainer``/``fit`` split: the ``Trainer`` owns + the training loop, the :class:`DataLoader` owns batching, and the model owns + the math. ``Trainer(method="advi").fit(model, loader, n_steps)`` seeds the + buffer, then runs ``pm.fit`` while advancing the loader once per step. The + per-step advance is wired in internally, so the user-facing API has **no** + callbacks (the design Rob asked for). + + Parameters + ---------- + method : str or Inference, default "advi" + Passed straight through to :func:`pymc.fit` (``"advi"``, + ``"fullrank_advi"``, ...). + **fit_kwargs + Default keyword arguments forwarded to :func:`pymc.fit` on every + :meth:`fit` call (e.g. ``obj_optimizer``); per-call kwargs override them. + + Notes + ----- + This is the *starting point* Rob suggested: the streaming step logic lives in + the ``Trainer`` rather than in the inference operator. The longer-term plan is + to fold it into ADVI itself once the variational-inference rework lands. + """ + + def __init__(self, *, method: str = "advi", **fit_kwargs): + self.method = method + self._fit_kwargs = fit_kwargs + + def fit( + self, + model, + data: DataLoader, + n_steps: int = 10_000, + *, + random_seed: int | None = None, + progressbar: bool = False, + **kwargs, + ): + """Fit ``model`` on the stream from ``data`` for ``n_steps`` steps. + + Parameters + ---------- + model : pymc.Model + The model. Its observed RV should read ``data.as_tensor()`` and (for + now) pass ``total_size=data.total_size`` so the log-likelihood is + rescaled by ``N / batch_size``. + data : DataLoader + The minibatch source. Its buffer is seeded before step 0 and advanced + once after every optimization step. + n_steps : int + Number of optimization steps. + random_seed, progressbar, **kwargs + Forwarded to :func:`pymc.fit` (per-call kwargs override the Trainer's + defaults). + + Returns + ------- + Approximation + Whatever :func:`pymc.fit` returns for the chosen method. + """ + from pymc.variational.inference import fit as _fit + + if not isinstance(data, DataLoader): + raise TypeError( + f"Trainer.fit expects a DataLoader for `data`, got {type(data).__name__}." + ) + if data.total_size is None: + warnings.warn( + "Trainer.fit: the DataLoader has total_size=None, so the minibatch " + "log-likelihood is not rescaled and the posterior will be biased. " + "Construct the DataLoader with total_size=N or total_size='auto'.", + UserWarning, + stacklevel=2, + ) + + data._seed_buffer() + merged = {**self._fit_kwargs, **kwargs} + return _fit( + n_steps, + method=self.method, + model=model, + random_seed=random_seed, + progressbar=progressbar, + callbacks=[data._advance_callback()], + **merged, + ) + + def shuffle_buffer( chunk_source: Callable[[], Iterator[np.ndarray]], *, @@ -310,6 +481,10 @@ def shuffle_buffer( partial batch (< ``batch_size`` rows) is dropped. This approximates i.i.d. minibatches from an *unordered* or pre-shuffled stream. + :class:`DataLoader` calls this for you when ``shuffle=True``; use it directly + when you want explicit control over ``buffer_size`` independently of the + loader. + It does **not** by itself fix a strongly time/row-ordered stream (a bounded buffer only block-shuffles such data) -- pre-shuffle on disk, or interleave shards into ``chunk_source``, for that. ``buffer_size`` is a *lower* bound: the @@ -372,10 +547,10 @@ def factory() -> Iterator[np.ndarray]: # Forward a known row count (e.g. parquet_source's .n_rows from Parquet # metadata) to the wrapped factory, so - # ``StreamingDataset(shuffle_buffer(parquet_source(dir)), total_size="auto")`` - # resolves N for free instead of doing a counting pass. The only discrepancy is - # the single dropped trailing partial batch (< batch_size rows), well within the - # auto-size sanity tolerance. + # ``DataLoader(source, shuffle=True, total_size="auto")`` resolves N for free + # instead of doing a counting pass. The only discrepancy is the single dropped + # trailing partial batch (< batch_size rows), well within the auto-size + # sanity tolerance. source_n_rows = getattr(chunk_source, "n_rows", None) if source_n_rows is not None: factory.n_rows = source_n_rows # type: ignore[attr-defined] @@ -389,8 +564,9 @@ def _make_factory( """Coerce ``source`` into a zero-arg callable returning a fresh iterator. A callable that is not itself an iterator is treated as the factory; a bare - iterator is wrapped (and refuses a second epoch); any other iterable is - re-``iter``-ed each epoch. + iterator is wrapped (and refuses a second epoch); any other iterable (incl. an + :class:`IterableDataset`) is re-``iter``-ed each epoch. A known ``.n_rows`` is + forwarded onto the returned factory so ``total_size="auto"`` stays cheap. """ if callable(source) and not isinstance(source, Iterator): # A factory may return any iterable (a list of batches, a generator, ...), @@ -399,8 +575,7 @@ def _make_factory( def _factory() -> Iterator[np.ndarray]: return iter(source()) # type: ignore[operator] - return _factory - if isinstance(source, Iterator): + elif isinstance(source, Iterator): consumed = {"done": False} def _factory() -> Iterator[np.ndarray]: @@ -412,8 +587,15 @@ def _factory() -> Iterator[np.ndarray]: consumed["done"] = True return source - return _factory - return lambda: iter(source) + else: + + def _factory() -> Iterator[np.ndarray]: + return iter(source) + + n_rows = getattr(source, "n_rows", None) + if n_rows is not None: + _factory.n_rows = n_rows # type: ignore[attr-defined] + return _factory def _auto_total_size( @@ -429,6 +611,10 @@ def _auto_total_size( pass hang -- both must pass ``total_size`` explicitly. """ n = getattr(source, "n_rows", None) + if n is None: + # The user's source may not carry .n_rows even when the (shuffle-wrapped) + # factory does; fall back to the factory's own forwarded count. + n = getattr(factory, "n_rows", None) if n is not None: if not _is_positive_int(n): raise ValueError(f"source.n_rows must be a positive integer, got {n!r}") @@ -461,19 +647,39 @@ def _auto_total_size( return count +class _ParquetDataset(IterableDataset): + """An :class:`IterableDataset` over a directory of Parquet shards. + + Yields one ``(rows, n_columns)`` ``float64`` array per file and exposes + :attr:`n_rows` read from Parquet *metadata* (no data scan). + """ + + def __init__(self, paths: list[str], columns: list[str] | None, n_rows: int): + self._paths = paths + self._columns = columns + self.n_rows = n_rows + + def __iter__(self) -> Iterator[np.ndarray]: + import pyarrow.parquet as pq + + for path in self._paths: + table = pq.read_table(path, columns=self._columns) + yield np.column_stack([table.column(c).to_numpy() for c in table.column_names]) + + def parquet_source( directory: str, *, columns: list[str] | None = None, pattern: str = "*.parquet", -) -> Callable[[], Iterator[np.ndarray]]: - """A finite, re-readable streaming source over a directory of Parquet files. +) -> _ParquetDataset: + """An :class:`IterableDataset` over a directory of Parquet files. Yields one ``(rows, n_columns)`` ``float64`` array per file, and carries an ``n_rows`` attribute read from Parquet *metadata* (no data scan) so that - ``StreamingDataset(parquet_source(dir), ..., total_size="auto")`` resolves the - dataset size for free. Wrap it in :func:`shuffle_buffer` to get fixed-size, - shuffled batches. + ``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the + dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or + wrap in :func:`shuffle_buffer`) to get shuffled batches. """ import glob as _glob import os @@ -483,11 +689,5 @@ def parquet_source( paths = sorted(_glob.glob(os.path.join(directory, pattern))) if not paths: raise ValueError(f"no Parquet files match {os.path.join(directory, pattern)!r}") - - def factory() -> Iterator[np.ndarray]: - for path in paths: - table = pq.read_table(path, columns=columns) - yield np.column_stack([table.column(c).to_numpy() for c in table.column_names]) - - factory.n_rows = sum(pq.read_metadata(p).num_rows for p in paths) # type: ignore[attr-defined] - return factory + n_rows = sum(pq.read_metadata(p).num_rows for p in paths) + return _ParquetDataset(paths, columns, n_rows) diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 5a9abb027e..47f1eac231 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -16,7 +16,12 @@ import pymc as pm -from pymc.variational.streaming import StreamingDataset, shuffle_buffer +from pymc.variational.streaming import ( + DataLoader, + IterableDataset, + Trainer, + shuffle_buffer, +) def _chunks(data, size): @@ -29,7 +34,7 @@ def factory(): def test_advance_shape_and_counters(): data = np.arange(40, dtype="float64").reshape(20, 2) - ds = StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(2,), total_size=20) + ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(2,), total_size=20) assert ds.batches_seen == 0 ds.advance() assert ds.as_tensor().get_value().shape == (4, 2) @@ -40,7 +45,7 @@ def test_advance_shape_and_counters(): def test_wrong_batch_shape_rejected(): data = np.zeros((10, 2)) - ds = StreamingDataset(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) + ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) with pytest.raises(ValueError, match="does not match batch_size"): ds.advance() @@ -48,14 +53,12 @@ def test_wrong_batch_shape_rejected(): def test_total_size_none_warns_at_construction(): data = np.zeros((8, 1)) with pytest.warns(UserWarning, match="total_size=None"): - StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(1,)) + DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,)) def test_cycle_true_restarts_source(): data = np.arange(8, dtype="float64").reshape(8, 1) - ds = StreamingDataset( - _chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=True - ) + ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=True) for _ in range(4): # two epochs worth ds.advance() assert ds.batches_seen == 4 @@ -63,9 +66,7 @@ def test_cycle_true_restarts_source(): def test_cycle_false_raises_when_exhausted(): data = np.arange(8, dtype="float64").reshape(8, 1) - ds = StreamingDataset( - _chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=False - ) + ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=False) ds.advance() ds.advance() with pytest.raises(StopIteration): @@ -74,7 +75,7 @@ def test_cycle_false_raises_when_exhausted(): def test_preprocess_fn_applied(): data = np.ones((8, 1)) - ds = StreamingDataset( + ds = DataLoader( _chunks(data, 4), batch_size=4, sample_shape=(1,), @@ -105,6 +106,28 @@ def test_shuffle_buffer_does_not_mutate_source(): np.testing.assert_array_equal(data, original) # source untouched +def test_dataloader_shuffle_true_yields_full_batches(): + # shuffle=True wraps the source in a bounded shuffle_buffer internally; batches + # are full and rows are conserved (nothing dropped when N % batch_size == 0). + data = np.arange(120, dtype="float64").reshape(120, 1) + ds = DataLoader( + _chunks(data, 8), + batch_size=10, + shuffle=True, + buffer_size=40, + seed=0, + sample_shape=(1,), + total_size=120, + cycle=False, + ) + seen = [] + for _ in range(12): # 120 / 10 + ds.advance() + seen.append(ds.as_tensor().get_value().copy()) + assert all(b.shape == (10, 1) for b in seen) + np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in seen])), data.ravel()) + + def test_total_size_rescales_logp_like_minibatch(): # observed=buf[:, k] + total_size=N must scale the observed log-likelihood by # N / batch_size via the existing create_minibatch_rv path -- pin this without @@ -112,7 +135,7 @@ def test_total_size_rescales_logp_like_minibatch(): rng = np.random.default_rng(0) N, bs = 1000, 16 data = rng.normal(size=(bs, 1)) - ds = StreamingDataset(lambda: iter([data]), batch_size=bs, sample_shape=(1,), total_size=N) + ds = DataLoader(lambda: iter([data]), batch_size=bs, sample_shape=(1,), total_size=N) ds.advance() with pm.Model() as scaled: @@ -128,8 +151,12 @@ def test_total_size_rescales_logp_like_minibatch(): np.testing.assert_allclose(obs_scaled, obs_plain * (N / bs), rtol=1e-6) -def test_equivalence_with_in_ram_minibatch(): - """End-to-end: streaming ADVI reproduces in-RAM pm.Minibatch ADVI.""" +def test_trainer_end_to_end_matches_in_ram_minibatch(): + """End-to-end: Trainer-driven streaming ADVI reproduces in-RAM pm.Minibatch ADVI. + + Also exercises the no-callback design: the Trainer seeds and advances the + DataLoader internally -- the user wires up nothing. + """ seed = 0 rng = np.random.default_rng(seed) N, bs = 60_000, 2048 @@ -151,41 +178,71 @@ def test_equivalence_with_in_ram_minibatch(): ) in_ram = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) - ds = StreamingDataset( - shuffle_buffer(_chunks(data, 20_000), buffer_size=40_000, batch_size=bs, seed=seed), + loader = DataLoader( + _chunks(data, 20_000), batch_size=bs, + shuffle=True, + buffer_size=40_000, + seed=seed, sample_shape=(3,), total_size=N, ) - ds.advance() - with pm.Model(): + with pm.Model() as model: b = pm.Normal("b", 0, 3, shape=3) - buf = ds.as_tensor() + buf = loader.as_tensor() pm.Bernoulli( "o", logit_p=b[0] + b[1] * buf[:, 0] + b[2] * buf[:, 1], observed=buf[:, 2], - total_size=ds.total_size, - ) - ap = pm.fit( - 6000, - method="advi", - obj_optimizer=pm.adam(learning_rate=0.02), - callbacks=[ds.fit_callback()], - progressbar=False, - random_seed=seed, + total_size=loader.total_size, ) + ap = Trainer(method="advi", obj_optimizer=pm.adam(learning_rate=0.02)).fit( + model, loader, 6000, random_seed=seed + ) + with model: stream = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) np.testing.assert_allclose(in_ram, stream, atol=0.1) +def test_trainer_seeds_buffer_before_first_step(): + # PyMC runs fit callbacks AFTER each step, so the Trainer must seed the buffer + # before step 0 (otherwise step 0 trains on the zero placeholder). After a + # short fit, the buffer holds a real batch and batches_seen == n_steps + 1. + data = np.ones((4, 1)) + loader = DataLoader(lambda: iter([data] * 100), batch_size=4, sample_shape=(1,), total_size=4) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + pm.Normal("y", mu, 1, observed=loader.as_tensor()[:, 0], total_size=loader.total_size) + Trainer(method="advi").fit(model, loader, 5, progressbar=False, random_seed=0) + assert loader.batches_seen == 6 # one seed + five steps + np.testing.assert_array_equal(loader.as_tensor().get_value(), data) # not the zero placeholder + + +def test_trainer_rejects_non_dataloader(): + with pm.Model() as model: + pm.Normal("x", 0, 1) + with pytest.raises(TypeError, match="DataLoader"): + Trainer(method="advi").fit(model, object(), 10) + + +def test_trainer_warns_when_total_size_missing(): + data = np.ones((4, 1)) + with pytest.warns(UserWarning, match="total_size=None"): + loader = DataLoader(lambda: iter([data] * 50), batch_size=4, sample_shape=(1,)) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + pm.Normal("y", mu, 1, observed=loader.as_tensor()[:, 0]) # unscaled + with pytest.warns(UserWarning, match="total_size=None"): + Trainer(method="advi").fit(model, loader, 3, progressbar=False, random_seed=0) + + def test_total_size_zero_raises(): # total_size=0 is falsy: it slips a None-only check and the model's truthy # `if total_size:` guard, silently skipping the N/batch_size rescaling. data = np.zeros((8, 1)) with pytest.raises(ValueError, match="positive integer"): - StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=0) + DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=0) def test_total_size_negative_raises(): @@ -193,7 +250,7 @@ def test_total_size_negative_raises(): # (the data log-likelihood's sign flips, so VI maximizes mis-fit). data = np.zeros((8, 1)) with pytest.raises(ValueError, match="positive integer"): - StreamingDataset(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=-100) + DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=-100) def test_shuffle_buffer_small_buffer_conserves_rows(): @@ -220,13 +277,13 @@ def test_shuffle_buffer_rejects_nonpositive_sizes(): def test_accepts_numpy_integer_sizes_rejects_bool(): # the positive-int check uses numbers.Integral: numpy ints are valid, bool is not. data = np.zeros((8, 1)) - ds = StreamingDataset( + ds = DataLoader( _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) ) ds.advance() assert ds.batch_size == 4 with pytest.raises(ValueError): - StreamingDataset(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8) + DataLoader(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8) def test_shuffle_buffer_reshuffles_across_epochs(): @@ -264,7 +321,7 @@ def test_sizes_normalized_to_python_int(): # numpy integer sizes must be stored as plain Python ints so ds.total_size is # accepted downstream by create_minibatch_rv (regression for the np.int64 trap). data = np.zeros((8, 1)) - ds = StreamingDataset( + ds = DataLoader( _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) ) assert type(ds.batch_size) is int @@ -275,9 +332,7 @@ def test_numpy_total_size_accepted_by_observed_rv(): # a stored np.int64 total_size used to reach create_minibatch_rv and raise # "Invalid type for total_size"; it must now build a valid observed RV. data = np.zeros((4, 1), dtype="float64") - ds = StreamingDataset( - lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=np.int64(4) - ) + ds = DataLoader(lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=np.int64(4)) ds.advance() with pm.Model() as model: mu = pm.Normal("mu", 0, 1) @@ -290,33 +345,19 @@ def test_factory_returning_reiterable_is_accepted(): # a zero-arg factory may return ANY iterable (e.g. a list), not just an # iterator; advance() used to crash with "'list' object is not an iterator". data = [np.zeros((4, 1), dtype="float64")] - ds = StreamingDataset(lambda: data, batch_size=4, sample_shape=(1,), total_size=4) + ds = DataLoader(lambda: data, batch_size=4, sample_shape=(1,), total_size=4) ds.advance() assert ds.as_tensor().get_value().shape == (4, 1) def test_scalar_batch_rejected_with_clear_error(): # a 0-D batch used to raise an opaque IndexError on batch.shape[0]. - ds = StreamingDataset( - lambda: iter([np.array(1.0)]), batch_size=1, sample_shape=(), total_size=1 - ) + ds = DataLoader(lambda: iter([np.array(1.0)]), batch_size=1, sample_shape=(), total_size=1) with pytest.raises(ValueError, match="leading batch dimension"): ds.advance() -def test_fit_callback_seeds_buffer_by_default(): - # PyMC runs callbacks AFTER each step, so the buffer must be seeded before the - # first step; fit_callback() seeds on creation unless seed=False. - data = np.ones((4, 1)) - ds = StreamingDataset(lambda: iter([data, data]), batch_size=4, sample_shape=(1,), total_size=8) - assert ds.batches_seen == 0 - ds.fit_callback() # default seed=True - assert ds.batches_seen == 1 - np.testing.assert_array_equal(ds.as_tensor().get_value(), data) # not the zero placeholder - - -def test_fit_callback_seed_false_does_not_advance(): - data = np.ones((4, 1)) - ds = StreamingDataset(lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=4) - ds.fit_callback(seed=False) - assert ds.batches_seen == 0 +def test_iterable_dataset_base_is_abstract(): + # the base class is a contract: __iter__ must be overridden. + with pytest.raises(NotImplementedError): + iter(IterableDataset()) diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index 94d0a77aed..c53107f688 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -11,14 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Prototype: total_size='auto' resolution + the rows_streamed sanity warning.""" +"""total_size='auto' resolution + the rows_streamed sanity warning.""" import warnings import numpy as np import pytest -from pymc.variational.streaming import StreamingDataset, parquet_source, shuffle_buffer +from pymc.variational.streaming import ( + DataLoader, + IterableDataset, + parquet_source, + shuffle_buffer, +) def _factory(data, size): @@ -35,9 +40,7 @@ def test_auto_counts_finite_source(): # no .n_rows -> auto does one counting pass and resolves the true N. data = np.arange(60, dtype="float64").reshape(60, 1) with pytest.warns(UserWarning, match="counting pass"): - ds = StreamingDataset( - _factory(data, 7), batch_size=10, sample_shape=(1,), total_size="auto" - ) + ds = DataLoader(_factory(data, 7), batch_size=10, sample_shape=(1,), total_size="auto") assert ds.total_size == 60 @@ -47,7 +50,7 @@ def test_auto_uses_n_rows_fast_path(): data = np.zeros((8, 1)) f = _factory(data, 4) f.n_rows = 999 - ds = StreamingDataset(f, batch_size=4, sample_shape=(1,), total_size="auto") + ds = DataLoader(f, batch_size=4, sample_shape=(1,), total_size="auto") assert ds.total_size == 999 @@ -56,13 +59,13 @@ def test_auto_rejects_one_shot_iterator(): data = np.zeros((20, 1)) one_shot = (data[i : i + 4] for i in range(0, 20, 4)) with pytest.raises(ValueError, match="re-readable"): - StreamingDataset(one_shot, batch_size=4, sample_shape=(1,), total_size="auto") + DataLoader(one_shot, batch_size=4, sample_shape=(1,), total_size="auto") def test_shuffle_buffer_forwards_n_rows_for_auto(): # shuffle_buffer must forward a known .n_rows so total_size="auto" works through - # the common shuffle_buffer(parquet_source(...)) composition WITHOUT a counting - # pass (the realistic way users wrap a Parquet source). + # the explicit shuffle_buffer(parquet_source(...)) composition WITHOUT a counting + # pass (the realistic way power users wrap a Parquet source). data = np.arange(40, dtype="float64").reshape(40, 1) src = _factory(data, 8) src.n_rows = 40 @@ -71,7 +74,27 @@ def test_shuffle_buffer_forwards_n_rows_for_auto(): with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) # a counting pass would warn -> fail - ds = StreamingDataset(wrapped, batch_size=10, sample_shape=(1,), total_size="auto") + ds = DataLoader(wrapped, batch_size=10, sample_shape=(1,), total_size="auto") + assert ds.total_size == 40 + + +def test_dataloader_shuffle_auto_resolves_via_n_rows(): + # DataLoader(shuffle=True, total_size="auto") must resolve N from the source's + # .n_rows WITHOUT a counting pass, even though shuffle wraps the source. + data = np.arange(40, dtype="float64").reshape(40, 1) + src = _factory(data, 8) + src.n_rows = 40 + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ds = DataLoader( + src, + batch_size=10, + shuffle=True, + buffer_size=20, + seed=0, + sample_shape=(1,), + total_size="auto", + ) assert ds.total_size == 40 @@ -88,20 +111,20 @@ def test_auto_rejects_factory_returning_same_one_shot_iterator(): data = np.zeros((20, 1)) one_shot = (data[i : i + 4] for i in range(0, 20, 4)) with pytest.raises(ValueError, match="fresh iterator"): - StreamingDataset(lambda: one_shot, batch_size=4, sample_shape=(1,), total_size="auto") + DataLoader(lambda: one_shot, batch_size=4, sample_shape=(1,), total_size="auto") def test_auto_rejects_bad_n_rows(): f = _factory(np.zeros((8, 1)), 4) f.n_rows = 0 with pytest.raises(ValueError, match="n_rows must be a positive integer"): - StreamingDataset(f, batch_size=4, sample_shape=(1,), total_size="auto") + DataLoader(f, batch_size=4, sample_shape=(1,), total_size="auto") def test_sanity_warns_on_grossly_wrong_total_size(): # one full pass = 20 rows, but total_size=100 -> at the first epoch boundary, warn. data = np.arange(20, dtype="float64").reshape(20, 1) - ds = StreamingDataset(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=100) + ds = DataLoader(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=100) with pytest.warns(UserWarning, match="disagrees with"): for _ in range(6): # 5 batches = one epoch, the 6th crosses the boundary ds.advance() @@ -109,7 +132,7 @@ def test_sanity_warns_on_grossly_wrong_total_size(): def test_sanity_silent_when_total_size_matches(): data = np.arange(20, dtype="float64").reshape(20, 1) - ds = StreamingDataset(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=20) + ds = DataLoader(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=20) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) # any UserWarning fails the test for _ in range(6): @@ -130,10 +153,11 @@ def test_parquet_source_n_rows_from_metadata(tmp_path): f"{tmp_path}/part_{i:02d}.parquet", ) src = parquet_source(str(tmp_path)) + assert isinstance(src, IterableDataset) # parquet_source is a dataset now assert src.n_rows == total # read from metadata, no data scan # and total_size='auto' picks it up for free (no counting pass / warning) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) - ds = StreamingDataset(src, batch_size=10, sample_shape=(2,), total_size="auto") + ds = DataLoader(src, batch_size=10, sample_shape=(2,), total_size="auto") assert ds.total_size == total From 77229a8ba0ddbc2b03aece7817ccdf21a9e3e6a7 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Sat, 6 Jun 2026 21:36:37 -0500 Subject: [PATCH 08/27] refactor(streaming): align Trainer/DataLoader with VI-rework blueprint Follows jessegrabowski/pymc VI_Overview.ipynb (the VI rework Rob/Jesse are building) instead of my ad-hoc shapes: - DataLoader.__len__ == total_size N (sized like a PyTorch DataLoader), and __iter__ yields the validated minibatch stream. This is the answer to Rob's open question: total_size leaves the model and becomes len(loader). - Trainer takes (method=, dataloader=, model=, data_name=) and fit(n); it streams each minibatch into the model's pm.Data placeholder via model.set_data, so the model is fully decoupled from the loader and the user writes no callbacks. - Model idiom is now pm.Data("batch", placeholder) + total_size=len(loader), matching the blueprint; verified end-to-end (recovers in-RAM pm.Minibatch ADVI). - Kept the as_tensor()/advance() shared-buffer path as a documented advanced escape hatch; dropped the now-unused _seed_buffer/_advance_callback. 38 tests pass (1 skipped: pyarrow). Open for Rob: spelling DataLoader (PyTorch, per his "match PyTorch") vs Dataloader (Jesse's draft); method-as-string until the ADVI(Inference).step rework lands. --- pymc/variational/streaming.py | 259 +++++++++++++++------------- tests/variational/test_streaming.py | 76 +++++--- 2 files changed, 191 insertions(+), 144 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index a25f2bee31..462f2cdf15 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -26,22 +26,24 @@ (e.g. :func:`parquet_source` over a directory of shards). It never loads the whole dataset; it yields it a chunk at a time. * :class:`DataLoader` -- turns a dataset into fixed-size (optionally shuffled) - minibatches and owns the small ``pytensor.shared`` buffer the model observes. + minibatches; it is iterable (the minibatch stream) and sized (``len(loader)`` + is the dataset size ``N``), exactly like a PyTorch one. * :class:`Trainer` -- drives variational inference (ADVI, ...) over a - ``DataLoader`` with **no user-facing callbacks**; ``Trainer.fit(model, loader)`` - advances the buffer each step internally. + ``DataLoader`` with **no user-facing callbacks**; + ``Trainer(method=..., dataloader=...).fit(n)`` streams each minibatch into the + model's ``pm.Data`` placeholder with ``set_data``. -**The full data never enters RAM.** The model graph observes only the -``(batch_size, *sample_shape)`` shared buffer -- a *placeholder* that the loader +**The full data never enters RAM.** The model graph observes only a +``(batch_size, *sample_shape)`` ``pm.Data`` *placeholder* that the ``Trainer`` overwrites with the next minibatch every step. Passing a directory of 122 GB of Parquet shards still gives a model whose resident footprint is one batch. The unbiased-gradient rescaling is the *same* as for ``pm.Minibatch``: the observed log-likelihood must be scaled by ``N / batch_size`` through the existing -:func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. Today that means -passing ``total_size=loader.total_size`` to the observed distribution; see -:class:`Trainer` for the in-progress effort to inject it at fit time so it no -longer has to appear in the model body. +:func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. ``N`` is exactly +``len(loader)`` -- a :class:`DataLoader` is sized like a PyTorch one -- so the +model passes ``total_size=len(loader)``. (Folding that scaling into the inference +step, so it drops out of the model body, is the next step in PyMC's VI rework.) The one extra obligation relative to ``pm.Minibatch`` is **shuffling**. ``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its @@ -55,6 +57,7 @@ ------- .. code-block:: python + import numpy as np import pymc as pm from pymc.variational.streaming import DataLoader, Trainer, parquet_source @@ -64,17 +67,18 @@ parquet_source("shuffled/"), # an IterableDataset over the shards batch_size=4096, sample_shape=(4,), # 3 features + 1 observed column - total_size="auto", # infer N from Parquet metadata + total_size="auto", # infer N from Parquet metadata; N == len(loader) ) with pm.Model() as model: b = pm.Normal("b", 0.0, 3.0, shape=4) - buf = loader.as_tensor() # (4096, 4) shared buffer -- the ONLY data in RAM - logit = b[0] + b[1] * buf[:, 0] + b[2] * buf[:, 1] + b[3] * buf[:, 2] - pm.Bernoulli("y", logit_p=logit, observed=buf[:, 3], total_size=loader.total_size) + batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder -- the ONLY data in RAM + logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] + pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) - # No callbacks: the Trainer advances the buffer each step internally. - approx = Trainer(method="advi").fit(model, loader, 20_000) + # No callbacks: the Trainer streams each minibatch into "batch" with set_data. + with model: + approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) """ from __future__ import annotations @@ -122,9 +126,10 @@ class DataLoader: """Turn an out-of-core dataset into fixed-size minibatches for variational inference. The analogue of ``torch.utils.data.DataLoader``: it batches (and optionally - shuffles) an :class:`IterableDataset` and owns the small ``pytensor.shared`` - buffer the model observes. The full dataset never enters memory -- only one - ``(batch_size, *sample_shape)`` buffer does. + shuffles) an :class:`IterableDataset` into the minibatch stream that + :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)`` + is the dataset size ``N``). The full dataset never enters memory -- only one + ``(batch_size, *sample_shape)`` batch does. Parameters ---------- @@ -266,65 +271,60 @@ def rows_streamed(self) -> int: """Total rows pushed through the buffer (grows past ``N`` across epochs).""" return self._rows_streamed - # ----- the model-facing tensor ------------------------------------------ + # ----- iteration: the minibatch stream ---------------------------------- - def as_tensor(self) -> pt.TensorVariable: - """The ``pytensor.shared`` buffer the model observes (mutates each step).""" - return self._shared - - # ----- the only mutator -------------------------------------------------- - - def advance(self) -> None: - """Pull the next batch from the source into the buffer. + def __iter__(self) -> Iterator[np.ndarray]: + """Yield one epoch of validated ``(batch_size, *sample_shape)`` minibatches. - :class:`Trainer` calls this once per optimization step, so end users never - need to. Power users driving ``pm.fit`` directly can call it themselves. + This is the stream :class:`Trainer` pushes into the model's ``pm.Data`` + placeholder via ``set_data``. Re-iterate the loader for another epoch. """ - batch = self._next_batch() - if self._preprocess_fn is not None: - batch = self._preprocess_fn(batch) - self._validate(batch) - # Own a fresh contiguous copy before borrowing into the shared variable: - # the source may legitimately yield *views* into a reused array, so we - # must not alias it. np.array(copy default) guarantees an owned array. - arr = np.array(batch, dtype=self._dtype) - self._shared.set_value(arr, borrow=True) - self._batches_seen += 1 - self._rows_streamed += int(arr.shape[0]) + for batch in self._source_factory(): + yield self._prepare(batch) - # ----- iterator sugar ---------------------------------------------------- + def __len__(self) -> int: + """The dataset size ``N`` -- pass to the observed distribution's ``total_size``. - def __iter__(self) -> DataLoader: - return self + Sized like a PyTorch ``DataLoader``; ``total_size=len(loader)`` is how the + model gets the ``N / batch_size`` rescaling. + """ + if self._total_size is None: + raise TypeError( + "len(DataLoader) is the dataset size N, but this loader was built with " + "total_size=None; construct it with total_size=N or total_size='auto'." + ) + return self._total_size - def __next__(self) -> np.ndarray: - self.advance() - return self._shared.get_value(borrow=False) # an owned copy, safe to keep + # ----- shared-buffer path (advanced; the Trainer uses pm.Data instead) --- - # ----- internals --------------------------------------------------------- - - def _seed_buffer(self) -> None: - """Load the first real batch if the buffer still holds the zero placeholder. + def as_tensor(self) -> pt.TensorVariable: + """A ``pytensor.shared`` buffer the model can observe directly. - PyMC runs ``pm.fit`` callbacks *after* each optimization step, so the - buffer must already hold a real batch before step 0 -- otherwise the first - step trains on the zero-initialized placeholder. :class:`Trainer` calls - this before fitting. + An alternative to the recommended ``pm.Data`` + :class:`Trainer` path: drive + ``pm.fit`` yourself and call :meth:`advance` (e.g. from a callback) to refill + this buffer each step. """ - if self._batches_seen == 0: - self.advance() + return self._shared - def _advance_callback(self) -> Callable: - """A 3-arg ``(approx, losses, i)`` callback that advances the buffer. + def advance(self) -> None: + """Pull the next batch from the source into the :meth:`as_tensor` buffer.""" + arr = self._prepare(self._next_batch()) + self._shared.set_value(arr, borrow=True) + self._batches_seen += 1 + self._rows_streamed += int(arr.shape[0]) - Internal: :class:`Trainer` wires this into ``pm.fit`` so the user never has - to. Kept private deliberately -- the user-facing design has no callbacks. - """ + # ----- internals --------------------------------------------------------- - def _cb(*_): - self.advance() + def _prepare(self, batch: np.ndarray) -> np.ndarray: + """Preprocess, validate, and return an *owned* copy of one batch. - return _cb + The owned copy matters: a source may legitimately yield *views* into a + reused array, so neither the buffer nor an iteration consumer may alias it. + """ + if self._preprocess_fn is not None: + batch = self._preprocess_fn(batch) + self._validate(batch) + return np.array(batch, dtype=self._dtype) # np.array(copy default) owns it def _next_batch(self) -> np.ndarray: try: @@ -378,89 +378,116 @@ def _validate(self, batch: np.ndarray) -> None: class Trainer: """Drive variational inference over a :class:`DataLoader` -- without callbacks. - Mirrors the PyTorch Lightning ``Trainer``/``fit`` split: the ``Trainer`` owns - the training loop, the :class:`DataLoader` owns batching, and the model owns - the math. ``Trainer(method="advi").fit(model, loader, n_steps)`` seeds the - buffer, then runs ``pm.fit`` while advancing the loader once per step. The - per-step advance is wired in internally, so the user-facing API has **no** - callbacks (the design Rob asked for). + Follows the design in PyMC's variational-inference rework (Grabowski, *VI + Overview*) and PyTorch Lightning: the ``Trainer`` owns the training loop, the + :class:`DataLoader` owns batching (and ``len(dataloader)`` is the dataset size + ``N``), and the model owns the math. The model exposes a ``pm.Data`` placeholder; + the ``Trainer`` streams minibatches into it with ``model.set_data`` once per + step, so the user wires up **no** callbacks (the design Rob asked for). + + .. code-block:: python + + import numpy as np + + loader = DataLoader( + parquet_source("shuffled/"), batch_size=4096, sample_shape=(4,), total_size="auto" + ) + with pm.Model() as model: + b = pm.Normal("b", 0.0, 3.0, shape=4) + batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder + logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] + pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) + approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) Parameters ---------- - method : str or Inference, default "advi" - Passed straight through to :func:`pymc.fit` (``"advi"``, - ``"fullrank_advi"``, ...). + method : str, default "advi" + Variational method, forwarded to :func:`pymc.fit` (``"advi"``, + ``"fullrank_advi"``, ...). When PyMC's VI rework lands this will also accept + an inference *instance* (e.g. ``ADVI()``); a string drives today's ``pm.fit``. + dataloader : DataLoader + The minibatch source. ``len(dataloader)`` is ``N`` -- the model should pass + it to the observed distribution's ``total_size``. + model : pymc.Model, optional + Defaults to the model on the context stack. + data_name : str, default "data" + Name of the ``pm.Data`` placeholder minibatches are streamed into. **fit_kwargs - Default keyword arguments forwarded to :func:`pymc.fit` on every - :meth:`fit` call (e.g. ``obj_optimizer``); per-call kwargs override them. + Default keyword arguments forwarded to :func:`pymc.fit` (e.g. + ``obj_optimizer``); per-call kwargs to :meth:`fit` override them. Notes ----- - This is the *starting point* Rob suggested: the streaming step logic lives in - the ``Trainer`` rather than in the inference operator. The longer-term plan is - to fold it into ADVI itself once the variational-inference rework lands. + This is the *starting point* Rob suggested: the per-step ``set_data`` logic + lives in the ``Trainer``. The longer-term plan is to fold it into the inference + object's ``step(batch)`` once the VI rework lands, at which point the + ``total_size`` rescaling can be derived from ``len(dataloader)`` and dropped + from the model body entirely. """ - def __init__(self, *, method: str = "advi", **fit_kwargs): + def __init__( + self, + *, + method: str = "advi", + dataloader: DataLoader, + model=None, + data_name: str = "data", + **fit_kwargs, + ): self.method = method + self.dataloader = dataloader + self.model = model + self.data_name = data_name self._fit_kwargs = fit_kwargs def fit( self, - model, - data: DataLoader, - n_steps: int = 10_000, + n: int = 10_000, *, random_seed: int | None = None, progressbar: bool = False, **kwargs, ): - """Fit ``model`` on the stream from ``data`` for ``n_steps`` steps. - - Parameters - ---------- - model : pymc.Model - The model. Its observed RV should read ``data.as_tensor()`` and (for - now) pass ``total_size=data.total_size`` so the log-likelihood is - rescaled by ``N / batch_size``. - data : DataLoader - The minibatch source. Its buffer is seeded before step 0 and advanced - once after every optimization step. - n_steps : int - Number of optimization steps. - random_seed, progressbar, **kwargs - Forwarded to :func:`pymc.fit` (per-call kwargs override the Trainer's - defaults). - - Returns - ------- - Approximation - Whatever :func:`pymc.fit` returns for the chosen method. + """Fit for ``n`` steps, streaming minibatches into the model's placeholder. + + Returns whatever :func:`pymc.fit` returns for the chosen method. """ + from pymc.model import modelcontext from pymc.variational.inference import fit as _fit - if not isinstance(data, DataLoader): + loader = self.dataloader + if not isinstance(loader, DataLoader): raise TypeError( - f"Trainer.fit expects a DataLoader for `data`, got {type(data).__name__}." - ) - if data.total_size is None: - warnings.warn( - "Trainer.fit: the DataLoader has total_size=None, so the minibatch " - "log-likelihood is not rescaled and the posterior will be biased. " - "Construct the DataLoader with total_size=N or total_size='auto'.", - UserWarning, - stacklevel=2, + f"Trainer needs a DataLoader for `dataloader`, got {type(loader).__name__}." ) + model = modelcontext(self.model) + + # An endless minibatch stream: re-iterate the loader across epochs. + def _stream() -> Iterator[np.ndarray]: + while True: + empty = True + for batch in loader: + empty = False + yield batch + if empty: + raise RuntimeError("dataloader yielded no batches") + + batches = _stream() + # Seed the placeholder before step 0: pm.fit runs callbacks AFTER each step, + # so without this the first step would train on the placeholder's contents. + model.set_data(self.data_name, next(batches)) + + def _advance(*_): + model.set_data(self.data_name, next(batches)) - data._seed_buffer() merged = {**self._fit_kwargs, **kwargs} return _fit( - n_steps, + n, method=self.method, model=model, random_seed=random_seed, progressbar=progressbar, - callbacks=[data._advance_callback()], + callbacks=[_advance], **merged, ) diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 47f1eac231..5d88b1256e 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -154,8 +154,8 @@ def test_total_size_rescales_logp_like_minibatch(): def test_trainer_end_to_end_matches_in_ram_minibatch(): """End-to-end: Trainer-driven streaming ADVI reproduces in-RAM pm.Minibatch ADVI. - Also exercises the no-callback design: the Trainer seeds and advances the - DataLoader internally -- the user wires up nothing. + Exercises the whole blueprint API: a pm.Data placeholder + total_size=len(loader), + and a Trainer that streams minibatches into it with set_data -- no user callbacks. """ seed = 0 rng = np.random.default_rng(seed) @@ -189,52 +189,72 @@ def test_trainer_end_to_end_matches_in_ram_minibatch(): ) with pm.Model() as model: b = pm.Normal("b", 0, 3, shape=3) - buf = loader.as_tensor() + batch = pm.Data("batch", np.zeros((bs, 3))) # placeholder; full data stays out of RAM pm.Bernoulli( "o", - logit_p=b[0] + b[1] * buf[:, 0] + b[2] * buf[:, 1], - observed=buf[:, 2], - total_size=loader.total_size, + logit_p=b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1], + observed=batch[:, 2], + total_size=len(loader), # N comes from the loader, PyTorch-style ) - ap = Trainer(method="advi", obj_optimizer=pm.adam(learning_rate=0.02)).fit( - model, loader, 6000, random_seed=seed - ) - with model: + ap = Trainer( + method="advi", + dataloader=loader, + data_name="batch", + obj_optimizer=pm.adam(learning_rate=0.02), + ).fit(6000, random_seed=seed) stream = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) np.testing.assert_allclose(in_ram, stream, atol=0.1) -def test_trainer_seeds_buffer_before_first_step(): - # PyMC runs fit callbacks AFTER each step, so the Trainer must seed the buffer - # before step 0 (otherwise step 0 trains on the zero placeholder). After a - # short fit, the buffer holds a real batch and batches_seen == n_steps + 1. +def test_trainer_streams_into_placeholder(): + # The Trainer seeds the pm.Data placeholder before step 0 (pm.fit runs callbacks + # AFTER each step) and overwrites it each step; after fitting it holds a real + # batch, not the zero seed -- with the user writing no callbacks. data = np.ones((4, 1)) loader = DataLoader(lambda: iter([data] * 100), batch_size=4, sample_shape=(1,), total_size=4) with pm.Model() as model: mu = pm.Normal("mu", 0, 1) - pm.Normal("y", mu, 1, observed=loader.as_tensor()[:, 0], total_size=loader.total_size) - Trainer(method="advi").fit(model, loader, 5, progressbar=False, random_seed=0) - assert loader.batches_seen == 6 # one seed + five steps - np.testing.assert_array_equal(loader.as_tensor().get_value(), data) # not the zero placeholder + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + Trainer(method="advi", dataloader=loader, data_name="batch").fit( + 5, progressbar=False, random_seed=0 + ) + np.testing.assert_array_equal(model["batch"].get_value(), data) # not the zero seed def test_trainer_rejects_non_dataloader(): - with pm.Model() as model: - pm.Normal("x", 0, 1) + # the isinstance guard fires before any model lookup, so no context is needed. with pytest.raises(TypeError, match="DataLoader"): - Trainer(method="advi").fit(model, object(), 10) + Trainer(method="advi", dataloader=object()).fit(10) -def test_trainer_warns_when_total_size_missing(): +def test_len_returns_total_size(): + data = np.zeros((40, 1)) + loader = DataLoader(_chunks(data, 8), batch_size=8, sample_shape=(1,), total_size=40) + assert len(loader) == 40 + + +def test_len_raises_when_total_size_none(): + # len(loader) IS N; with total_size=None there is no N to hand the model, so it + # raises rather than silently skipping the N/batch_size rescaling. data = np.ones((4, 1)) with pytest.warns(UserWarning, match="total_size=None"): - loader = DataLoader(lambda: iter([data] * 50), batch_size=4, sample_shape=(1,)) - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1) - pm.Normal("y", mu, 1, observed=loader.as_tensor()[:, 0]) # unscaled - with pytest.warns(UserWarning, match="total_size=None"): - Trainer(method="advi").fit(model, loader, 3, progressbar=False, random_seed=0) + loader = DataLoader(lambda: iter([data] * 5), batch_size=4, sample_shape=(1,)) + with pytest.raises(TypeError, match="total_size=None"): + len(loader) + + +def test_iter_yields_clean_batches_and_reiterates(): + # __iter__ yields validated (batch_size, *sample_shape) batches and can be + # re-iterated for another epoch -- this is the stream the Trainer consumes. + data = np.arange(40, dtype="float64").reshape(40, 1) + loader = DataLoader(_chunks(data, 10), batch_size=10, sample_shape=(1,), total_size=40) + e1 = list(loader) + e2 = list(loader) # re-iterable + assert len(e1) == 4 and all(b.shape == (10, 1) for b in e1) + np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in e1])), data.ravel()) + np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in e2])), data.ravel()) def test_total_size_zero_raises(): From 03f3a7611cbf3593dbc0cd6b3c70474c7855925a Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Mon, 8 Jun 2026 22:41:06 -0500 Subject: [PATCH 09/27] Refine streaming VI after mentor design review - Trainer's stream now updates batches_seen/rows_streamed and runs the one-shot total_size sanity check at each epoch boundary (previously dead on the Trainer path; __iter__ stays side-effect-free). - total_size="auto" with shuffle=True counts the unshuffled source, fixing an undercount of up to batch_size-1 rows. - Trainer default data_name "data" -> "batch" to match the examples/tests. - Clarify len(loader)==N (rows, not batches) in docstrings; raise a clear error when a cycled source restarts empty. - Register the streaming API in docs/source/api/vi.rst. - Add regression tests for the auto-size shuffle count and Trainer counters. --- docs/source/api/vi.rst | 15 ++++ pymc/variational/streaming.py | 72 ++++++++++++++------ tests/variational/test_streaming_autosize.py | 40 +++++++++++ 3 files changed, 108 insertions(+), 19 deletions(-) diff --git a/docs/source/api/vi.rst b/docs/source/api/vi.rst index cca88dfde2..3e59294f67 100644 --- a/docs/source/api/vi.rst +++ b/docs/source/api/vi.rst @@ -68,6 +68,21 @@ Special Stein +Streaming +--------- +Out-of-core minibatching for variational inference on datasets that do not fit in +memory (see :mod:`pymc.variational.streaming`). + +.. currentmodule:: pymc.variational +.. autosummary:: + :toctree: generated/ + + DataLoader + IterableDataset + Trainer + shuffle_buffer + parquet_source + .. currentmodule:: pymc .. autosummary:: :toctree: generated/ diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 462f2cdf15..2f3ff4fbe9 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -26,8 +26,9 @@ (e.g. :func:`parquet_source` over a directory of shards). It never loads the whole dataset; it yields it a chunk at a time. * :class:`DataLoader` -- turns a dataset into fixed-size (optionally shuffled) - minibatches; it is iterable (the minibatch stream) and sized (``len(loader)`` - is the dataset size ``N``), exactly like a PyTorch one. + minibatches; it is iterable (the minibatch stream) and sized. Note ``len(loader)`` + is the row count ``N`` (what the observed distribution needs for ``total_size``), + *not* the batch count ``torch.utils.data.DataLoader.__len__`` returns. * :class:`Trainer` -- drives variational inference (ADVI, ...) over a ``DataLoader`` with **no user-facing callbacks**; ``Trainer(method=..., dataloader=...).fit(n)`` streams each minibatch into the @@ -41,7 +42,7 @@ The unbiased-gradient rescaling is the *same* as for ``pm.Minibatch``: the observed log-likelihood must be scaled by ``N / batch_size`` through the existing :func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. ``N`` is exactly -``len(loader)`` -- a :class:`DataLoader` is sized like a PyTorch one -- so the +``len(loader)`` (the loader is sized; ``len`` returns the row count ``N``) -- so the model passes ``total_size=len(loader)``. (Folding that scaling into the inference step, so it drops out of the model body, is the next step in PyMC's VI rework.) @@ -193,24 +194,26 @@ def __init__( if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") - source_factory = _make_factory(dataset) + raw_factory = _make_factory(dataset) + source_factory = raw_factory if shuffle: if buffer_size is None: buffer_size = 50 * int(batch_size) - # shuffle_buffer forwards a known .n_rows, so total_size="auto" still - # resolves cheaply through the shuffle wrapper. source_factory = shuffle_buffer( - source_factory, buffer_size=buffer_size, batch_size=batch_size, seed=seed + raw_factory, buffer_size=buffer_size, batch_size=batch_size, seed=seed ) self._source_factory = source_factory if isinstance(total_size, str): if total_size != "auto": raise ValueError(f"total_size string must be 'auto', got {total_size!r}") - # Resolve N automatically: a source-provided .n_rows (cheap, e.g. from - # parquet_source's metadata) else one counting pass over a finite, - # re-readable source. One-shot / infinite sources cannot be auto-counted. - total_size = _auto_total_size(self._source_factory, dataset) + # Resolve N automatically from the UNSHUFFLED source: a source-provided + # .n_rows (cheap, e.g. from parquet_source's metadata) else one counting + # pass over a finite, re-readable source. We count raw_factory, never the + # shuffle-wrapped one: the shuffle buffer drops the final partial batch, so + # counting through it would undercount N by up to batch_size-1. One-shot / + # infinite sources cannot be auto-counted. + total_size = _auto_total_size(raw_factory, dataset) elif total_size is None: warnings.warn( "DataLoader created with total_size=None: the minibatch " @@ -283,10 +286,12 @@ def __iter__(self) -> Iterator[np.ndarray]: yield self._prepare(batch) def __len__(self) -> int: - """The dataset size ``N`` -- pass to the observed distribution's ``total_size``. + """The dataset size ``N`` (row count) -- pass to the distribution's ``total_size``. - Sized like a PyTorch ``DataLoader``; ``total_size=len(loader)`` is how the - model gets the ``N / batch_size`` rescaling. + ``total_size=len(loader)`` is how the model gets the ``N / batch_size`` + rescaling. Note this returns the *row* count ``N``, not the *batch* count + (``ceil(N / batch_size)``) that ``torch.utils.data.DataLoader.__len__`` + returns; ``total_size`` needs ``N``. :attr:`total_size` is the same value. """ if self._total_size is None: raise TypeError( @@ -295,6 +300,24 @@ def __len__(self) -> int: ) return self._total_size + def _stream_batches(self) -> Iterator[np.ndarray]: + """One epoch of prepared minibatches, with accounting (the Trainer's path). + + Like :meth:`__iter__` but it updates :attr:`batches_seen` / + :attr:`rows_streamed` and fires the one-shot ``total_size`` sanity check at + the epoch boundary. :meth:`__iter__` stays side-effect-free so a plain + ``for b in loader`` / ``list(loader)`` does not mutate counters; the + :class:`Trainer` iterates *this* instead, so its (documented, primary) + workflow still benefits from the wrong-``total_size`` guard and reports + non-zero counters. + """ + for batch in self._source_factory(): + prepared = self._prepare(batch) + self._batches_seen += 1 + self._rows_streamed += int(prepared.shape[0]) + yield prepared + self._maybe_warn_total_size() + # ----- shared-buffer path (advanced; the Trainer uses pm.Data instead) --- def as_tensor(self) -> pt.TensorVariable: @@ -336,7 +359,17 @@ def _next_batch(self) -> np.ndarray: # row count, so we can sanity-check the user's total_size for free. self._maybe_warn_total_size() self._source_iter = self._source_factory() - return next(self._source_iter) + try: + return next(self._source_iter) + except StopIteration: + # A cycled source that restarts empty would otherwise leak a bare + # StopIteration out of advance(); surface it as a clear error (mirrors + # Trainer._stream's "yielded no batches" guard). + raise RuntimeError( + "DataLoader source yielded no rows on restart (cycle=True); an " + "empty/exhausted source cannot be cycled. Ensure the source factory " + "returns a non-empty iterator each epoch." + ) from None def _maybe_warn_total_size(self) -> None: """Warn once if total_size grossly disagrees with the rows seen in one pass.""" @@ -410,8 +443,9 @@ class Trainer: it to the observed distribution's ``total_size``. model : pymc.Model, optional Defaults to the model on the context stack. - data_name : str, default "data" - Name of the ``pm.Data`` placeholder minibatches are streamed into. + data_name : str, default "batch" + Name of the ``pm.Data`` placeholder minibatches are streamed into. Must + match the name used for ``pm.Data(name, ...)`` in the model. **fit_kwargs Default keyword arguments forwarded to :func:`pymc.fit` (e.g. ``obj_optimizer``); per-call kwargs to :meth:`fit` override them. @@ -431,7 +465,7 @@ def __init__( method: str = "advi", dataloader: DataLoader, model=None, - data_name: str = "data", + data_name: str = "batch", **fit_kwargs, ): self.method = method @@ -466,7 +500,7 @@ def fit( def _stream() -> Iterator[np.ndarray]: while True: empty = True - for batch in loader: + for batch in loader._stream_batches(): empty = False yield batch if empty: diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index c53107f688..f5a433ffad 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -161,3 +161,43 @@ def test_parquet_source_n_rows_from_metadata(tmp_path): warnings.simplefilter("error", UserWarning) ds = DataLoader(src, batch_size=10, sample_shape=(2,), total_size="auto") assert ds.total_size == total + + +def test_auto_counts_unshuffled_source_when_shuffling_non_divisible(): + # total_size="auto" with shuffle=True must count the UNSHUFFLED source: the + # shuffle buffer drops the final partial batch, so counting through it would + # undercount N by up to batch_size-1. N=125 is not divisible by batch_size=10. + data = np.arange(125, dtype="float64").reshape(125, 1) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader( + _factory(data, 125), # one chunk, NO .n_rows -> forces a counting pass + batch_size=10, + shuffle=True, + buffer_size=30, + seed=0, + sample_shape=(1,), + total_size="auto", + ) + assert ds.total_size == 125 # exact N, not 120 (was undercounted via the shuffle wrap) + + +def test_stream_batches_updates_counters_and_warns_on_wrong_total_size(): + # The accounting-aware stream the Trainer iterates (loader._stream_batches) must + # update the public counters AND fire the one-shot sanity check at the epoch + # boundary -- so a grossly wrong hand-passed total_size is still caught on the + # Trainer's primary path, not only via advance(); plain iteration stays pure. + data = np.arange(40, dtype="float64").reshape(20, 2) + ds = DataLoader( + _factory(data, 5), # 4 chunks of 5 rows + batch_size=5, + sample_shape=(2,), + total_size=10_000, # grossly wrong vs the 20 rows actually streamed + ) + assert ds.batches_seen == 0 and ds.rows_streamed == 0 + list(ds) # plain __iter__ must NOT mutate counters + assert ds.batches_seen == 0 and ds.rows_streamed == 0 + with pytest.warns(UserWarning, match="disagrees with"): + batches = list(ds._stream_batches()) # one epoch through the Trainer's path + assert len(batches) == 4 + assert ds.batches_seen == 4 + assert ds.rows_streamed == 20 From 36f0955ede9ae2d727512200c2fb4d76a2faae4c Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Mon, 8 Jun 2026 22:57:00 -0500 Subject: [PATCH 10/27] Register streaming tests in CI matrix (check_all_tests_are_covered) --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1d9278dd34..c299fbfa55 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -190,7 +190,7 @@ jobs: linker: [cvm, numba] python-version: ["3.12"] test-subset: - - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py + - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/variational/test_streaming.py tests/variational/test_streaming_autosize.py tests/test_initial_point.py - tests/model/test_core.py tests/sampling/test_mcmc.py - tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py From 9b8a914e804275998725b1bdd04de95363967103 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Tue, 9 Jun 2026 19:08:47 -0500 Subject: [PATCH 11/27] Re-batch arbitrary block sizes in the plain DataLoader path The non-shuffle path previously required the source to yield exact batch_size blocks and raised on anything else, while the docstrings promised re-batching. Now both paths re-batch: blocks of any size are sliced in order with remainders carried across blocks, and a raw array (or any single-sample stream) is accepted directly, so the VI-rework sketch usage Dataloader(, batch_size=...) works as written. Trailing rows that do not fill a final batch are dropped, like drop_last=True in torch, since the model observes a fixed-shape placeholder. Also: total_size="auto" counts a single-sample stream as rows rather than flattened elements; Trainer.fit(callbacks=...) appends user callbacks after the internal advance instead of raising a duplicate keyword error. --- pymc/variational/streaming.py | 79 +++++++++++++++++--- tests/variational/test_streaming.py | 75 +++++++++++++++++-- tests/variational/test_streaming_autosize.py | 23 ++++++ 3 files changed, 158 insertions(+), 19 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 2f3ff4fbe9..7bbae569c3 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -135,10 +135,14 @@ class DataLoader: Parameters ---------- dataset : IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] - The source of rows. An :class:`IterableDataset`, a re-iterable, or a - zero-arg *factory* returning a fresh iterator (preferred, so the stream can - be restarted when ``cycle=True``). It may yield individual rows or - multi-row blocks; the loader re-batches to exactly ``batch_size`` rows. + The source of rows. An :class:`IterableDataset`, a re-iterable (including a + plain ``np.ndarray``), or a zero-arg *factory* returning a fresh iterator + (preferred, so the stream can be restarted when ``cycle=True``). It may + yield single samples (e.g. the rows of a raw array) or blocks of any size; + the loader re-batches them, in order, to exactly ``batch_size`` rows. + Trailing rows that do not fill a final batch are dropped at the end of a + pass, like ``drop_last=True`` in PyTorch (required here because the model + observes a fixed-shape placeholder). batch_size : int Leading dimension of every yielded minibatch (and of the buffer). shuffle : bool, default False @@ -213,7 +217,7 @@ def __init__( # shuffle-wrapped one: the shuffle buffer drops the final partial batch, so # counting through it would undercount N by up to batch_size-1. One-shot / # infinite sources cannot be auto-counted. - total_size = _auto_total_size(raw_factory, dataset) + total_size = _auto_total_size(raw_factory, dataset, tuple(sample_shape)) elif total_size is None: warnings.warn( "DataLoader created with total_size=None: the minibatch " @@ -234,7 +238,6 @@ def __init__( f"{total_size!r}." ) - self._source_iter: Iterator[np.ndarray] = self._source_factory() # Normalize integer-like sizes to plain Python ints. ``_is_positive_int`` # accepts numpy integers (via ``numbers.Integral``), but the downstream # ``create_minibatch_rv`` type-checks ``isinstance(total_size, int)`` and @@ -250,6 +253,9 @@ def __init__( self._rows_streamed = 0 self._warned_size = False # the sanity check below fires at most once + # The persistent stream behind advance()/as_tensor(); re-batched like __iter__. + self._source_iter: Iterator[np.ndarray] = self._rebatched() + self._shared = pytensor.shared( np.zeros((batch_size, *self._sample_shape), dtype=dtype), name=name ) @@ -276,13 +282,17 @@ def rows_streamed(self) -> int: # ----- iteration: the minibatch stream ---------------------------------- + def _rebatched(self) -> Iterator[np.ndarray]: + """A fresh pass of exactly ``batch_size``-row batches from the source.""" + return _rebatch(self._source_factory(), self._batch_size, self._sample_shape) + def __iter__(self) -> Iterator[np.ndarray]: """Yield one epoch of validated ``(batch_size, *sample_shape)`` minibatches. This is the stream :class:`Trainer` pushes into the model's ``pm.Data`` placeholder via ``set_data``. Re-iterate the loader for another epoch. """ - for batch in self._source_factory(): + for batch in self._rebatched(): yield self._prepare(batch) def __len__(self) -> int: @@ -311,7 +321,7 @@ def _stream_batches(self) -> Iterator[np.ndarray]: workflow still benefits from the wrong-``total_size`` guard and reports non-zero counters. """ - for batch in self._source_factory(): + for batch in self._rebatched(): prepared = self._prepare(batch) self._batches_seen += 1 self._rows_streamed += int(prepared.shape[0]) @@ -358,7 +368,7 @@ def _next_batch(self) -> np.ndarray: # First exhaustion == one full pass: rows_streamed now equals the real # row count, so we can sanity-check the user's total_size for free. self._maybe_warn_total_size() - self._source_iter = self._source_factory() + self._source_iter = self._rebatched() try: return next(self._source_iter) except StopIteration: @@ -515,13 +525,16 @@ def _advance(*_): model.set_data(self.data_name, next(batches)) merged = {**self._fit_kwargs, **kwargs} + # User callbacks (e.g. convergence trackers) run AFTER the internal advance, + # appended rather than colliding with it on the `callbacks` keyword. + user_callbacks = merged.pop("callbacks", None) or [] return _fit( n, method=self.method, model=model, random_seed=random_seed, progressbar=progressbar, - callbacks=[_advance], + callbacks=[_advance, *user_callbacks], **merged, ) @@ -619,6 +632,46 @@ def factory() -> Iterator[np.ndarray]: return factory +def _rebatch( + blocks: Iterable[np.ndarray], + batch_size: int, + sample_shape: tuple[int, ...], +) -> Iterator[np.ndarray]: + """Slice a stream of samples/blocks into exact ``batch_size``-row batches, in order. + + Accepts single samples (shape ``sample_shape``, e.g. the rows of a raw array) + and blocks of any size (shape ``(rows, *sample_shape)``), carrying remainders + across blocks so no row is lost mid-stream. Trailing rows that do not fill a + final batch are dropped when the stream ends (``drop_last=True`` behavior; the + model observes a fixed-shape placeholder, so a partial batch cannot be fed). + Sources that already yield exact ``batch_size`` blocks (e.g. + :func:`shuffle_buffer`) pass through without copying. + """ + buf: list[np.ndarray] = [] + have = 0 + for arr in blocks: + a = np.asarray(arr) + if a.shape == sample_shape: # a single sample, not a block + a = a[None, ...] + elif a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape: + raise ValueError( + f"source yielded shape {a.shape}; expected a single sample of shape " + f"{sample_shape} or a block of shape (rows, *{sample_shape})" + ) + buf.append(a) + have += a.shape[0] + if have < batch_size: + continue + merged = np.concatenate(buf, axis=0) if len(buf) > 1 else buf[0] + n_full = merged.shape[0] // batch_size + for i in range(n_full): + yield merged[i * batch_size : (i + 1) * batch_size] + rem = merged.shape[0] - n_full * batch_size + buf = [merged[n_full * batch_size :].copy()] if rem else [] + have = rem + # stream ended: the (< batch_size) remainder in buf is dropped + + def _make_factory( source: Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], ) -> Callable[[], Iterator[np.ndarray]]: @@ -662,6 +715,7 @@ def _factory() -> Iterator[np.ndarray]: def _auto_total_size( factory: Callable[[], Iterator[np.ndarray]], source: object, + sample_shape: tuple[int, ...] = (), ) -> int: """Resolve ``total_size="auto"``: a source ``.n_rows`` (cheap) else a counting pass. @@ -694,7 +748,10 @@ def _auto_total_size( first_iter = factory() count = 0 for chunk in first_iter: - count += int(np.asarray(chunk).shape[0]) + a = np.asarray(chunk) + # A yield of shape exactly `sample_shape` is ONE sample (e.g. one row of a + # raw array), not a block of a.shape[0] rows; count it as a single row. + count += 1 if a.shape == sample_shape else int(a.shape[0]) if count <= 0: raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") if factory() is first_iter: diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 5d88b1256e..1084e34b12 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -43,10 +43,47 @@ def test_advance_shape_and_counters(): assert ds.batches_seen == 2 and ds.rows_streamed == 8 -def test_wrong_batch_shape_rejected(): - data = np.zeros((10, 2)) +def test_plain_loader_rebatches_arbitrary_blocks(): + # blocks of 3 with batch_size=4: re-batched in order; the trailing 2 rows that + # cannot fill a final batch are dropped (drop_last semantics). + data = np.arange(20, dtype="float64").reshape(10, 2) ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) - with pytest.raises(ValueError, match="does not match batch_size"): + batches = list(ds) + assert [b.shape for b in batches] == [(4, 2), (4, 2)] + np.testing.assert_array_equal(np.concatenate(batches), data[:8]) # order preserved + + +def test_raw_array_source_like_vi_rework_sketch(): + # the VI-rework sketch usage: Dataloader(, batch_size=...); rows are + # yielded one sample at a time and the loader re-batches them. + data = np.arange(40, dtype="float64").reshape(20, 2) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader(data, batch_size=8, sample_shape=(2,), total_size="auto") + assert ds.total_size == 20 # counted as 20 rows, not 40 flattened elements + batches = list(ds) + assert [b.shape for b in batches] == [(8, 2), (8, 2)] # trailing 4 rows dropped + np.testing.assert_array_equal(np.concatenate(batches), data[:16]) + + +def test_wrong_sample_shape_rejected(): + data = np.zeros((12, 3)) # 3 columns, but the loader declares 2 + ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(2,), total_size=12) + with pytest.raises(ValueError, match="source yielded shape"): + ds.advance() + + +def test_cycle_true_empty_restart_raises_clear_error(): + calls = {"n": 0} + + def factory(): + calls["n"] += 1 + if calls["n"] == 1: + yield np.zeros((4, 1)) + # second epoch: nothing (an exhausted/empty source cannot be cycled) + + ds = DataLoader(factory, batch_size=4, sample_shape=(1,), total_size=4) + ds.advance() + with pytest.raises(RuntimeError, match="restart"): ds.advance() @@ -370,11 +407,33 @@ def test_factory_returning_reiterable_is_accepted(): assert ds.as_tensor().get_value().shape == (4, 1) -def test_scalar_batch_rejected_with_clear_error(): - # a 0-D batch used to raise an opaque IndexError on batch.shape[0]. - ds = DataLoader(lambda: iter([np.array(1.0)]), batch_size=1, sample_shape=(), total_size=1) - with pytest.raises(ValueError, match="leading batch dimension"): - ds.advance() +def test_scalar_samples_are_batched(): + # with sample_shape=() a 0-D yield is ONE scalar sample (exactly what iterating + # a raw 1-D array produces); the loader batches scalars instead of erroring. + data = np.arange(6, dtype="float64") + ds = DataLoader(data, batch_size=3, sample_shape=(), total_size=6) + batches = list(ds) + assert [b.shape for b in batches] == [(3,), (3,)] + np.testing.assert_array_equal(np.concatenate(batches), data) + + +def test_trainer_appends_user_callbacks_and_streams_distinct_batches(): + # user callbacks (e.g. convergence trackers) must compose with the internal + # advance callback, not collide with it on the `callbacks` keyword; and the + # placeholder must hold a DIFFERENT batch on successive steps. + blocks = [np.full((4, 1), float(i)) for i in range(60)] + loader = DataLoader(lambda: iter(blocks), batch_size=4, sample_shape=(1,), total_size=240) + seen = [] + with pm.Model() as model: + x = pm.Normal("x", 0.0, 1.0) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", x, 1.0, observed=batch[:, 0], total_size=len(loader)) + # no data_name passed: the default ("batch") must match this placeholder + Trainer(method="advi", dataloader=loader).fit( + 5, callbacks=[lambda *_: seen.append(float(model["batch"].get_value()[0, 0]))] + ) + assert len(seen) == 5 # the user callback ran every step + assert len(set(seen)) > 1 # the placeholder advanced to new batches def test_iterable_dataset_base_is_abstract(): diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index f5a433ffad..7a50093e7b 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -163,6 +163,29 @@ def test_parquet_source_n_rows_from_metadata(tmp_path): assert ds.total_size == total +def test_parquet_source_columns_and_shard_order(tmp_path): + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + for i in range(2): + pq.write_table( + pa.table( + {"a": [float(i)] * 2, "b": [9.0] * 2, "c": [float(10 + i)] * 2}, + ), + f"{tmp_path}/part_{i}.parquet", + ) + src = parquet_source(str(tmp_path), columns=["a", "c"]) + blocks = list(src) + assert [b.shape for b in blocks] == [(2, 2), (2, 2)] # column b filtered out + np.testing.assert_array_equal(blocks[0][:, 0], [0.0, 0.0]) # sorted shard order + np.testing.assert_array_equal(blocks[1][:, 1], [11.0, 11.0]) + + +def test_parquet_source_empty_dir_raises(tmp_path): + pytest.importorskip("pyarrow") + with pytest.raises(ValueError, match="no Parquet files match"): + parquet_source(str(tmp_path)) + + def test_auto_counts_unshuffled_source_when_shuffling_non_divisible(): # total_size="auto" with shuffle=True must count the UNSHUFFLED source: the # shuffle buffer drops the final partial batch, so counting through it would From 32ec88fad324ece75d5876267100a3c482bf6e29 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 06:11:53 -0500 Subject: [PATCH 12/27] Address review comments - Drop the shared-buffer path (as_tensor/advance and the cycle/name parameters): neither exists in torch.utils.data and the Trainer never used it. Manual stepping stays available through plain iteration plus set_data. - Move modelcontext/fit imports to module level. - Replace test comments with docstrings, drop redundant comments and section banners, and rename the reshuffle test descriptively. --- pymc/variational/streaming.py | 228 ++++++------------ tests/variational/test_streaming.py | 233 +++++++++---------- tests/variational/test_streaming_autosize.py | 73 +++--- 3 files changed, 209 insertions(+), 325 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 7bbae569c3..a388c3ac02 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -14,13 +14,12 @@ """Out-of-core minibatching for variational inference. ``pm.Minibatch`` random-indexes an array that is *fully resident in memory*; its -peak memory is therefore O(N) in the dataset size. This module instead feeds -minibatches from an out-of-core source into a small fixed-size -``pytensor.shared`` buffer, so peak memory is O(buffer) -- the batch buffer plus, -if used, the shuffle buffer -- and independent of N. +peak memory is therefore O(N) in the dataset size. This module instead streams +minibatches from an out-of-core source into a ``pm.Data`` placeholder, so peak +memory is O(batch) plus, if used, the shuffle buffer, independent of N. -The API deliberately mirrors PyTorch's ``torch.utils.data`` so the mental model -transfers directly: +The API mirrors PyTorch's ``torch.utils.data`` so the mental model transfers +directly: * :class:`IterableDataset` -- a re-iterable, out-of-core source of rows (e.g. :func:`parquet_source` over a directory of shards). It never loads the @@ -36,8 +35,9 @@ **The full data never enters RAM.** The model graph observes only a ``(batch_size, *sample_shape)`` ``pm.Data`` *placeholder* that the ``Trainer`` -overwrites with the next minibatch every step. Passing a directory of 122 GB of -Parquet shards still gives a model whose resident footprint is one batch. +overwrites with the next minibatch every step. Passing a directory of Parquet +shards far larger than RAM still gives a model whose resident footprint is one +batch. The unbiased-gradient rescaling is the *same* as for ``pm.Minibatch``: the observed log-likelihood must be scaled by ``N / batch_size`` through the existing @@ -73,7 +73,7 @@ with pm.Model() as model: b = pm.Normal("b", 0.0, 3.0, shape=4) - batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder -- the ONLY data in RAM + batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder, the only data in RAM logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) @@ -90,8 +90,9 @@ from collections.abc import Callable, Iterable, Iterator import numpy as np -import pytensor -import pytensor.tensor as pt + +from pymc.model import modelcontext +from pymc.variational.inference import fit as _fit def _is_positive_int(value: object) -> bool: @@ -137,14 +138,14 @@ class DataLoader: dataset : IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] The source of rows. An :class:`IterableDataset`, a re-iterable (including a plain ``np.ndarray``), or a zero-arg *factory* returning a fresh iterator - (preferred, so the stream can be restarted when ``cycle=True``). It may - yield single samples (e.g. the rows of a raw array) or blocks of any size; - the loader re-batches them, in order, to exactly ``batch_size`` rows. - Trailing rows that do not fill a final batch are dropped at the end of a - pass, like ``drop_last=True`` in PyTorch (required here because the model - observes a fixed-shape placeholder). + (preferred, so the stream can be restarted each epoch). It may yield single + samples (e.g. the rows of a raw array) or blocks of any size; the loader + re-batches them, in order, to exactly ``batch_size`` rows. Trailing rows + that do not fill a final batch are dropped at the end of a pass, like + ``drop_last=True`` in PyTorch (required here because the model observes a + fixed-shape placeholder). batch_size : int - Leading dimension of every yielded minibatch (and of the buffer). + Leading dimension of every yielded minibatch. shuffle : bool, default False If ``True``, wrap the source in a bounded :func:`shuffle_buffer` of ``buffer_size`` rows. This only approximates i.i.d. batches for an @@ -160,24 +161,19 @@ class DataLoader: Trailing shape of a single observation. ``()`` for scalar observations, ``(k,)`` to stream ``k`` columns (e.g. features + the observed column). dtype : str, default "float64" - Dtype of the shared buffer. If it differs from ``pytensor.config.floatX`` - the model will insert a per-step cast on the observed tensor. + Dtype each prepared batch is cast to; match the dtype of the ``pm.Data`` + placeholder the batches are streamed into. total_size : int or "auto", optional The true dataset size ``N`` (a positive integer), or ``"auto"`` to infer it (from the source's ``n_rows`` if available, else a single counting pass). Pass it on to the observed distribution as - ``total_size=loader.total_size`` so the minibatch log-likelihood is - rescaled by ``N / batch_size`` (the same mechanism as ``pm.Minibatch``). - Unlike ``pm.Minibatch`` it cannot be inferred from a resident array; - ``None`` warns at construction and a non-positive value raises (it would - otherwise silently disable or invert the rescaling). + ``total_size=len(loader)`` so the minibatch log-likelihood is rescaled by + ``N / batch_size`` (the same mechanism as ``pm.Minibatch``). Unlike + ``pm.Minibatch`` it cannot be inferred from a resident array; ``None`` + warns at construction and a non-positive value raises (it would otherwise + silently disable or invert the rescaling). preprocess_fn : callable, optional - Pure transform applied to each batch before it lands in the buffer. - cycle : bool, default True - Restart the source when exhausted (the usual case: many epochs). If - ``False``, :meth:`advance` raises ``StopIteration`` once exhausted. - name : str - Name of the underlying ``pytensor.shared`` variable. + Pure transform applied to each batch before it is yielded. """ def __init__( @@ -192,8 +188,6 @@ def __init__( dtype: str = "float64", total_size: int | str | None = None, preprocess_fn: Callable[[np.ndarray], np.ndarray] | None = None, - cycle: bool = True, - name: str = "dataloader_buffer", ): if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") @@ -211,12 +205,8 @@ def __init__( if isinstance(total_size, str): if total_size != "auto": raise ValueError(f"total_size string must be 'auto', got {total_size!r}") - # Resolve N automatically from the UNSHUFFLED source: a source-provided - # .n_rows (cheap, e.g. from parquet_source's metadata) else one counting - # pass over a finite, re-readable source. We count raw_factory, never the - # shuffle-wrapped one: the shuffle buffer drops the final partial batch, so - # counting through it would undercount N by up to batch_size-1. One-shot / - # infinite sources cannot be auto-counted. + # Count the unshuffled source: the shuffle wrapper drops the trailing + # partial batch, so counting through it would undercount N. total_size = _auto_total_size(raw_factory, dataset, tuple(sample_shape)) elif total_size is None: warnings.warn( @@ -227,40 +217,24 @@ def __init__( stacklevel=2, ) elif not _is_positive_int(total_size): - # A non-positive total_size is silently dangerous: 0 is falsy, so the - # model never wraps the observed RV and the N/batch_size rescaling is - # skipped (posterior collapses toward the prior); a negative value - # yields a negative scaling coefficient that flips the data - # log-likelihood's sign (VI then maximizes mis-fit). Reject it loudly. + # 0 is falsy (the rescaling would be silently skipped) and a negative + # value flips the sign of the data log-likelihood; reject both loudly. raise ValueError( "total_size must be a positive integer (the true dataset size N) so " "the minibatch log-likelihood is rescaled by N / batch_size; got " f"{total_size!r}." ) - # Normalize integer-like sizes to plain Python ints. ``_is_positive_int`` - # accepts numpy integers (via ``numbers.Integral``), but the downstream - # ``create_minibatch_rv`` type-checks ``isinstance(total_size, int)`` and - # would raise on a stored ``np.int64`` ("Invalid type for total_size"). + # Plain Python ints: create_minibatch_rv rejects np.int64 for total_size. self._batch_size = int(batch_size) self._sample_shape = tuple(sample_shape) self._dtype = dtype self._total_size = None if total_size is None else int(total_size) self._preprocess_fn = preprocess_fn - self._cycle = cycle self._batches_seen = 0 self._rows_streamed = 0 - self._warned_size = False # the sanity check below fires at most once - - # The persistent stream behind advance()/as_tensor(); re-batched like __iter__. - self._source_iter: Iterator[np.ndarray] = self._rebatched() - - self._shared = pytensor.shared( - np.zeros((batch_size, *self._sample_shape), dtype=dtype), name=name - ) - - # ----- read-only state --------------------------------------------------- + self._warned_size = False @property def batch_size(self) -> int: @@ -277,11 +251,9 @@ def batches_seen(self) -> int: @property def rows_streamed(self) -> int: - """Total rows pushed through the buffer (grows past ``N`` across epochs).""" + """Total rows streamed into the model (grows past ``N`` across epochs).""" return self._rows_streamed - # ----- iteration: the minibatch stream ---------------------------------- - def _rebatched(self) -> Iterator[np.ndarray]: """A fresh pass of exactly ``batch_size``-row batches from the source.""" return _rebatch(self._source_factory(), self._batch_size, self._sample_shape) @@ -314,12 +286,9 @@ def _stream_batches(self) -> Iterator[np.ndarray]: """One epoch of prepared minibatches, with accounting (the Trainer's path). Like :meth:`__iter__` but it updates :attr:`batches_seen` / - :attr:`rows_streamed` and fires the one-shot ``total_size`` sanity check at - the epoch boundary. :meth:`__iter__` stays side-effect-free so a plain - ``for b in loader`` / ``list(loader)`` does not mutate counters; the - :class:`Trainer` iterates *this* instead, so its (documented, primary) - workflow still benefits from the wrong-``total_size`` guard and reports - non-zero counters. + :attr:`rows_streamed` and runs the one-shot ``total_size`` sanity check at + the epoch boundary. :meth:`__iter__` stays side-effect-free so plain + iteration does not mutate counters. """ for batch in self._rebatched(): prepared = self._prepare(batch) @@ -328,58 +297,16 @@ def _stream_batches(self) -> Iterator[np.ndarray]: yield prepared self._maybe_warn_total_size() - # ----- shared-buffer path (advanced; the Trainer uses pm.Data instead) --- - - def as_tensor(self) -> pt.TensorVariable: - """A ``pytensor.shared`` buffer the model can observe directly. - - An alternative to the recommended ``pm.Data`` + :class:`Trainer` path: drive - ``pm.fit`` yourself and call :meth:`advance` (e.g. from a callback) to refill - this buffer each step. - """ - return self._shared - - def advance(self) -> None: - """Pull the next batch from the source into the :meth:`as_tensor` buffer.""" - arr = self._prepare(self._next_batch()) - self._shared.set_value(arr, borrow=True) - self._batches_seen += 1 - self._rows_streamed += int(arr.shape[0]) - - # ----- internals --------------------------------------------------------- - def _prepare(self, batch: np.ndarray) -> np.ndarray: - """Preprocess, validate, and return an *owned* copy of one batch. + """Preprocess, validate, and return an owned copy of one batch. - The owned copy matters: a source may legitimately yield *views* into a - reused array, so neither the buffer nor an iteration consumer may alias it. + A source may legitimately yield views into a reused array; the copy + prevents the consumer from aliasing it. """ if self._preprocess_fn is not None: batch = self._preprocess_fn(batch) self._validate(batch) - return np.array(batch, dtype=self._dtype) # np.array(copy default) owns it - - def _next_batch(self) -> np.ndarray: - try: - return next(self._source_iter) - except StopIteration: - if not self._cycle: - raise - # First exhaustion == one full pass: rows_streamed now equals the real - # row count, so we can sanity-check the user's total_size for free. - self._maybe_warn_total_size() - self._source_iter = self._rebatched() - try: - return next(self._source_iter) - except StopIteration: - # A cycled source that restarts empty would otherwise leak a bare - # StopIteration out of advance(); surface it as a clear error (mirrors - # Trainer._stream's "yielded no batches" guard). - raise RuntimeError( - "DataLoader source yielded no rows on restart (cycle=True); an " - "empty/exhausted source cannot be cycled. Ensure the source factory " - "returns a non-empty iterator each epoch." - ) from None + return np.array(batch, dtype=self._dtype) def _maybe_warn_total_size(self) -> None: """Warn once if total_size grossly disagrees with the rows seen in one pass.""" @@ -407,9 +334,7 @@ def _validate(self, batch: np.ndarray) -> None: ) if batch.shape[0] != self._batch_size: raise ValueError( - f"batch shape[0] = {batch.shape[0]} does not match batch_size = " - f"{self._batch_size}; partial batches are not allowed (drop them in " - "the source, e.g. via shuffle=True / shuffle_buffer)." + f"batch shape[0] = {batch.shape[0]} does not match batch_size = {self._batch_size}." ) if batch.shape[1:] != self._sample_shape: raise ValueError( @@ -426,7 +351,7 @@ class Trainer: :class:`DataLoader` owns batching (and ``len(dataloader)`` is the dataset size ``N``), and the model owns the math. The model exposes a ``pm.Data`` placeholder; the ``Trainer`` streams minibatches into it with ``model.set_data`` once per - step, so the user wires up **no** callbacks (the design Rob asked for). + step, so the user wires up no callbacks. .. code-block:: python @@ -446,7 +371,7 @@ class Trainer: ---------- method : str, default "advi" Variational method, forwarded to :func:`pymc.fit` (``"advi"``, - ``"fullrank_advi"``, ...). When PyMC's VI rework lands this will also accept + ``"fullrank_advi"``, ...). Once the VI rework lands this will also accept an inference *instance* (e.g. ``ADVI()``); a string drives today's ``pm.fit``. dataloader : DataLoader The minibatch source. ``len(dataloader)`` is ``N`` -- the model should pass @@ -462,9 +387,8 @@ class Trainer: Notes ----- - This is the *starting point* Rob suggested: the per-step ``set_data`` logic - lives in the ``Trainer``. The longer-term plan is to fold it into the inference - object's ``step(batch)`` once the VI rework lands, at which point the + The per-step ``set_data`` currently lives in the ``Trainer``. Once the VI + rework's ``Inference.step(batch)`` lands it moves there, at which point the ``total_size`` rescaling can be derived from ``len(dataloader)`` and dropped from the model body entirely. """ @@ -496,9 +420,6 @@ def fit( Returns whatever :func:`pymc.fit` returns for the chosen method. """ - from pymc.model import modelcontext - from pymc.variational.inference import fit as _fit - loader = self.dataloader if not isinstance(loader, DataLoader): raise TypeError( @@ -506,7 +427,6 @@ def fit( ) model = modelcontext(self.model) - # An endless minibatch stream: re-iterate the loader across epochs. def _stream() -> Iterator[np.ndarray]: while True: empty = True @@ -525,8 +445,8 @@ def _advance(*_): model.set_data(self.data_name, next(batches)) merged = {**self._fit_kwargs, **kwargs} - # User callbacks (e.g. convergence trackers) run AFTER the internal advance, - # appended rather than colliding with it on the `callbacks` keyword. + # User callbacks (e.g. convergence trackers) are appended after the + # internal advance instead of colliding with it on the keyword. user_callbacks = merged.pop("callbacks", None) or [] return _fit( n, @@ -569,9 +489,8 @@ def shuffle_buffer( ``max(buffer_size, batch_size, largest_chunk_rows)``. Each epoch (each call of the returned factory) draws a fresh permutation from - a sub-stream of ``seed``, so the shuffle order differs across epochs -- a - seeded buffer must not replay one fixed order forever -- while staying - reproducible for a given ``seed``. + a sub-stream of ``seed``, so the shuffle order differs across epochs while + staying reproducible for a given ``seed``. """ if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") @@ -580,17 +499,14 @@ def shuffle_buffer( seed_seq = np.random.SeedSequence(seed) def factory() -> Iterator[np.ndarray]: - # Spawn a fresh sub-stream per epoch so re-iterating (cycle=True) reshuffles - # rather than replaying one fixed permutation forever; still reproducible - # across runs for a given seed. + # A fresh sub-stream per epoch: re-iterating reshuffles instead of + # replaying one fixed permutation, yet stays reproducible per seed. rng = np.random.default_rng(seed_seq.spawn(1)[0]) it = chunk_source() - carry: np.ndarray | None = None # leftover (< batch_size) from last fill + carry: np.ndarray | None = None exhausted = False - # Accumulate at least one full batch's worth even when buffer_size < - # batch_size: otherwise the inner loop would break early with fewer than - # batch_size rows and the `have < batch_size` guard below would silently - # discard the entire stream. + # Accumulate at least one batch even when buffer_size < batch_size, + # otherwise the guard below would silently discard the whole stream. target = max(buffer_size, batch_size) while not exhausted: bufs: list[np.ndarray] = [] @@ -606,12 +522,12 @@ def factory() -> Iterator[np.ndarray]: if have >= target: break else: - exhausted = True # for-loop ran to completion: source is done + exhausted = True if have < batch_size: # Only reachable once the source is exhausted: drop the final - # sub-batch remainder (it cannot form a full batch). + # partial batch. return - buf = np.concatenate(bufs, axis=0) # always a fresh, owned copy + buf = np.concatenate(bufs, axis=0) rng.shuffle(buf) n_full = buf.shape[0] // batch_size for i in range(n_full): @@ -619,12 +535,8 @@ def factory() -> Iterator[np.ndarray]: rem = buf.shape[0] - n_full * batch_size carry = buf[n_full * batch_size :].copy() if rem else None - # Forward a known row count (e.g. parquet_source's .n_rows from Parquet - # metadata) to the wrapped factory, so - # ``DataLoader(source, shuffle=True, total_size="auto")`` resolves N for free - # instead of doing a counting pass. The only discrepancy is the single dropped - # trailing partial batch (< batch_size rows), well within the auto-size - # sanity tolerance. + # Forward a known row count so total_size="auto" stays metadata-cheap + # through the shuffle wrapper. source_n_rows = getattr(chunk_source, "n_rows", None) if source_n_rows is not None: factory.n_rows = source_n_rows # type: ignore[attr-defined] @@ -669,7 +581,6 @@ def _rebatch( rem = merged.shape[0] - n_full * batch_size buf = [merged[n_full * batch_size :].copy()] if rem else [] have = rem - # stream ended: the (< batch_size) remainder in buf is dropped def _make_factory( @@ -683,9 +594,8 @@ def _make_factory( forwarded onto the returned factory so ``total_size="auto"`` stays cheap. """ if callable(source) and not isinstance(source, Iterator): - # A factory may return any iterable (a list of batches, a generator, ...), - # not only an iterator; normalize so ``__next__`` always has an iterator to - # pull from (a bare ``list`` would otherwise fail ``next(...)``). + # A factory may return any iterable (a list of batches, a generator, ...); + # normalize so the loader always pulls from a true iterator. def _factory() -> Iterator[np.ndarray]: return iter(source()) # type: ignore[operator] @@ -695,8 +605,9 @@ def _factory() -> Iterator[np.ndarray]: def _factory() -> Iterator[np.ndarray]: if consumed["done"]: raise RuntimeError( - "source is a bare iterator and cycle=True was requested; pass a " - "zero-arg factory or a re-iterable instead" + "source is a bare iterator and was already consumed; the loader " + "restarts the stream each epoch, so pass a zero-arg factory or a " + "re-iterable instead" ) consumed["done"] = True return source @@ -727,8 +638,6 @@ def _auto_total_size( """ n = getattr(source, "n_rows", None) if n is None: - # The user's source may not carry .n_rows even when the (shuffle-wrapped) - # factory does; fall back to the factory's own forwarded count. n = getattr(factory, "n_rows", None) if n is not None: if not _is_positive_int(n): @@ -749,14 +658,13 @@ def _auto_total_size( count = 0 for chunk in first_iter: a = np.asarray(chunk) - # A yield of shape exactly `sample_shape` is ONE sample (e.g. one row of a - # raw array), not a block of a.shape[0] rows; count it as a single row. + # A yield of shape exactly `sample_shape` is ONE sample, not a block. count += 1 if a.shape == sample_shape else int(a.shape[0]) if count <= 0: raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") if factory() is first_iter: - # A genuine factory yields a FRESH iterator each call; one that returns the - # same (now-exhausted) iterator would leave advance() with nothing to pull. + # A genuine factory yields a fresh iterator each call; one that returns the + # same exhausted iterator would leave the loader with nothing to stream. raise ValueError( "total_size='auto' got a factory that returns the same one-shot iterator " "each call; pass a factory that creates a fresh iterator each call, or " diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 1084e34b12..e898a7f75e 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -32,85 +32,46 @@ def factory(): return factory -def test_advance_shape_and_counters(): - data = np.arange(40, dtype="float64").reshape(20, 2) - ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(2,), total_size=20) - assert ds.batches_seen == 0 - ds.advance() - assert ds.as_tensor().get_value().shape == (4, 2) - assert ds.batches_seen == 1 and ds.rows_streamed == 4 - ds.advance() - assert ds.batches_seen == 2 and ds.rows_streamed == 8 - - def test_plain_loader_rebatches_arbitrary_blocks(): - # blocks of 3 with batch_size=4: re-batched in order; the trailing 2 rows that - # cannot fill a final batch are dropped (drop_last semantics). + """Blocks of 3 with batch_size=4 are re-batched in order; the trailing rows + that cannot fill a final batch are dropped (drop_last semantics).""" data = np.arange(20, dtype="float64").reshape(10, 2) ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) batches = list(ds) assert [b.shape for b in batches] == [(4, 2), (4, 2)] - np.testing.assert_array_equal(np.concatenate(batches), data[:8]) # order preserved + np.testing.assert_array_equal(np.concatenate(batches), data[:8]) def test_raw_array_source_like_vi_rework_sketch(): - # the VI-rework sketch usage: Dataloader(, batch_size=...); rows are - # yielded one sample at a time and the loader re-batches them. + """A raw array works directly, as in the VI-rework sketch + ``Dataloader(np.random.normal(...), batch_size=...)``: rows are yielded one + sample at a time, re-batched, and counted as rows by total_size='auto'.""" data = np.arange(40, dtype="float64").reshape(20, 2) with pytest.warns(UserWarning, match="counting pass"): ds = DataLoader(data, batch_size=8, sample_shape=(2,), total_size="auto") - assert ds.total_size == 20 # counted as 20 rows, not 40 flattened elements + assert ds.total_size == 20 batches = list(ds) - assert [b.shape for b in batches] == [(8, 2), (8, 2)] # trailing 4 rows dropped + assert [b.shape for b in batches] == [(8, 2), (8, 2)] np.testing.assert_array_equal(np.concatenate(batches), data[:16]) def test_wrong_sample_shape_rejected(): - data = np.zeros((12, 3)) # 3 columns, but the loader declares 2 + """A source whose trailing shape does not match sample_shape raises.""" + data = np.zeros((12, 3)) ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(2,), total_size=12) with pytest.raises(ValueError, match="source yielded shape"): - ds.advance() - - -def test_cycle_true_empty_restart_raises_clear_error(): - calls = {"n": 0} - - def factory(): - calls["n"] += 1 - if calls["n"] == 1: - yield np.zeros((4, 1)) - # second epoch: nothing (an exhausted/empty source cannot be cycled) - - ds = DataLoader(factory, batch_size=4, sample_shape=(1,), total_size=4) - ds.advance() - with pytest.raises(RuntimeError, match="restart"): - ds.advance() + next(iter(ds)) def test_total_size_none_warns_at_construction(): + """total_size=None disables the N/batch_size rescaling, so it warns.""" data = np.zeros((8, 1)) with pytest.warns(UserWarning, match="total_size=None"): DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,)) -def test_cycle_true_restarts_source(): - data = np.arange(8, dtype="float64").reshape(8, 1) - ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=True) - for _ in range(4): # two epochs worth - ds.advance() - assert ds.batches_seen == 4 - - -def test_cycle_false_raises_when_exhausted(): - data = np.arange(8, dtype="float64").reshape(8, 1) - ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=8, cycle=False) - ds.advance() - ds.advance() - with pytest.raises(StopIteration): - ds.advance() - - def test_preprocess_fn_applied(): + """preprocess_fn transforms each batch before it is yielded.""" data = np.ones((8, 1)) ds = DataLoader( _chunks(data, 4), @@ -119,33 +80,32 @@ def test_preprocess_fn_applied(): total_size=8, preprocess_fn=lambda b: b * 3.0, ) - ds.advance() - np.testing.assert_array_equal(ds.as_tensor().get_value(), np.full((4, 1), 3.0)) + np.testing.assert_array_equal(next(iter(ds)), np.full((4, 1), 3.0)) -def test_shuffle_buffer_conserves_rows_non_dividing(): - # buffer_size and chunk size deliberately do NOT divide batch_size: the - # carry-over must not lose or duplicate any row (regression for the drop bug). +def test_shuffle_buffer_conserves_rows_with_non_dividing_chunks(): + """Chunk and buffer sizes that do not divide batch_size must not lose or + duplicate rows; the remainder is carried into the next buffer fill.""" data = np.arange(140, dtype="float64").reshape(140, 1) src = shuffle_buffer(_chunks(data, 7), buffer_size=55, batch_size=10, seed=0) batches = list(src()) assert all(b.shape == (10, 1) for b in batches) seen = np.sort(np.concatenate([b.ravel() for b in batches])) - # 140 rows, batch 10 -> 14 full batches, nothing dropped (140 % 10 == 0) np.testing.assert_array_equal(seen, data.ravel()) def test_shuffle_buffer_does_not_mutate_source(): + """Shuffling happens on an owned copy, never in place on the source arrays.""" data = np.arange(100, dtype="float64").reshape(100, 1) original = data.copy() src = shuffle_buffer(_chunks(data, 25), buffer_size=40, batch_size=10, seed=1) list(src()) - np.testing.assert_array_equal(data, original) # source untouched + np.testing.assert_array_equal(data, original) def test_dataloader_shuffle_true_yields_full_batches(): - # shuffle=True wraps the source in a bounded shuffle_buffer internally; batches - # are full and rows are conserved (nothing dropped when N % batch_size == 0). + """shuffle=True wraps the source in a bounded shuffle_buffer; one epoch yields + full batches and conserves every row when N divides batch_size.""" data = np.arange(120, dtype="float64").reshape(120, 1) ds = DataLoader( _chunks(data, 8), @@ -155,32 +115,30 @@ def test_dataloader_shuffle_true_yields_full_batches(): seed=0, sample_shape=(1,), total_size=120, - cycle=False, ) - seen = [] - for _ in range(12): # 120 / 10 - ds.advance() - seen.append(ds.as_tensor().get_value().copy()) - assert all(b.shape == (10, 1) for b in seen) - np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in seen])), data.ravel()) + batches = list(ds) + assert all(b.shape == (10, 1) for b in batches) + np.testing.assert_array_equal( + np.sort(np.concatenate([b.ravel() for b in batches])), data.ravel() + ) def test_total_size_rescales_logp_like_minibatch(): - # observed=buf[:, k] + total_size=N must scale the observed log-likelihood by - # N / batch_size via the existing create_minibatch_rv path -- pin this without - # training anything. + """total_size=len(loader) scales the observed minibatch log-likelihood by + exactly N / batch_size, through the same create_minibatch_rv mechanism as + pm.Minibatch: logp(scaled) == logp(plain) * N / batch_size.""" rng = np.random.default_rng(0) N, bs = 1000, 16 data = rng.normal(size=(bs, 1)) - ds = DataLoader(lambda: iter([data]), batch_size=bs, sample_shape=(1,), total_size=N) - ds.advance() + loader = DataLoader(lambda: iter([data]), batch_size=bs, sample_shape=(1,), total_size=N) with pm.Model() as scaled: mu = pm.Normal("mu", 0, 1) - pm.Normal("y", mu, 1, observed=ds.as_tensor()[:, 0], total_size=ds.total_size) + batch = pm.Data("batch", data) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) with pm.Model() as plain: mu = pm.Normal("mu", 0, 1) - pm.Normal("y", mu, 1, observed=data[:, 0]) # no total_size + pm.Normal("y", mu, 1, observed=data[:, 0]) point = {"mu": np.array(0.3)} obs_scaled = scaled.compile_logp(scaled.observed_RVs)(point) @@ -191,8 +149,9 @@ def test_total_size_rescales_logp_like_minibatch(): def test_trainer_end_to_end_matches_in_ram_minibatch(): """End-to-end: Trainer-driven streaming ADVI reproduces in-RAM pm.Minibatch ADVI. - Exercises the whole blueprint API: a pm.Data placeholder + total_size=len(loader), - and a Trainer that streams minibatches into it with set_data -- no user callbacks. + Exercises the whole API: a pm.Data placeholder, total_size=len(loader), and a + Trainer that streams minibatches into the placeholder with set_data while the + user writes no callbacks. Runs long enough to cycle the loader across epochs. """ seed = 0 rng = np.random.default_rng(seed) @@ -226,12 +185,12 @@ def test_trainer_end_to_end_matches_in_ram_minibatch(): ) with pm.Model() as model: b = pm.Normal("b", 0, 3, shape=3) - batch = pm.Data("batch", np.zeros((bs, 3))) # placeholder; full data stays out of RAM + batch = pm.Data("batch", np.zeros((bs, 3))) pm.Bernoulli( "o", logit_p=b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1], observed=batch[:, 2], - total_size=len(loader), # N comes from the loader, PyTorch-style + total_size=len(loader), ) ap = Trainer( method="advi", @@ -245,9 +204,9 @@ def test_trainer_end_to_end_matches_in_ram_minibatch(): def test_trainer_streams_into_placeholder(): - # The Trainer seeds the pm.Data placeholder before step 0 (pm.fit runs callbacks - # AFTER each step) and overwrites it each step; after fitting it holds a real - # batch, not the zero seed -- with the user writing no callbacks. + """The Trainer seeds the pm.Data placeholder before step 0 (pm.fit runs + callbacks after each step) and overwrites it each step; after fitting it holds + a real batch, not the zero seed.""" data = np.ones((4, 1)) loader = DataLoader(lambda: iter([data] * 100), batch_size=4, sample_shape=(1,), total_size=4) with pm.Model() as model: @@ -257,24 +216,46 @@ def test_trainer_streams_into_placeholder(): Trainer(method="advi", dataloader=loader, data_name="batch").fit( 5, progressbar=False, random_seed=0 ) - np.testing.assert_array_equal(model["batch"].get_value(), data) # not the zero seed + np.testing.assert_array_equal(model["batch"].get_value(), data) + + +def test_trainer_raises_when_loader_cannot_restart(): + """A source that streams one epoch and then comes back empty cannot be cycled; + the Trainer surfaces a clear error instead of training on stale data.""" + calls = {"n": 0} + + def factory(): + calls["n"] += 1 + if calls["n"] == 1: + yield np.zeros((4, 1)) + + loader = DataLoader(factory, batch_size=4, sample_shape=(1,), total_size=4) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + with pytest.raises(RuntimeError, match="yielded no batches"): + Trainer(method="advi", dataloader=loader, data_name="batch").fit( + 5, progressbar=False, random_seed=0 + ) def test_trainer_rejects_non_dataloader(): - # the isinstance guard fires before any model lookup, so no context is needed. + """The isinstance guard fires before any model lookup.""" with pytest.raises(TypeError, match="DataLoader"): Trainer(method="advi", dataloader=object()).fit(10) def test_len_returns_total_size(): + """len(loader) is the dataset row count N, the value total_size needs.""" data = np.zeros((40, 1)) loader = DataLoader(_chunks(data, 8), batch_size=8, sample_shape=(1,), total_size=40) assert len(loader) == 40 def test_len_raises_when_total_size_none(): - # len(loader) IS N; with total_size=None there is no N to hand the model, so it - # raises rather than silently skipping the N/batch_size rescaling. + """With total_size=None there is no N to hand the model, so len() raises + rather than silently skipping the N/batch_size rescaling.""" data = np.ones((4, 1)) with pytest.warns(UserWarning, match="total_size=None"): loader = DataLoader(lambda: iter([data] * 5), batch_size=4, sample_shape=(1,)) @@ -283,44 +264,41 @@ def test_len_raises_when_total_size_none(): def test_iter_yields_clean_batches_and_reiterates(): - # __iter__ yields validated (batch_size, *sample_shape) batches and can be - # re-iterated for another epoch -- this is the stream the Trainer consumes. + """__iter__ yields validated (batch_size, *sample_shape) batches and can be + re-iterated for another epoch.""" data = np.arange(40, dtype="float64").reshape(40, 1) loader = DataLoader(_chunks(data, 10), batch_size=10, sample_shape=(1,), total_size=40) e1 = list(loader) - e2 = list(loader) # re-iterable + e2 = list(loader) assert len(e1) == 4 and all(b.shape == (10, 1) for b in e1) np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in e1])), data.ravel()) np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in e2])), data.ravel()) def test_total_size_zero_raises(): - # total_size=0 is falsy: it slips a None-only check and the model's truthy - # `if total_size:` guard, silently skipping the N/batch_size rescaling. + """total_size=0 is falsy and would silently skip the rescaling, so it raises.""" data = np.zeros((8, 1)) with pytest.raises(ValueError, match="positive integer"): DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=0) def test_total_size_negative_raises(): - # negative total_size is truthy but yields a negative scaling coefficient - # (the data log-likelihood's sign flips, so VI maximizes mis-fit). + """A negative total_size would flip the sign of the data log-likelihood.""" data = np.zeros((8, 1)) with pytest.raises(ValueError, match="positive integer"): DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=-100) def test_shuffle_buffer_small_buffer_conserves_rows(): - # buffer_size < batch_size must NOT silently discard the dataset: the buffer - # accumulates to at least batch_size before emitting (regression for the - # early-return data-loss bug). + """buffer_size < batch_size must not silently discard the dataset: the buffer + accumulates to at least batch_size before emitting.""" data = np.arange(120, dtype="float64").reshape(120, 1) src = shuffle_buffer(_chunks(data, 7), buffer_size=3, batch_size=10, seed=0) batches = list(src()) - assert batches, "buffer_size < batch_size silently produced zero batches" + assert batches assert all(b.shape == (10, 1) for b in batches) seen = np.sort(np.concatenate([b.ravel() for b in batches])) - np.testing.assert_array_equal(seen, data.ravel()) # 120 % 10 == 0, nothing dropped + np.testing.assert_array_equal(seen, data.ravel()) def test_shuffle_buffer_rejects_nonpositive_sizes(): @@ -332,32 +310,31 @@ def test_shuffle_buffer_rejects_nonpositive_sizes(): def test_accepts_numpy_integer_sizes_rejects_bool(): - # the positive-int check uses numbers.Integral: numpy ints are valid, bool is not. + """The positive-int check uses numbers.Integral: numpy ints pass, bool does not.""" data = np.zeros((8, 1)) ds = DataLoader( _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) ) - ds.advance() + assert next(iter(ds)).shape == (4, 1) assert ds.batch_size == 4 with pytest.raises(ValueError): DataLoader(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8) -def test_shuffle_buffer_reshuffles_across_epochs(): - # a seeded buffer must NOT replay one fixed permutation every epoch (that - # would weaken shuffling under cycle=True); each epoch reshuffles, but rows - # are conserved. +def test_shuffle_buffer_draws_fresh_permutation_each_epoch(): + """A seeded buffer must not replay one fixed permutation every epoch; each + epoch reshuffles while conserving rows.""" data = np.arange(60, dtype="float64").reshape(60, 1) factory = shuffle_buffer(_chunks(data, 10), buffer_size=60, batch_size=10, seed=0) epoch1 = np.concatenate([b.ravel() for b in factory()]) epoch2 = np.concatenate([b.ravel() for b in factory()]) - assert not np.array_equal(epoch1, epoch2) # different order across epochs - np.testing.assert_array_equal(np.sort(epoch1), data.ravel()) # but conserves rows + assert not np.array_equal(epoch1, epoch2) + np.testing.assert_array_equal(np.sort(epoch1), data.ravel()) np.testing.assert_array_equal(np.sort(epoch2), data.ravel()) def test_shuffle_buffer_seed_reproducible_across_runs(): - # same seed => identical first-epoch order across independent constructions. + """The same seed gives an identical first-epoch order across constructions.""" data = np.arange(60, dtype="float64").reshape(60, 1) a = np.concatenate( [ @@ -375,8 +352,8 @@ def test_shuffle_buffer_seed_reproducible_across_runs(): def test_sizes_normalized_to_python_int(): - # numpy integer sizes must be stored as plain Python ints so ds.total_size is - # accepted downstream by create_minibatch_rv (regression for the np.int64 trap). + """Numpy integer sizes are stored as plain Python ints so total_size is + accepted downstream by create_minibatch_rv.""" data = np.zeros((8, 1)) ds = DataLoader( _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) @@ -386,30 +363,30 @@ def test_sizes_normalized_to_python_int(): def test_numpy_total_size_accepted_by_observed_rv(): - # a stored np.int64 total_size used to reach create_minibatch_rv and raise - # "Invalid type for total_size"; it must now build a valid observed RV. + """A numpy-integer total_size used to reach create_minibatch_rv and raise; the + normalized value must build and compile a valid observed RV.""" data = np.zeros((4, 1), dtype="float64") - ds = DataLoader(lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=np.int64(4)) - ds.advance() + loader = DataLoader( + lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=np.int64(4) + ) with pm.Model() as model: mu = pm.Normal("mu", 0, 1) - pm.Normal("y", mu, 1, observed=ds.as_tensor()[:, 0], total_size=ds.total_size) - # compiling the observed logp exercises the create_minibatch_rv scaling path + batch = pm.Data("batch", data) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=loader.total_size) model.compile_logp(model.observed_RVs)({"mu": np.array(0.0)}) def test_factory_returning_reiterable_is_accepted(): - # a zero-arg factory may return ANY iterable (e.g. a list), not just an - # iterator; advance() used to crash with "'list' object is not an iterator". + """A zero-arg factory may return any iterable (e.g. a list), not just an + iterator.""" data = [np.zeros((4, 1), dtype="float64")] ds = DataLoader(lambda: data, batch_size=4, sample_shape=(1,), total_size=4) - ds.advance() - assert ds.as_tensor().get_value().shape == (4, 1) + assert next(iter(ds)).shape == (4, 1) def test_scalar_samples_are_batched(): - # with sample_shape=() a 0-D yield is ONE scalar sample (exactly what iterating - # a raw 1-D array produces); the loader batches scalars instead of erroring. + """With sample_shape=() a 0-D yield is one scalar sample, exactly what + iterating a raw 1-D array produces; the loader batches scalars.""" data = np.arange(6, dtype="float64") ds = DataLoader(data, batch_size=3, sample_shape=(), total_size=6) batches = list(ds) @@ -418,9 +395,10 @@ def test_scalar_samples_are_batched(): def test_trainer_appends_user_callbacks_and_streams_distinct_batches(): - # user callbacks (e.g. convergence trackers) must compose with the internal - # advance callback, not collide with it on the `callbacks` keyword; and the - # placeholder must hold a DIFFERENT batch on successive steps. + """User callbacks (e.g. convergence trackers) compose with the internal + advance callback instead of colliding on the keyword, and the placeholder + holds a different batch on successive steps. Also exercises the default + data_name ("batch").""" blocks = [np.full((4, 1), float(i)) for i in range(60)] loader = DataLoader(lambda: iter(blocks), batch_size=4, sample_shape=(1,), total_size=240) seen = [] @@ -428,15 +406,14 @@ def test_trainer_appends_user_callbacks_and_streams_distinct_batches(): x = pm.Normal("x", 0.0, 1.0) batch = pm.Data("batch", np.zeros((4, 1))) pm.Normal("y", x, 1.0, observed=batch[:, 0], total_size=len(loader)) - # no data_name passed: the default ("batch") must match this placeholder Trainer(method="advi", dataloader=loader).fit( 5, callbacks=[lambda *_: seen.append(float(model["batch"].get_value()[0, 0]))] ) - assert len(seen) == 5 # the user callback ran every step - assert len(set(seen)) > 1 # the placeholder advanced to new batches + assert len(seen) == 5 + assert len(set(seen)) > 1 def test_iterable_dataset_base_is_abstract(): - # the base class is a contract: __iter__ must be overridden. + """The base class is a contract: __iter__ must be overridden.""" with pytest.raises(NotImplementedError): iter(IterableDataset()) diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index 7a50093e7b..968dbb55e7 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -37,7 +37,7 @@ def f(): def test_auto_counts_finite_source(): - # no .n_rows -> auto does one counting pass and resolves the true N. + """Without .n_rows, 'auto' does one counting pass and resolves the true N.""" data = np.arange(60, dtype="float64").reshape(60, 1) with pytest.warns(UserWarning, match="counting pass"): ds = DataLoader(_factory(data, 7), batch_size=10, sample_shape=(1,), total_size="auto") @@ -45,8 +45,7 @@ def test_auto_counts_finite_source(): def test_auto_uses_n_rows_fast_path(): - # source advertises .n_rows -> auto trusts it WITHOUT counting (the factory only - # really yields 8 rows, but n_rows says 999; auto must return 999). + """A source-advertised .n_rows is trusted without a counting pass.""" data = np.zeros((8, 1)) f = _factory(data, 4) f.n_rows = 999 @@ -55,7 +54,7 @@ def test_auto_uses_n_rows_fast_path(): def test_auto_rejects_one_shot_iterator(): - # a bare generator is consumed by counting -> auto must refuse it. + """A bare generator would be consumed by the counting pass, so 'auto' refuses it.""" data = np.zeros((20, 1)) one_shot = (data[i : i + 4] for i in range(0, 20, 4)) with pytest.raises(ValueError, match="re-readable"): @@ -63,9 +62,8 @@ def test_auto_rejects_one_shot_iterator(): def test_shuffle_buffer_forwards_n_rows_for_auto(): - # shuffle_buffer must forward a known .n_rows so total_size="auto" works through - # the explicit shuffle_buffer(parquet_source(...)) composition WITHOUT a counting - # pass (the realistic way power users wrap a Parquet source). + """shuffle_buffer forwards a known .n_rows so total_size='auto' works through + an explicit shuffle_buffer(parquet_source(...)) composition without counting.""" data = np.arange(40, dtype="float64").reshape(40, 1) src = _factory(data, 8) src.n_rows = 40 @@ -73,14 +71,14 @@ def test_shuffle_buffer_forwards_n_rows_for_auto(): assert wrapped.n_rows == 40 with warnings.catch_warnings(): - warnings.simplefilter("error", UserWarning) # a counting pass would warn -> fail + warnings.simplefilter("error", UserWarning) ds = DataLoader(wrapped, batch_size=10, sample_shape=(1,), total_size="auto") assert ds.total_size == 40 def test_dataloader_shuffle_auto_resolves_via_n_rows(): - # DataLoader(shuffle=True, total_size="auto") must resolve N from the source's - # .n_rows WITHOUT a counting pass, even though shuffle wraps the source. + """DataLoader(shuffle=True, total_size='auto') resolves N from the source's + .n_rows without a counting pass, even though shuffle wraps the source.""" data = np.arange(40, dtype="float64").reshape(40, 1) src = _factory(data, 8) src.n_rows = 40 @@ -99,15 +97,15 @@ def test_dataloader_shuffle_auto_resolves_via_n_rows(): def test_shuffle_buffer_without_n_rows_has_no_attribute(): - # a plain source without .n_rows must not gain a bogus one. + """A source without .n_rows must not gain a bogus one through the wrapper.""" data = np.arange(40, dtype="float64").reshape(40, 1) wrapped = shuffle_buffer(_factory(data, 8), buffer_size=20, batch_size=10, seed=0) assert not hasattr(wrapped, "n_rows") def test_auto_rejects_factory_returning_same_one_shot_iterator(): - # a "factory" that hands back the SAME already-consumable iterator each call is - # not re-readable: the counting pass consumes it and advance() would get nothing. + """A factory that returns the same already-consumed iterator each call is not + re-readable; the counting pass detects and refuses it.""" data = np.zeros((20, 1)) one_shot = (data[i : i + 4] for i in range(0, 20, 4)) with pytest.raises(ValueError, match="fresh iterator"): @@ -122,24 +120,26 @@ def test_auto_rejects_bad_n_rows(): def test_sanity_warns_on_grossly_wrong_total_size(): - # one full pass = 20 rows, but total_size=100 -> at the first epoch boundary, warn. + """A hand-passed total_size that grossly disagrees with the rows actually + streamed in one pass triggers the one-shot warning at the epoch boundary.""" data = np.arange(20, dtype="float64").reshape(20, 1) ds = DataLoader(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=100) with pytest.warns(UserWarning, match="disagrees with"): - for _ in range(6): # 5 batches = one epoch, the 6th crosses the boundary - ds.advance() + list(ds._stream_batches()) def test_sanity_silent_when_total_size_matches(): + """No warning when total_size matches the rows streamed in one pass.""" data = np.arange(20, dtype="float64").reshape(20, 1) ds = DataLoader(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=20) with warnings.catch_warnings(): - warnings.simplefilter("error", UserWarning) # any UserWarning fails the test - for _ in range(6): - ds.advance() + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) def test_parquet_source_n_rows_from_metadata(tmp_path): + """parquet_source reads n_rows from file metadata (no data scan) and + total_size='auto' picks it up without a counting pass.""" pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") rng = np.random.default_rng(0) @@ -153,10 +153,9 @@ def test_parquet_source_n_rows_from_metadata(tmp_path): f"{tmp_path}/part_{i:02d}.parquet", ) src = parquet_source(str(tmp_path)) - assert isinstance(src, IterableDataset) # parquet_source is a dataset now - assert src.n_rows == total # read from metadata, no data scan + assert isinstance(src, IterableDataset) + assert src.n_rows == total - # and total_size='auto' picks it up for free (no counting pass / warning) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) ds = DataLoader(src, batch_size=10, sample_shape=(2,), total_size="auto") @@ -164,6 +163,7 @@ def test_parquet_source_n_rows_from_metadata(tmp_path): def test_parquet_source_columns_and_shard_order(tmp_path): + """columns= selects a column subset and shards are read in sorted path order.""" pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") for i in range(2): @@ -175,8 +175,8 @@ def test_parquet_source_columns_and_shard_order(tmp_path): ) src = parquet_source(str(tmp_path), columns=["a", "c"]) blocks = list(src) - assert [b.shape for b in blocks] == [(2, 2), (2, 2)] # column b filtered out - np.testing.assert_array_equal(blocks[0][:, 0], [0.0, 0.0]) # sorted shard order + assert [b.shape for b in blocks] == [(2, 2), (2, 2)] + np.testing.assert_array_equal(blocks[0][:, 0], [0.0, 0.0]) np.testing.assert_array_equal(blocks[1][:, 1], [11.0, 11.0]) @@ -187,13 +187,13 @@ def test_parquet_source_empty_dir_raises(tmp_path): def test_auto_counts_unshuffled_source_when_shuffling_non_divisible(): - # total_size="auto" with shuffle=True must count the UNSHUFFLED source: the - # shuffle buffer drops the final partial batch, so counting through it would - # undercount N by up to batch_size-1. N=125 is not divisible by batch_size=10. + """total_size='auto' with shuffle=True counts the unshuffled source: the + shuffle buffer drops the trailing partial batch, so counting through it would + undercount N by up to batch_size - 1 (here 125 vs 120).""" data = np.arange(125, dtype="float64").reshape(125, 1) with pytest.warns(UserWarning, match="counting pass"): ds = DataLoader( - _factory(data, 125), # one chunk, NO .n_rows -> forces a counting pass + _factory(data, 125), batch_size=10, shuffle=True, buffer_size=30, @@ -201,26 +201,25 @@ def test_auto_counts_unshuffled_source_when_shuffling_non_divisible(): sample_shape=(1,), total_size="auto", ) - assert ds.total_size == 125 # exact N, not 120 (was undercounted via the shuffle wrap) + assert ds.total_size == 125 def test_stream_batches_updates_counters_and_warns_on_wrong_total_size(): - # The accounting-aware stream the Trainer iterates (loader._stream_batches) must - # update the public counters AND fire the one-shot sanity check at the epoch - # boundary -- so a grossly wrong hand-passed total_size is still caught on the - # Trainer's primary path, not only via advance(); plain iteration stays pure. + """The accounting stream the Trainer iterates updates the public counters and + fires the one-shot total_size sanity check at the epoch boundary, while plain + __iter__ stays side-effect-free.""" data = np.arange(40, dtype="float64").reshape(20, 2) ds = DataLoader( - _factory(data, 5), # 4 chunks of 5 rows + _factory(data, 5), batch_size=5, sample_shape=(2,), - total_size=10_000, # grossly wrong vs the 20 rows actually streamed + total_size=10_000, ) assert ds.batches_seen == 0 and ds.rows_streamed == 0 - list(ds) # plain __iter__ must NOT mutate counters + list(ds) assert ds.batches_seen == 0 and ds.rows_streamed == 0 with pytest.warns(UserWarning, match="disagrees with"): - batches = list(ds._stream_batches()) # one epoch through the Trainer's path + batches = list(ds._stream_batches()) assert len(batches) == 4 assert ds.batches_seen == 4 assert ds.rows_streamed == 20 From c59cbb151f774fcd3a20a40deeac735f1dbabfbb Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 06:24:03 -0500 Subject: [PATCH 13/27] Fix mypy error comparing Integral with int --- pymc/variational/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index a388c3ac02..871a267858 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -97,7 +97,7 @@ def _is_positive_int(value: object) -> bool: """True for a strictly positive integer (incl. numpy integer types), excluding bool.""" - return isinstance(value, numbers.Integral) and not isinstance(value, bool) and value > 0 + return isinstance(value, numbers.Integral) and not isinstance(value, bool) and int(value) > 0 class IterableDataset: From c5a029e94ca99537dac054ab244ba084231b1517 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 06:28:54 -0500 Subject: [PATCH 14/27] Match docstring punctuation to the rest of the codebase --- pymc/variational/streaming.py | 74 +++++++++++++++++------------------ 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 871a267858..3fa07f54b4 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -13,7 +13,7 @@ # limitations under the License. """Out-of-core minibatching for variational inference. -``pm.Minibatch`` random-indexes an array that is *fully resident in memory*; its +``pm.Minibatch`` random-indexes an array that is fully resident in memory; its peak memory is therefore O(N) in the dataset size. This module instead streams minibatches from an out-of-core source into a ``pm.Data`` placeholder, so peak memory is O(batch) plus, if used, the shuffle buffer, independent of N. @@ -21,36 +21,36 @@ The API mirrors PyTorch's ``torch.utils.data`` so the mental model transfers directly: -* :class:`IterableDataset` -- a re-iterable, out-of-core source of rows +* :class:`IterableDataset`: a re-iterable, out-of-core source of rows (e.g. :func:`parquet_source` over a directory of shards). It never loads the whole dataset; it yields it a chunk at a time. -* :class:`DataLoader` -- turns a dataset into fixed-size (optionally shuffled) +* :class:`DataLoader`: turns a dataset into fixed-size (optionally shuffled) minibatches; it is iterable (the minibatch stream) and sized. Note ``len(loader)`` is the row count ``N`` (what the observed distribution needs for ``total_size``), - *not* the batch count ``torch.utils.data.DataLoader.__len__`` returns. -* :class:`Trainer` -- drives variational inference (ADVI, ...) over a - ``DataLoader`` with **no user-facing callbacks**; + not the batch count ``torch.utils.data.DataLoader.__len__`` returns. +* :class:`Trainer`: drives variational inference (ADVI, ...) over a + ``DataLoader`` with no user-facing callbacks; ``Trainer(method=..., dataloader=...).fit(n)`` streams each minibatch into the model's ``pm.Data`` placeholder with ``set_data``. -**The full data never enters RAM.** The model graph observes only a -``(batch_size, *sample_shape)`` ``pm.Data`` *placeholder* that the ``Trainer`` +The full data never enters RAM. The model graph observes only a +``(batch_size, *sample_shape)`` ``pm.Data`` placeholder that the ``Trainer`` overwrites with the next minibatch every step. Passing a directory of Parquet shards far larger than RAM still gives a model whose resident footprint is one batch. -The unbiased-gradient rescaling is the *same* as for ``pm.Minibatch``: the +The unbiased-gradient rescaling is the same as for ``pm.Minibatch``: the observed log-likelihood must be scaled by ``N / batch_size`` through the existing :func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. ``N`` is exactly -``len(loader)`` (the loader is sized; ``len`` returns the row count ``N``) -- so the +``len(loader)`` (the loader is sized; ``len`` returns the row count ``N``), so the model passes ``total_size=len(loader)``. (Folding that scaling into the inference step, so it drops out of the model body, is the next step in PyMC's VI rework.) -The one extra obligation relative to ``pm.Minibatch`` is **shuffling**. +The one extra obligation relative to ``pm.Minibatch`` is shuffling. ``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its minibatches are i.i.d. by construction. A streaming source is only as well mixed as the order it yields rows in: reading time/row-ordered data through a -*bounded* buffer is merely a block-shuffle and biases the variational posterior. +bounded buffer is merely a block-shuffle and biases the variational posterior. Pre-shuffle the data once on disk (or interleave shards) and/or pass ``shuffle=True``. @@ -101,14 +101,14 @@ def _is_positive_int(value: object) -> bool: class IterableDataset: - """A re-iterable, out-of-core source of rows -- the analogue of ``torch.utils.data.IterableDataset``. + """A re-iterable, out-of-core source of rows, the analogue of ``torch.utils.data.IterableDataset``. Subclass and implement :meth:`__iter__` to yield ``np.ndarray`` blocks of rows (shape ``(rows, *sample_shape)``); :class:`DataLoader` re-batches those blocks - into fixed-size minibatches. ``__iter__`` must return a **fresh** iterator each + into fixed-size minibatches. ``__iter__`` must return a fresh iterator each call so the dataset can be replayed across epochs. - Optionally set :attr:`n_rows` (the total row count, if known cheaply -- e.g. + Optionally set :attr:`n_rows` (the total row count, if known cheaply, e.g. from file metadata) so a :class:`DataLoader` with ``total_size="auto"`` can resolve ``N`` without a counting pass. @@ -130,14 +130,14 @@ class DataLoader: The analogue of ``torch.utils.data.DataLoader``: it batches (and optionally shuffles) an :class:`IterableDataset` into the minibatch stream that :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)`` - is the dataset size ``N``). The full dataset never enters memory -- only one + is the dataset size ``N``). The full dataset never enters memory; only one ``(batch_size, *sample_shape)`` batch does. Parameters ---------- dataset : IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] The source of rows. An :class:`IterableDataset`, a re-iterable (including a - plain ``np.ndarray``), or a zero-arg *factory* returning a fresh iterator + plain ``np.ndarray``), or a zero-arg factory returning a fresh iterator (preferred, so the stream can be restarted each epoch). It may yield single samples (e.g. the rows of a raw array) or blocks of any size; the loader re-batches them, in order, to exactly ``batch_size`` rows. Trailing rows @@ -149,8 +149,8 @@ class DataLoader: shuffle : bool, default False If ``True``, wrap the source in a bounded :func:`shuffle_buffer` of ``buffer_size`` rows. This only approximates i.i.d. batches for an - *already unordered* stream; a bounded buffer cannot fix strongly - time/row-ordered data (pre-shuffle on disk for that -- see the module + already unordered stream; a bounded buffer cannot fix strongly + time/row-ordered data (pre-shuffle on disk for that; see the module docstring). buffer_size : int, optional Shuffle-buffer size in rows when ``shuffle=True``. Defaults to @@ -211,7 +211,7 @@ def __init__( elif total_size is None: warnings.warn( "DataLoader created with total_size=None: the minibatch " - "log-likelihood will NOT be rescaled and the posterior will be " + "log-likelihood will not be rescaled and the posterior will be " "biased. Pass total_size=N (the true dataset size) or total_size='auto'.", UserWarning, stacklevel=2, @@ -268,10 +268,10 @@ def __iter__(self) -> Iterator[np.ndarray]: yield self._prepare(batch) def __len__(self) -> int: - """The dataset size ``N`` (row count) -- pass to the distribution's ``total_size``. + """The dataset size ``N`` (row count); pass it to the distribution's ``total_size``. ``total_size=len(loader)`` is how the model gets the ``N / batch_size`` - rescaling. Note this returns the *row* count ``N``, not the *batch* count + rescaling. Note this returns the row count ``N``, not the batch count (``ceil(N / batch_size)``) that ``torch.utils.data.DataLoader.__len__`` returns; ``total_size`` needs ``N``. :attr:`total_size` is the same value. """ @@ -317,8 +317,8 @@ def _maybe_warn_total_size(self) -> None: if seen and abs(self._total_size - seen) > 0.1 * seen: warnings.warn( f"total_size={self._total_size} disagrees with the {seen} rows streamed " - f"in one full pass; the N/batch_size rescaling -- and therefore the " - f"posterior width -- is likely wrong. Pass the true dataset size, or " + f"in one full pass; the N/batch_size rescaling, and therefore the " + f"posterior width, is likely wrong. Pass the true dataset size, or " f"total_size='auto'.", UserWarning, stacklevel=3, @@ -344,7 +344,7 @@ def _validate(self, batch: np.ndarray) -> None: class Trainer: - """Drive variational inference over a :class:`DataLoader` -- without callbacks. + """Drive variational inference over a :class:`DataLoader` without user callbacks. Follows the design in PyMC's variational-inference rework (Grabowski, *VI Overview*) and PyTorch Lightning: the ``Trainer`` owns the training loop, the @@ -372,9 +372,9 @@ class Trainer: method : str, default "advi" Variational method, forwarded to :func:`pymc.fit` (``"advi"``, ``"fullrank_advi"``, ...). Once the VI rework lands this will also accept - an inference *instance* (e.g. ``ADVI()``); a string drives today's ``pm.fit``. + an inference instance (e.g. ``ADVI()``); a string drives today's ``pm.fit``. dataloader : DataLoader - The minibatch source. ``len(dataloader)`` is ``N`` -- the model should pass + The minibatch source. ``len(dataloader)`` is ``N``; the model should pass it to the observed distribution's ``total_size``. model : pymc.Model, optional Defaults to the model on the context stack. @@ -437,7 +437,7 @@ def _stream() -> Iterator[np.ndarray]: raise RuntimeError("dataloader yielded no batches") batches = _stream() - # Seed the placeholder before step 0: pm.fit runs callbacks AFTER each step, + # Seed the placeholder before step 0: pm.fit runs callbacks after each step, # so without this the first step would train on the placeholder's contents. model.set_data(self.data_name, next(batches)) @@ -470,18 +470,18 @@ def shuffle_buffer( Accumulates rows from ``chunk_source`` into a buffer of at least ``buffer_size`` rows, shuffles it, and yields ``batch_size`` slices; rows that - do not fill a final batch are **carried over** into the next buffer (never + do not fill a final batch are carried over into the next buffer (never dropped) until the source is exhausted, at which point a single trailing partial batch (< ``batch_size`` rows) is dropped. This approximates i.i.d. - minibatches from an *unordered* or pre-shuffled stream. + minibatches from an unordered or pre-shuffled stream. :class:`DataLoader` calls this for you when ``shuffle=True``; use it directly when you want explicit control over ``buffer_size`` independently of the loader. - It does **not** by itself fix a strongly time/row-ordered stream (a bounded - buffer only block-shuffles such data) -- pre-shuffle on disk, or interleave - shards into ``chunk_source``, for that. ``buffer_size`` is a *lower* bound: the + It does not by itself fix a strongly time/row-ordered stream (a bounded + buffer only block-shuffles such data); pre-shuffle on disk, or interleave + shards into ``chunk_source``, for that. ``buffer_size`` is a lower bound: the buffer always accumulates at least ``max(buffer_size, batch_size)`` rows before emitting (so a ``buffer_size`` smaller than ``batch_size`` still yields full batches instead of silently dropping the stream), and a single chunk larger @@ -634,7 +634,7 @@ def _auto_total_size( reads it from Parquet metadata without scanning the data) use it directly. Otherwise do a single counting pass over a finite, re-readable source. A bare one-shot iterator cannot be auto-counted (counting consumes it) and an infinite stream would make the - pass hang -- both must pass ``total_size`` explicitly. + pass hang; both must pass ``total_size`` explicitly. """ n = getattr(source, "n_rows", None) if n is None: @@ -658,7 +658,7 @@ def _auto_total_size( count = 0 for chunk in first_iter: a = np.asarray(chunk) - # A yield of shape exactly `sample_shape` is ONE sample, not a block. + # A yield of shape exactly `sample_shape` is one sample, not a block. count += 1 if a.shape == sample_shape else int(a.shape[0]) if count <= 0: raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") @@ -677,7 +677,7 @@ class _ParquetDataset(IterableDataset): """An :class:`IterableDataset` over a directory of Parquet shards. Yields one ``(rows, n_columns)`` ``float64`` array per file and exposes - :attr:`n_rows` read from Parquet *metadata* (no data scan). + :attr:`n_rows` read from Parquet metadata (no data scan). """ def __init__(self, paths: list[str], columns: list[str] | None, n_rows: int): @@ -702,7 +702,7 @@ def parquet_source( """An :class:`IterableDataset` over a directory of Parquet files. Yields one ``(rows, n_columns)`` ``float64`` array per file, and carries an - ``n_rows`` attribute read from Parquet *metadata* (no data scan) so that + ``n_rows`` attribute read from Parquet metadata (no data scan) so that ``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or wrap in :func:`shuffle_buffer`) to get shuffled batches. From 6f9eed25fd462c9b01eae982d6953ec298513761 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 06:40:26 -0500 Subject: [PATCH 15/27] Follow numpydoc section conventions --- pymc/variational/streaming.py | 37 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 3fa07f54b4..ca8dbb10eb 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -54,8 +54,8 @@ Pre-shuffle the data once on disk (or interleave shards) and/or pass ``shuffle=True``. -Example -------- +Examples +-------- .. code-block:: python import numpy as np @@ -353,20 +353,6 @@ class Trainer: the ``Trainer`` streams minibatches into it with ``model.set_data`` once per step, so the user wires up no callbacks. - .. code-block:: python - - import numpy as np - - loader = DataLoader( - parquet_source("shuffled/"), batch_size=4096, sample_shape=(4,), total_size="auto" - ) - with pm.Model() as model: - b = pm.Normal("b", 0.0, 3.0, shape=4) - batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder - logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] - pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) - approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) - Parameters ---------- method : str, default "advi" @@ -391,6 +377,20 @@ class Trainer: rework's ``Inference.step(batch)`` lands it moves there, at which point the ``total_size`` rescaling can be derived from ``len(dataloader)`` and dropped from the model body entirely. + + Examples + -------- + .. code-block:: python + + loader = DataLoader( + parquet_source("shuffled/"), batch_size=4096, sample_shape=(4,), total_size="auto" + ) + with pm.Model() as model: + b = pm.Normal("b", 0.0, 3.0, shape=4) + batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder + logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] + pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) + approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) """ def __init__( @@ -418,7 +418,10 @@ def fit( ): """Fit for ``n`` steps, streaming minibatches into the model's placeholder. - Returns whatever :func:`pymc.fit` returns for the chosen method. + Returns + ------- + :class:`Approximation` + The fitted approximation, as returned by :func:`pymc.fit`. """ loader = self.dataloader if not isinstance(loader, DataLoader): From 45cb513d7a05e1510425733aa7fee3ba9a76d5e4 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 06:41:34 -0500 Subject: [PATCH 16/27] Declare __all__ like the neighboring modules --- pymc/variational/streaming.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index ca8dbb10eb..4df59d57e8 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -94,6 +94,8 @@ from pymc.model import modelcontext from pymc.variational.inference import fit as _fit +__all__ = ["DataLoader", "IterableDataset", "Trainer", "parquet_source", "shuffle_buffer"] + def _is_positive_int(value: object) -> bool: """True for a strictly positive integer (incl. numpy integer types), excluding bool.""" From 14568b126a43e0068ebbf21b589bf1c51b92da3e Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 22:29:55 -0500 Subject: [PATCH 17/27] Add docstrings to the last three tests --- tests/variational/test_streaming.py | 1 + tests/variational/test_streaming_autosize.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index e898a7f75e..274e480333 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -302,6 +302,7 @@ def test_shuffle_buffer_small_buffer_conserves_rows(): def test_shuffle_buffer_rejects_nonpositive_sizes(): + """Zero or negative buffer/batch sizes raise at construction.""" data = np.zeros((10, 1)) with pytest.raises(ValueError, match="buffer_size"): shuffle_buffer(_chunks(data, 5), buffer_size=0, batch_size=4) diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index 968dbb55e7..0d092cc257 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -113,6 +113,7 @@ def test_auto_rejects_factory_returning_same_one_shot_iterator(): def test_auto_rejects_bad_n_rows(): + """A non-positive source .n_rows is rejected instead of trusted.""" f = _factory(np.zeros((8, 1)), 4) f.n_rows = 0 with pytest.raises(ValueError, match="n_rows must be a positive integer"): @@ -181,6 +182,7 @@ def test_parquet_source_columns_and_shard_order(tmp_path): def test_parquet_source_empty_dir_raises(tmp_path): + """A directory with no matching Parquet files raises a clear error.""" pytest.importorskip("pyarrow") with pytest.raises(ValueError, match="no Parquet files match"): parquet_source(str(tmp_path)) From 51ee22c8a0054381f70928d509ce6887b3fee507 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 22:37:35 -0500 Subject: [PATCH 18/27] Plainer docstring wording --- pymc/variational/streaming.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 4df59d57e8..c8d35302f7 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -18,8 +18,7 @@ minibatches from an out-of-core source into a ``pm.Data`` placeholder, so peak memory is O(batch) plus, if used, the shuffle buffer, independent of N. -The API mirrors PyTorch's ``torch.utils.data`` so the mental model transfers -directly: +The API follows PyTorch's ``torch.utils.data``: * :class:`IterableDataset`: a re-iterable, out-of-core source of rows (e.g. :func:`parquet_source` over a directory of shards). It never loads the @@ -46,7 +45,7 @@ model passes ``total_size=len(loader)``. (Folding that scaling into the inference step, so it drops out of the model body, is the next step in PyMC's VI rework.) -The one extra obligation relative to ``pm.Minibatch`` is shuffling. +One difference from ``pm.Minibatch`` is shuffling. ``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its minibatches are i.i.d. by construction. A streaming source is only as well mixed as the order it yields rows in: reading time/row-ordered data through a @@ -103,7 +102,7 @@ def _is_positive_int(value: object) -> bool: class IterableDataset: - """A re-iterable, out-of-core source of rows, the analogue of ``torch.utils.data.IterableDataset``. + """A re-iterable, out-of-core source of rows, like ``torch.utils.data.IterableDataset``. Subclass and implement :meth:`__iter__` to yield ``np.ndarray`` blocks of rows (shape ``(rows, *sample_shape)``); :class:`DataLoader` re-batches those blocks @@ -129,7 +128,7 @@ def __iter__(self) -> Iterator[np.ndarray]: class DataLoader: """Turn an out-of-core dataset into fixed-size minibatches for variational inference. - The analogue of ``torch.utils.data.DataLoader``: it batches (and optionally + Like ``torch.utils.data.DataLoader``, it batches (and optionally shuffles) an :class:`IterableDataset` into the minibatch stream that :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)`` is the dataset size ``N``). The full dataset never enters memory; only one @@ -220,7 +219,7 @@ def __init__( ) elif not _is_positive_int(total_size): # 0 is falsy (the rescaling would be silently skipped) and a negative - # value flips the sign of the data log-likelihood; reject both loudly. + # value flips the sign of the data log-likelihood; raise on both. raise ValueError( "total_size must be a positive integer (the true dataset size N) so " "the minibatch log-likelihood is rescaled by N / batch_size; got " @@ -311,7 +310,7 @@ def _prepare(self, batch: np.ndarray) -> np.ndarray: return np.array(batch, dtype=self._dtype) def _maybe_warn_total_size(self) -> None: - """Warn once if total_size grossly disagrees with the rows seen in one pass.""" + """Warn once if total_size differs from the rows seen in one pass by more than 10%.""" if self._warned_size or self._total_size is None: return self._warned_size = True @@ -353,7 +352,7 @@ class Trainer: :class:`DataLoader` owns batching (and ``len(dataloader)`` is the dataset size ``N``), and the model owns the math. The model exposes a ``pm.Data`` placeholder; the ``Trainer`` streams minibatches into it with ``model.set_data`` once per - step, so the user wires up no callbacks. + step; no user callbacks are needed. Parameters ---------- From 54f13e4b0aacbebbfd7ed1f95dcc0abb8b62c0e8 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 22:44:41 -0500 Subject: [PATCH 19/27] State the memory bound precisely --- pymc/variational/streaming.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index c8d35302f7..fca0e93970 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -16,7 +16,8 @@ ``pm.Minibatch`` random-indexes an array that is fully resident in memory; its peak memory is therefore O(N) in the dataset size. This module instead streams minibatches from an out-of-core source into a ``pm.Data`` placeholder, so peak -memory is O(batch) plus, if used, the shuffle buffer, independent of N. +memory is set by the batch, the source chunk, and the optional shuffle buffer, +independent of N. The API follows PyTorch's ``torch.utils.data``: @@ -72,7 +73,7 @@ with pm.Model() as model: b = pm.Normal("b", 0.0, 3.0, shape=4) - batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder, the only data in RAM + batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder for one minibatch logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) From fc5eb11d3780292893e590e0b21c86007ff6e5ac Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 22:53:28 -0500 Subject: [PATCH 20/27] Promote single samples before the shuffle buffer shuffle_buffer concatenates yields along the leading axis, so a raw array source under shuffle=True had its rows flattened (2-D) or crashed on shape[0] (scalars). Promote single samples to one-row blocks before the shuffle wrap, with the same helper the re-batcher uses. Also tighten a few docstring claims: the parquet dtype follows the file columns, and the shuffle buffer bound is stated as rows held. --- pymc/variational/streaming.py | 61 ++++++++++++++++++++++------- tests/variational/test_streaming.py | 24 ++++++++++++ 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index fca0e93970..d2fdaaf5ac 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -132,8 +132,7 @@ class DataLoader: Like ``torch.utils.data.DataLoader``, it batches (and optionally shuffles) an :class:`IterableDataset` into the minibatch stream that :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)`` - is the dataset size ``N``). The full dataset never enters memory; only one - ``(batch_size, *sample_shape)`` batch does. + is the dataset size ``N``). The full dataset never enters memory. Parameters ---------- @@ -199,8 +198,13 @@ def __init__( if shuffle: if buffer_size is None: buffer_size = 50 * int(batch_size) + # shuffle_buffer concatenates yields along the leading axis, so single + # samples must be promoted to one-row blocks before shuffling. source_factory = shuffle_buffer( - raw_factory, buffer_size=buffer_size, batch_size=batch_size, seed=seed + _block_factory(raw_factory, tuple(sample_shape)), + buffer_size=buffer_size, + batch_size=batch_size, + seed=seed, ) self._source_factory = source_factory @@ -490,8 +494,8 @@ def shuffle_buffer( buffer always accumulates at least ``max(buffer_size, batch_size)`` rows before emitting (so a ``buffer_size`` smaller than ``batch_size`` still yields full batches instead of silently dropping the stream), and a single chunk larger - than that is taken whole, so peak buffer memory is - ``max(buffer_size, batch_size, largest_chunk_rows)``. + than that is taken whole, so the buffer holds at most + ``max(buffer_size, batch_size, largest_chunk_rows)`` rows. Each epoch (each call of the returned factory) draws a fresh permutation from a sub-stream of ``seed``, so the shuffle order differs across epochs while @@ -549,6 +553,40 @@ def factory() -> Iterator[np.ndarray]: return factory + +def _promote_to_block(a: np.ndarray, sample_shape: tuple[int, ...]) -> np.ndarray: + """Return ``a`` as a ``(rows, *sample_shape)`` block; a single sample becomes one row.""" + if a.shape == sample_shape: + return a[None, ...] + if a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape: + raise ValueError( + f"source yielded shape {a.shape}; expected a single sample of shape " + f"{sample_shape} or a block of shape (rows, *{sample_shape})" + ) + return a + + +def _block_factory( + factory: Callable[[], Iterator[np.ndarray]], + sample_shape: tuple[int, ...], +) -> Callable[[], Iterator[np.ndarray]]: + """Wrap ``factory`` so every yield is a block, promoting single samples. + + :func:`shuffle_buffer` counts and concatenates yields along the leading axis, + so single-sample yields (e.g. the rows of a raw array) must be promoted to + one-row blocks before shuffling. A known ``.n_rows`` is forwarded. + """ + + def f() -> Iterator[np.ndarray]: + for arr in factory(): + yield _promote_to_block(np.asarray(arr), sample_shape) + + n_rows = getattr(factory, "n_rows", None) + if n_rows is not None: + f.n_rows = n_rows # type: ignore[attr-defined] + return f + + def _rebatch( blocks: Iterable[np.ndarray], batch_size: int, @@ -567,14 +605,7 @@ def _rebatch( buf: list[np.ndarray] = [] have = 0 for arr in blocks: - a = np.asarray(arr) - if a.shape == sample_shape: # a single sample, not a block - a = a[None, ...] - elif a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape: - raise ValueError( - f"source yielded shape {a.shape}; expected a single sample of shape " - f"{sample_shape} or a block of shape (rows, *{sample_shape})" - ) + a = _promote_to_block(np.asarray(arr), sample_shape) buf.append(a) have += a.shape[0] if have < batch_size: @@ -681,7 +712,7 @@ def _auto_total_size( class _ParquetDataset(IterableDataset): """An :class:`IterableDataset` over a directory of Parquet shards. - Yields one ``(rows, n_columns)`` ``float64`` array per file and exposes + Yields one ``(rows, n_columns)`` array per file and exposes :attr:`n_rows` read from Parquet metadata (no data scan). """ @@ -706,7 +737,7 @@ def parquet_source( ) -> _ParquetDataset: """An :class:`IterableDataset` over a directory of Parquet files. - Yields one ``(rows, n_columns)`` ``float64`` array per file, and carries an + Yields one ``(rows, n_columns)`` array per file, and carries an ``n_rows`` attribute read from Parquet metadata (no data scan) so that ``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 274e480333..5a8da8f2e7 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -385,6 +385,30 @@ def test_factory_returning_reiterable_is_accepted(): assert next(iter(ds)).shape == (4, 1) +def test_raw_array_with_shuffle_true(): + """A raw array source composes with shuffle=True: rows are promoted to + one-row blocks before the shuffle buffer instead of being flattened by it.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + ds = DataLoader( + data, batch_size=8, shuffle=True, buffer_size=16, seed=0, sample_shape=(2,), total_size=20 + ) + batches = list(ds) + assert [b.shape for b in batches] == [(8, 2), (8, 2)] + rows = {tuple(r) for b in batches for r in b} + assert len(rows) == 16 and rows <= {tuple(r) for r in data} + + +def test_scalar_raw_array_with_shuffle_true(): + """Scalar samples from a raw 1-D array compose with shuffle=True.""" + data = np.arange(12, dtype="float64") + ds = DataLoader( + data, batch_size=4, shuffle=True, buffer_size=6, seed=0, sample_shape=(), total_size=12 + ) + batches = list(ds) + assert [b.shape for b in batches] == [(4,), (4,), (4,)] + np.testing.assert_array_equal(np.sort(np.concatenate(batches)), data) + + def test_scalar_samples_are_batched(): """With sample_shape=() a 0-D yield is one scalar sample, exactly what iterating a raw 1-D array produces; the loader batches scalars.""" From 5769fbffe58614bf4b1977832462e2f2583c842b Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 22:55:31 -0500 Subject: [PATCH 21/27] Apply formatter --- pymc/variational/streaming.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index d2fdaaf5ac..3218062a57 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -553,7 +553,6 @@ def factory() -> Iterator[np.ndarray]: return factory - def _promote_to_block(a: np.ndarray, sample_shape: tuple[int, ...]) -> np.ndarray: """Return ``a`` as a ``(rows, *sample_shape)`` block; a single sample becomes one row.""" if a.shape == sample_shape: From 13a6a05c144f3110b68adc061410c0cce8012dee Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Wed, 10 Jun 2026 23:53:43 -0500 Subject: [PATCH 22/27] Fix loader and trainer edge cases; tighten docstring claims DataLoader infers sample_shape from a raw array source, so DataLoader(arr, batch_size=...) batches rows instead of silently flattening them to scalars. The total_size check no longer warns on an exact N when drop-last truncates the final batch, and its advice covers a wrong source n_rows. Trainer.fit routes all kwargs through one merge so constructor defaults like random_seed work as documented, accepts an Inference instance, and rejects an unknown data_name before consuming a batch. parquet_source validates columns against the schema up front. The shuffle_buffer docstring states the true buffer bound. --- pymc/variational/streaming.py | 124 ++++++++++++------- tests/variational/test_streaming.py | 72 +++++++++++ tests/variational/test_streaming_autosize.py | 35 +++++- 3 files changed, 184 insertions(+), 47 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 3218062a57..1ce013d884 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -33,11 +33,11 @@ ``Trainer(method=..., dataloader=...).fit(n)`` streams each minibatch into the model's ``pm.Data`` placeholder with ``set_data``. -The full data never enters RAM. The model graph observes only a -``(batch_size, *sample_shape)`` ``pm.Data`` placeholder that the ``Trainer`` -overwrites with the next minibatch every step. Passing a directory of Parquet -shards far larger than RAM still gives a model whose resident footprint is one -batch. +With an out-of-core source the full data never enters RAM. The model graph +observes only a ``(batch_size, *sample_shape)`` ``pm.Data`` placeholder that the +``Trainer`` overwrites with the next minibatch every step. Passing a directory +of Parquet shards far larger than RAM still gives a model whose resident +footprint is one batch. The unbiased-gradient rescaling is the same as for ``pm.Minibatch``: the observed log-likelihood must be scaled by ``N / batch_size`` through the existing @@ -84,7 +84,9 @@ from __future__ import annotations +import glob import numbers +import os import warnings from collections.abc import Callable, Iterable, Iterator @@ -92,6 +94,7 @@ import numpy as np from pymc.model import modelcontext +from pymc.variational.inference import Inference from pymc.variational.inference import fit as _fit __all__ = ["DataLoader", "IterableDataset", "Trainer", "parquet_source", "shuffle_buffer"] @@ -158,9 +161,11 @@ class DataLoader: ``50 * batch_size``. Ignored when ``shuffle=False``. seed : int, optional Seed for the shuffle buffer (ignored when ``shuffle=False``). - sample_shape : tuple of int, default () + sample_shape : tuple of int, optional Trailing shape of a single observation. ``()`` for scalar observations, ``(k,)`` to stream ``k`` columns (e.g. features + the observed column). + Defaults to ``dataset.shape[1:]`` for a raw ``np.ndarray`` source (its + rows are the samples, like torch's ``TensorDataset``), else ``()``. dtype : str, default "float64" Dtype each prepared batch is cast to; match the dtype of the ``pm.Data`` placeholder the batches are streamed into. @@ -174,7 +179,10 @@ class DataLoader: warns at construction and a non-positive value raises (it would otherwise silently disable or invert the rescaling). preprocess_fn : callable, optional - Pure transform applied to each batch before it is yielded. + Pure transform applied to each batch before validation (e.g. + normalization). It must preserve the row count and ``sample_shape``; + to select columns, do it at the source instead + (``parquet_source(columns=...)``). """ def __init__( @@ -185,13 +193,18 @@ def __init__( shuffle: bool = False, buffer_size: int | None = None, seed: int | None = None, - sample_shape: tuple[int, ...] = (), + sample_shape: tuple[int, ...] | None = None, dtype: str = "float64", total_size: int | str | None = None, preprocess_fn: Callable[[np.ndarray], np.ndarray] | None = None, ): if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") + if sample_shape is None: + # A raw array is rows-of-samples; without this default a 2-D array + # would be read as blocks of scalars and silently flattened. + sample_shape = dataset.shape[1:] if isinstance(dataset, np.ndarray) else () + sample_shape = tuple(sample_shape) raw_factory = _make_factory(dataset) source_factory = raw_factory @@ -201,7 +214,7 @@ def __init__( # shuffle_buffer concatenates yields along the leading axis, so single # samples must be promoted to one-row blocks before shuffling. source_factory = shuffle_buffer( - _block_factory(raw_factory, tuple(sample_shape)), + _block_factory(raw_factory, sample_shape), buffer_size=buffer_size, batch_size=batch_size, seed=seed, @@ -213,7 +226,7 @@ def __init__( raise ValueError(f"total_size string must be 'auto', got {total_size!r}") # Count the unshuffled source: the shuffle wrapper drops the trailing # partial batch, so counting through it would undercount N. - total_size = _auto_total_size(raw_factory, dataset, tuple(sample_shape)) + total_size = _auto_total_size(raw_factory, dataset, sample_shape) elif total_size is None: warnings.warn( "DataLoader created with total_size=None: the minibatch " @@ -233,7 +246,7 @@ def __init__( # Plain Python ints: create_minibatch_rv rejects np.int64 for total_size. self._batch_size = int(batch_size) - self._sample_shape = tuple(sample_shape) + self._sample_shape = sample_shape self._dtype = dtype self._total_size = None if total_size is None else int(total_size) self._preprocess_fn = preprocess_fn @@ -267,8 +280,10 @@ def _rebatched(self) -> Iterator[np.ndarray]: def __iter__(self) -> Iterator[np.ndarray]: """Yield one epoch of validated ``(batch_size, *sample_shape)`` minibatches. - This is the stream :class:`Trainer` pushes into the model's ``pm.Data`` - placeholder via ``set_data``. Re-iterate the loader for another epoch. + The same batches the :class:`Trainer` streams into the model's ``pm.Data`` + placeholder (it consumes them through an accounting wrapper, so plain + iteration leaves the counters untouched). Re-iterate the loader for + another epoch. """ for batch in self._rebatched(): yield self._prepare(batch) @@ -278,8 +293,8 @@ def __len__(self) -> int: ``total_size=len(loader)`` is how the model gets the ``N / batch_size`` rescaling. Note this returns the row count ``N``, not the batch count - (``ceil(N / batch_size)``) that ``torch.utils.data.DataLoader.__len__`` - returns; ``total_size`` needs ``N``. :attr:`total_size` is the same value. + that ``torch.utils.data.DataLoader.__len__`` returns; ``total_size`` + needs ``N``. :attr:`total_size` is the same value. """ if self._total_size is None: raise TypeError( @@ -315,17 +330,24 @@ def _prepare(self, batch: np.ndarray) -> np.ndarray: return np.array(batch, dtype=self._dtype) def _maybe_warn_total_size(self) -> None: - """Warn once if total_size differs from the rows seen in one pass by more than 10%.""" + """Warn once if ``total_size`` is inconsistent with the rows seen in one pass. + + A correct ``N`` satisfies ``seen <= N < seen + batch_size`` after a full + pass (the trailing partial batch is dropped), so that window never warns; + outside it a 10% slack absorbs sources that are only approximately sized. + """ if self._warned_size or self._total_size is None: return self._warned_size = True seen = self._rows_streamed - if seen and abs(self._total_size - seen) > 0.1 * seen: + if not seen or seen <= self._total_size < seen + self._batch_size: + return + if abs(self._total_size - seen) > 0.1 * seen: warnings.warn( f"total_size={self._total_size} disagrees with the {seen} rows streamed " f"in one full pass; the N/batch_size rescaling, and therefore the " - f"posterior width, is likely wrong. Pass the true dataset size, or " - f"total_size='auto'.", + f"posterior width, is likely wrong. Pass the true dataset size (or, if " + f"'auto' resolved it from the source's n_rows, fix that attribute).", UserWarning, stacklevel=3, ) @@ -352,8 +374,8 @@ def _validate(self, batch: np.ndarray) -> None: class Trainer: """Drive variational inference over a :class:`DataLoader` without user callbacks. - Follows the design in PyMC's variational-inference rework (Grabowski, *VI - Overview*) and PyTorch Lightning: the ``Trainer`` owns the training loop, the + Follows the design in PyMC's variational-inference rework and PyTorch + Lightning: the ``Trainer`` owns the training loop, the :class:`DataLoader` owns batching (and ``len(dataloader)`` is the dataset size ``N``), and the model owns the math. The model exposes a ``pm.Data`` placeholder; the ``Trainer`` streams minibatches into it with ``model.set_data`` once per @@ -361,10 +383,12 @@ class Trainer: Parameters ---------- - method : str, default "advi" - Variational method, forwarded to :func:`pymc.fit` (``"advi"``, - ``"fullrank_advi"``, ...). Once the VI rework lands this will also accept - an inference instance (e.g. ``ADVI()``); a string drives today's ``pm.fit``. + method : str or Inference, default "advi" + Variational method, forwarded to :func:`pymc.fit`: a name (``"advi"``, + ``"fullrank_advi"``, ...) or an :class:`~pymc.variational.inference.Inference` + instance. ``pm.fit`` applies ``model`` and ``random_seed`` only to a name; + an instance is already bound to a model, so configure it at construction + (e.g. ``ADVI(random_seed=...)``). dataloader : DataLoader The minibatch source. ``len(dataloader)`` is ``N``; the model should pass it to the observed distribution's ``total_size``. @@ -402,7 +426,7 @@ class Trainer: def __init__( self, *, - method: str = "advi", + method: str | Inference = "advi", dataloader: DataLoader, model=None, data_name: str = "batch", @@ -414,16 +438,13 @@ def __init__( self.data_name = data_name self._fit_kwargs = fit_kwargs - def fit( - self, - n: int = 10_000, - *, - random_seed: int | None = None, - progressbar: bool = False, - **kwargs, - ): + def fit(self, n: int = 10_000, **kwargs): """Fit for ``n`` steps, streaming minibatches into the model's placeholder. + Keyword arguments are forwarded to :func:`pymc.fit` on top of the + constructor's ``fit_kwargs`` (per-call wins); ``progressbar`` defaults to + ``False`` unless either sets it. + Returns ------- :class:`Approximation` @@ -435,6 +456,13 @@ def fit( f"Trainer needs a DataLoader for `dataloader`, got {type(loader).__name__}." ) model = modelcontext(self.model) + if self.data_name not in model: + # Checked before the stream starts so no batch is consumed (and no + # counter advances) on a typo. + raise KeyError( + f"data_name {self.data_name!r} is not a variable in the model; it " + f"must name the pm.Data placeholder the minibatches are streamed into." + ) def _stream() -> Iterator[np.ndarray]: while True: @@ -454,6 +482,7 @@ def _advance(*_): model.set_data(self.data_name, next(batches)) merged = {**self._fit_kwargs, **kwargs} + merged.setdefault("progressbar", False) # User callbacks (e.g. convergence trackers) are appended after the # internal advance instead of colliding with it on the keyword. user_callbacks = merged.pop("callbacks", None) or [] @@ -461,8 +490,6 @@ def _advance(*_): n, method=self.method, model=model, - random_seed=random_seed, - progressbar=progressbar, callbacks=[_advance, *user_callbacks], **merged, ) @@ -490,12 +517,12 @@ def shuffle_buffer( It does not by itself fix a strongly time/row-ordered stream (a bounded buffer only block-shuffles such data); pre-shuffle on disk, or interleave - shards into ``chunk_source``, for that. ``buffer_size`` is a lower bound: the - buffer always accumulates at least ``max(buffer_size, batch_size)`` rows before - emitting (so a ``buffer_size`` smaller than ``batch_size`` still yields full - batches instead of silently dropping the stream), and a single chunk larger - than that is taken whole, so the buffer holds at most - ``max(buffer_size, batch_size, largest_chunk_rows)`` rows. + shards into ``chunk_source``, for that. ``buffer_size`` is a lower bound: + each fill accumulates at least ``max(buffer_size, batch_size)`` rows before + shuffling (so a ``buffer_size`` smaller than ``batch_size`` still yields full + batches; the final fill stops at whatever the source has left), and the chunk + that crosses the threshold is kept whole, so the buffer holds fewer than + ``max(buffer_size, batch_size)`` plus one chunk's rows. Each epoch (each call of the returned factory) draws a fresh permutation from a sub-stream of ``seed``, so the shuffle order differs across epochs while @@ -742,13 +769,18 @@ def parquet_source( dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or wrap in :func:`shuffle_buffer`) to get shuffled batches. """ - import glob as _glob - import os - + # pyarrow is an optional dependency, so it is imported on use. import pyarrow.parquet as pq - paths = sorted(_glob.glob(os.path.join(directory, pattern))) + paths = sorted(glob.glob(os.path.join(directory, pattern))) if not paths: raise ValueError(f"no Parquet files match {os.path.join(directory, pattern)!r}") + if columns is not None: + available = set(pq.read_schema(paths[0]).names) + missing = sorted(set(columns) - available) + if missing: + raise ValueError( + f"columns {missing} not found in {paths[0]!r}; available: {sorted(available)}" + ) n_rows = sum(pq.read_metadata(p).num_rows for p in paths) return _ParquetDataset(paths, columns, n_rows) diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 5a8da8f2e7..115d73fc25 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -442,3 +442,75 @@ def test_iterable_dataset_base_is_abstract(): """The base class is a contract: __iter__ must be overridden.""" with pytest.raises(NotImplementedError): iter(IterableDataset()) + + +def test_raw_2d_array_infers_sample_shape(): + """A raw 2-D array defaults sample_shape to its trailing shape, so the + VI-rework sketch ``DataLoader(arr, batch_size=...)`` batches rows instead of + flattening them into scalars.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader(data, batch_size=8, total_size="auto") + assert ds.total_size == 20 + batches = list(ds) + assert [b.shape for b in batches] == [(8, 2), (8, 2)] + np.testing.assert_array_equal(np.concatenate(batches), data[:16]) + + +def test_explicit_sample_shape_overrides_inference(): + """An explicit sample_shape=() reads each row of a 2-D array as a block of + scalar samples, the pre-inference behavior.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + ds = DataLoader(data, batch_size=8, sample_shape=(), total_size=40) + batches = list(ds) + assert [b.shape for b in batches] == [(8,)] * 5 + + +def test_trainer_accepts_inference_instance(): + """An Inference instance is forwarded to pm.fit unchanged; it is bound to + the model it was built under, so the Trainer only streams the batches.""" + data = np.ones((4, 1)) + loader = DataLoader(lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + approx = Trainer(method=pm.ADVI(random_seed=0), dataloader=loader).fit(5) + assert len(approx.hist) == 5 + np.testing.assert_array_equal(model["batch"].get_value(), data) + + +def test_constructor_fit_kwargs_take_random_seed(): + """random_seed works as a constructor default, as the docstring promises, + and a per-call value overrides the constructor's.""" + data = np.ones((4, 1)) + + def fit_with(ctor_kwargs, fit_kwargs): + loader = DataLoader( + lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4 + ) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + return Trainer(method="advi", dataloader=loader, data_name="batch", **ctor_kwargs).fit( + 5, **fit_kwargs + ) + + a = fit_with({"random_seed": 7}, {}) + b = fit_with({"random_seed": 0}, {"random_seed": 7}) + np.testing.assert_array_equal(a.hist, b.hist) + + +def test_unknown_data_name_raises_before_consuming(): + """A data_name that is not in the model raises a guided KeyError before any + batch is pulled from the loader.""" + loader = DataLoader( + lambda: iter([np.zeros((4, 1))] * 3), batch_size=4, sample_shape=(1,), total_size=4 + ) + with pm.Model(): + pm.Normal("mu", 0, 1) + with pytest.raises(KeyError, match="pm.Data placeholder"): + Trainer(method="advi", dataloader=loader, data_name="nope").fit(2) + assert loader.batches_seen == 0 + assert loader.rows_streamed == 0 diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index 0d092cc257..759460d922 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -108,7 +108,10 @@ def test_auto_rejects_factory_returning_same_one_shot_iterator(): re-readable; the counting pass detects and refuses it.""" data = np.zeros((20, 1)) one_shot = (data[i : i + 4] for i in range(0, 20, 4)) - with pytest.raises(ValueError, match="fresh iterator"): + with ( + pytest.warns(UserWarning, match="counting pass"), + pytest.raises(ValueError, match="fresh iterator"), + ): DataLoader(lambda: one_shot, batch_size=4, sample_shape=(1,), total_size="auto") @@ -225,3 +228,33 @@ def test_stream_batches_updates_counters_and_warns_on_wrong_total_size(): assert len(batches) == 4 assert ds.batches_seen == 4 assert ds.rows_streamed == 20 + + +def test_sanity_silent_when_drop_last_truncates(): + """An exactly-correct total_size does not warn when batch_size does not + divide N: the trailing partial batch is dropped by design.""" + data = np.arange(25, dtype="float64").reshape(25, 1) + ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) + + +def test_sanity_silent_for_auto_resolved_non_divisible_n(): + """total_size='auto' must not warn against the N it just resolved.""" + data = np.arange(25, dtype="float64").reshape(25, 1) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size="auto") + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) + + +def test_parquet_source_rejects_unknown_columns(tmp_path): + """A typo in columns= raises a clear ValueError at construction instead of a + pyarrow error at first iteration.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0], "b": [2.0]}), f"{tmp_path}/p.parquet") + with pytest.raises(ValueError, match="not found"): + parquet_source(str(tmp_path), columns=["a", "nope"]) From 4bbb47096b5516240a6938d01c57be5c403cf31f Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Thu, 11 Jun 2026 10:45:35 -0500 Subject: [PATCH 23/27] Fix batch accounting and parquet edge cases from review - Trainer.fit(n) consumes exactly n batches: the advance after the final step is skipped, so a finite source is not over-consumed - the total_size sanity check counts the pass that completed instead of the cumulative row counter, which inflated across partial streams - parquet_source freezes the column order at construction and reads one row group at a time, so a permuted shard schema cannot silently swap features and peak read memory is a row group, not a file - warn at construction when a fixed-order loader would drop the same non-divisible tail every pass - total_size='auto' probes that the factory can actually be re-read, catching factories that close over an already-consumed iterator - document the shuffle-buffer transient concatenation copy and the full-buffer case --- pymc/variational/streaming.py | 132 +++++++++++++------ tests/variational/test_streaming.py | 47 ++++++- tests/variational/test_streaming_autosize.py | 107 ++++++++++++++- 3 files changed, 238 insertions(+), 48 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 1ce013d884..22ab5b1d32 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -33,11 +33,12 @@ ``Trainer(method=..., dataloader=...).fit(n)`` streams each minibatch into the model's ``pm.Data`` placeholder with ``set_data``. -With an out-of-core source the full data never enters RAM. The model graph -observes only a ``(batch_size, *sample_shape)`` ``pm.Data`` placeholder that the -``Trainer`` overwrites with the next minibatch every step. Passing a directory -of Parquet shards far larger than RAM still gives a model whose resident -footprint is one batch. +With bounded source chunks the full data never sits in RAM at once. The model +graph observes only a ``(batch_size, *sample_shape)`` ``pm.Data`` placeholder +that the ``Trainer`` overwrites with the next minibatch every step. Passing a +directory of Parquet shards far larger than RAM still gives a model whose +resident footprint is one batch (:func:`parquet_source` reads one row group at +a time). The unbiased-gradient rescaling is the same as for ``pm.Minibatch``: the observed log-likelihood must be scaled by ``N / batch_size`` through the existing @@ -46,11 +47,18 @@ model passes ``total_size=len(loader)``. (Folding that scaling into the inference step, so it drops out of the model body, is the next step in PyMC's VI rework.) +Batches have exactly ``batch_size`` rows, so each pass drops the final +``N mod batch_size`` rows (torch's ``drop_last``). With ``shuffle=True`` that +remainder is re-drawn every epoch, so all rows participate across epochs; with +a source that replays a fixed order, the same rows would be dropped every pass, +and the loader warns at construction. + One difference from ``pm.Minibatch`` is shuffling. ``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its minibatches are i.i.d. by construction. A streaming source is only as well mixed as the order it yields rows in: reading time/row-ordered data through a -bounded buffer is merely a block-shuffle and biases the variational posterior. +bounded buffer is merely a block-shuffle, and the resulting non-representative +minibatches can bias the variational posterior. Pre-shuffle the data once on disk (or interleave shards) and/or pass ``shuffle=True``. @@ -147,7 +155,9 @@ class DataLoader: re-batches them, in order, to exactly ``batch_size`` rows. Trailing rows that do not fill a final batch are dropped at the end of a pass, like ``drop_last=True`` in PyTorch (required here because the model observes a - fixed-shape placeholder). + fixed-shape placeholder). With ``shuffle=True`` the dropped remainder + differs per epoch; with a fixed replay order it is the same rows every + pass (warned at construction). batch_size : int Leading dimension of every yielded minibatch. shuffle : bool, default False @@ -158,7 +168,8 @@ class DataLoader: docstring). buffer_size : int, optional Shuffle-buffer size in rows when ``shuffle=True``. Defaults to - ``50 * batch_size``. Ignored when ``shuffle=False``. + ``50 * batch_size``. Ignored when ``shuffle=False``. A buffer at least + as large as the dataset holds all of it in memory (a full shuffle). seed : int, optional Seed for the shuffle buffer (ignored when ``shuffle=False``). sample_shape : tuple of int, optional @@ -244,6 +255,21 @@ def __init__( f"{total_size!r}." ) + if total_size is not None and not shuffle and int(total_size) % int(batch_size): + # Exact-size batches drop the tail of every pass (drop_last). Without + # shuffling the drop is not re-randomized, so a source that replays + # the same order would never show those rows to the fit at all. + warnings.warn( + f"shuffle=False with total_size={int(total_size)} not divisible by " + f"batch_size={int(batch_size)}: the trailing " + f"{int(total_size) % int(batch_size)} rows are dropped every pass, and " + f"if the source replays the same order each epoch they are never seen " + f"by the fit. Pass shuffle=True (the dropped remainder is then re-drawn " + f"each epoch) or choose a batch_size that divides the dataset size.", + UserWarning, + stacklevel=2, + ) + # Plain Python ints: create_minibatch_rv rejects np.int64 for total_size. self._batch_size = int(batch_size) self._sample_shape = sample_shape @@ -311,12 +337,14 @@ def _stream_batches(self) -> Iterator[np.ndarray]: the epoch boundary. :meth:`__iter__` stays side-effect-free so plain iteration does not mutate counters. """ + seen_this_pass = 0 for batch in self._rebatched(): prepared = self._prepare(batch) self._batches_seen += 1 self._rows_streamed += int(prepared.shape[0]) + seen_this_pass += int(prepared.shape[0]) yield prepared - self._maybe_warn_total_size() + self._maybe_warn_total_size(seen_this_pass) def _prepare(self, batch: np.ndarray) -> np.ndarray: """Preprocess, validate, and return an owned copy of one batch. @@ -329,17 +357,19 @@ def _prepare(self, batch: np.ndarray) -> np.ndarray: self._validate(batch) return np.array(batch, dtype=self._dtype) - def _maybe_warn_total_size(self) -> None: - """Warn once if ``total_size`` is inconsistent with the rows seen in one pass. + def _maybe_warn_total_size(self, seen: int) -> None: + """Warn once if ``total_size`` is inconsistent with the rows of one full pass. - A correct ``N`` satisfies ``seen <= N < seen + batch_size`` after a full - pass (the trailing partial batch is dropped), so that window never warns; - outside it a 10% slack absorbs sources that are only approximately sized. + ``seen`` is the row count of the pass that just completed (not the + cumulative :attr:`rows_streamed`, which keeps growing across partial + streams and earlier fits). A correct ``N`` satisfies + ``seen <= N < seen + batch_size`` after a full pass (the trailing partial + batch is dropped), so that window never warns; outside it a 10% slack + absorbs sources that are only approximately sized. """ if self._warned_size or self._total_size is None: return self._warned_size = True - seen = self._rows_streamed if not seen or seen <= self._total_size < seen + self._batch_size: return if abs(self._total_size - seen) > 0.1 * seen: @@ -441,9 +471,11 @@ def __init__( def fit(self, n: int = 10_000, **kwargs): """Fit for ``n`` steps, streaming minibatches into the model's placeholder. - Keyword arguments are forwarded to :func:`pymc.fit` on top of the - constructor's ``fit_kwargs`` (per-call wins); ``progressbar`` defaults to - ``False`` unless either sets it. + Exactly ``n`` minibatches are consumed: the first seeds the placeholder + before step 0, and the advance after the final step is skipped. Keyword + arguments are forwarded to :func:`pymc.fit` on top of the constructor's + ``fit_kwargs`` (per-call wins); ``progressbar`` defaults to ``False`` + unless either sets it. Returns ------- @@ -478,8 +510,16 @@ def _stream() -> Iterator[np.ndarray]: # so without this the first step would train on the placeholder's contents. model.set_data(self.data_name, next(batches)) + steps_done = 0 + def _advance(*_): - model.set_data(self.data_name, next(batches)) + # pm.fit fires callbacks after every step including the last; skip the + # advance there so exactly n batches are consumed (an (n+1)-th would be + # fetched, never trained on, and could exhaust a finite source). + nonlocal steps_done + steps_done += 1 + if steps_done < n: + model.set_data(self.data_name, next(batches)) merged = {**self._fit_kwargs, **kwargs} merged.setdefault("progressbar", False) @@ -522,7 +562,9 @@ def shuffle_buffer( shuffling (so a ``buffer_size`` smaller than ``batch_size`` still yields full batches; the final fill stops at whatever the source has left), and the chunk that crosses the threshold is kept whole, so the buffer holds fewer than - ``max(buffer_size, batch_size)`` plus one chunk's rows. + ``max(buffer_size, batch_size)`` plus one chunk's rows. Concatenating a fill + into one shuffleable array transiently allocates a second copy of those + rows, so peak allocation is about twice that bound. Each epoch (each call of the returned factory) draws a fresh permutation from a sub-stream of ``seed``, so the shuffle order differs across epochs while @@ -724,13 +766,17 @@ def _auto_total_size( count += 1 if a.shape == sample_shape else int(a.shape[0]) if count <= 0: raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") - if factory() is first_iter: - # A genuine factory yields a fresh iterator each call; one that returns the - # same exhausted iterator would leave the loader with nothing to stream. + # A genuine factory yields a fresh, non-empty stream each call; one that + # returns the same exhausted iterator (or a new generator over consumed + # state) would leave the loader with nothing to stream. The probe costs one + # chunk, which the counting pass has already dwarfed. + second_iter = factory() + if second_iter is first_iter or next(second_iter, None) is None: raise ValueError( - "total_size='auto' got a factory that returns the same one-shot iterator " - "each call; pass a factory that creates a fresh iterator each call, or " - "total_size=N explicitly." + "total_size='auto' counted rows but the factory's next stream was empty " + "(it returns the same one-shot iterator, or closes over an already-" + "consumed one); pass a factory that creates a fresh iterator each call, " + "or total_size=N explicitly." ) return count @@ -738,11 +784,13 @@ def _auto_total_size( class _ParquetDataset(IterableDataset): """An :class:`IterableDataset` over a directory of Parquet shards. - Yields one ``(rows, n_columns)`` array per file and exposes - :attr:`n_rows` read from Parquet metadata (no data scan). + Yields one ``(rows, n_columns)`` array per row group (so peak read memory is + one row group, not one file), in the fixed column order chosen at + construction, and exposes :attr:`n_rows` read from Parquet metadata (no data + scan). """ - def __init__(self, paths: list[str], columns: list[str] | None, n_rows: int): + def __init__(self, paths: list[str], columns: list[str], n_rows: int): self._paths = paths self._columns = columns self.n_rows = n_rows @@ -751,8 +799,12 @@ def __iter__(self) -> Iterator[np.ndarray]: import pyarrow.parquet as pq for path in self._paths: - table = pq.read_table(path, columns=self._columns) - yield np.column_stack([table.column(c).to_numpy() for c in table.column_names]) + file = pq.ParquetFile(path) + for i in range(file.metadata.num_row_groups): + table = file.read_row_group(i, columns=self._columns) + # Stack by the frozen column names, not the file's own order, so + # a shard with a permuted schema cannot silently swap features. + yield np.column_stack([table.column(c).to_numpy() for c in self._columns]) def parquet_source( @@ -763,8 +815,12 @@ def parquet_source( ) -> _ParquetDataset: """An :class:`IterableDataset` over a directory of Parquet files. - Yields one ``(rows, n_columns)`` array per file, and carries an - ``n_rows`` attribute read from Parquet metadata (no data scan) so that + Yields one ``(rows, n_columns)`` array per row group (one or more per file), + so peak read memory is one row group, not one file. The column order is + frozen at construction — ``columns`` if given, else the first file's schema + order — and every shard is read in that order, so a shard with a permuted + schema cannot silently reorder features mid-stream. Carries an ``n_rows`` + attribute read from Parquet metadata (no data scan) so that ``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or wrap in :func:`shuffle_buffer`) to get shuffled batches. @@ -775,12 +831,14 @@ def parquet_source( paths = sorted(glob.glob(os.path.join(directory, pattern))) if not paths: raise ValueError(f"no Parquet files match {os.path.join(directory, pattern)!r}") - if columns is not None: - available = set(pq.read_schema(paths[0]).names) - missing = sorted(set(columns) - available) + schema_names = list(pq.read_schema(paths[0]).names) + if columns is None: + columns = schema_names + else: + missing = sorted(set(columns) - set(schema_names)) if missing: raise ValueError( - f"columns {missing} not found in {paths[0]!r}; available: {sorted(available)}" + f"columns {missing} not found in {paths[0]!r}; available: {sorted(schema_names)}" ) n_rows = sum(pq.read_metadata(p).num_rows for p in paths) return _ParquetDataset(paths, columns, n_rows) diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 115d73fc25..48f927e189 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -34,9 +34,11 @@ def factory(): def test_plain_loader_rebatches_arbitrary_blocks(): """Blocks of 3 with batch_size=4 are re-batched in order; the trailing rows - that cannot fill a final batch are dropped (drop_last semantics).""" + that cannot fill a final batch are dropped (drop_last semantics), which the + unshuffled loader warns about up front.""" data = np.arange(20, dtype="float64").reshape(10, 2) - ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) + with pytest.warns(UserWarning, match="dropped every pass"): + ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) batches = list(ds) assert [b.shape for b in batches] == [(4, 2), (4, 2)] np.testing.assert_array_equal(np.concatenate(batches), data[:8]) @@ -47,7 +49,10 @@ def test_raw_array_source_like_vi_rework_sketch(): ``Dataloader(np.random.normal(...), batch_size=...)``: rows are yielded one sample at a time, re-batched, and counted as rows by total_size='auto'.""" data = np.arange(40, dtype="float64").reshape(20, 2) - with pytest.warns(UserWarning, match="counting pass"): + with ( + pytest.warns(UserWarning, match="counting pass"), + pytest.warns(UserWarning, match="dropped every pass"), + ): ds = DataLoader(data, batch_size=8, sample_shape=(2,), total_size="auto") assert ds.total_size == 20 batches = list(ds) @@ -128,7 +133,7 @@ def test_total_size_rescales_logp_like_minibatch(): exactly N / batch_size, through the same create_minibatch_rv mechanism as pm.Minibatch: logp(scaled) == logp(plain) * N / batch_size.""" rng = np.random.default_rng(0) - N, bs = 1000, 16 + N, bs = 1000, 20 data = rng.normal(size=(bs, 1)) loader = DataLoader(lambda: iter([data]), batch_size=bs, sample_shape=(1,), total_size=N) @@ -449,7 +454,10 @@ def test_raw_2d_array_infers_sample_shape(): VI-rework sketch ``DataLoader(arr, batch_size=...)`` batches rows instead of flattening them into scalars.""" data = np.arange(40, dtype="float64").reshape(20, 2) - with pytest.warns(UserWarning, match="counting pass"): + with ( + pytest.warns(UserWarning, match="counting pass"), + pytest.warns(UserWarning, match="dropped every pass"), + ): ds = DataLoader(data, batch_size=8, total_size="auto") assert ds.total_size == 20 batches = list(ds) @@ -502,6 +510,35 @@ def fit_with(ctor_kwargs, fit_kwargs): np.testing.assert_array_equal(a.hist, b.hist) +def test_fit_consumes_exactly_n_batches(): + """fit(n) consumes exactly n minibatches: one seeds the placeholder before + step 0 and the advance after the final step is skipped, so an (n+1)-th + batch is never fetched.""" + blocks = [np.full((2, 1), float(i)) for i in range(2)] + loader = DataLoader(lambda: iter(blocks), batch_size=2, sample_shape=(1,), total_size=4) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((2, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + Trainer(method="advi", dataloader=loader).fit(3, random_seed=0) + assert loader.batches_seen == 3 + assert loader.rows_streamed == 6 + + +def test_fit_one_step_on_single_batch_one_shot_source(): + """A finite stream with exactly the batches needed must not be over-consumed: + fit(1) on a one-batch, one-shot source trains and returns instead of failing + on a post-final restart.""" + loader = DataLoader(iter([np.ones((2, 1))]), batch_size=2, sample_shape=(1,), total_size=2) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((2, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + approx = Trainer(method="advi", dataloader=loader).fit(1, random_seed=0) + assert len(approx.hist) == 1 + assert loader.batches_seen == 1 + + def test_unknown_data_name_raises_before_consuming(): """A data_name that is not in the model raises a guided KeyError before any batch is pulled from the loader.""" diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index 759460d922..b1115c980f 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -48,9 +48,9 @@ def test_auto_uses_n_rows_fast_path(): """A source-advertised .n_rows is trusted without a counting pass.""" data = np.zeros((8, 1)) f = _factory(data, 4) - f.n_rows = 999 + f.n_rows = 1000 ds = DataLoader(f, batch_size=4, sample_shape=(1,), total_size="auto") - assert ds.total_size == 999 + assert ds.total_size == 1000 def test_auto_rejects_one_shot_iterator(): @@ -231,10 +231,12 @@ def test_stream_batches_updates_counters_and_warns_on_wrong_total_size(): def test_sanity_silent_when_drop_last_truncates(): - """An exactly-correct total_size does not warn when batch_size does not - divide N: the trailing partial batch is dropped by design.""" + """An exactly-correct total_size does not warn at the epoch boundary when + batch_size does not divide N: the trailing partial batch is dropped by + design (the fixed-order construction warning is separate).""" data = np.arange(25, dtype="float64").reshape(25, 1) - ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) + with pytest.warns(UserWarning, match="dropped every pass"): + ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) list(ds._stream_batches()) @@ -243,13 +245,106 @@ def test_sanity_silent_when_drop_last_truncates(): def test_sanity_silent_for_auto_resolved_non_divisible_n(): """total_size='auto' must not warn against the N it just resolved.""" data = np.arange(25, dtype="float64").reshape(25, 1) - with pytest.warns(UserWarning, match="counting pass"): + with ( + pytest.warns(UserWarning, match="counting pass"), + pytest.warns(UserWarning, match="dropped every pass"), + ): ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size="auto") with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) list(ds._stream_batches()) +def test_sanity_check_counts_the_completed_pass_not_cumulative_rows(): + """A partially consumed stray stream must not inflate the epoch-boundary + check: with a correct total_size, the next full pass stays silent.""" + data = np.arange(100, dtype="float64").reshape(100, 1) + ds = DataLoader(_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=100) + stray = ds._stream_batches() + for _ in range(3): + next(stray) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) + + +def test_sanity_check_not_fooled_by_cumulative_rows_matching_total_size(): + """The converse: a wrong total_size that happens to equal the cumulative + row counter must still warn after a true full pass.""" + data = np.arange(100, dtype="float64").reshape(100, 1) + ds = DataLoader(_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=130) + stray = ds._stream_batches() + for _ in range(3): + next(stray) + with pytest.warns(UserWarning, match="disagrees with"): + list(ds._stream_batches()) + + +def test_fixed_order_non_divisible_total_size_warns_at_construction(): + """shuffle=False with batch_size not dividing N would drop the same trailing + rows every pass, so the loader says so up front.""" + data = np.arange(25, dtype="float64").reshape(25, 1) + with pytest.warns(UserWarning, match="dropped every pass"): + DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) + + +def test_shuffled_non_divisible_total_size_is_silent_at_construction(): + """With shuffle=True the dropped remainder is re-drawn each epoch, so the + fixed-tail warning does not apply.""" + data = np.arange(25, dtype="float64").reshape(25, 1) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + DataLoader( + _factory(data, 5), + batch_size=10, + shuffle=True, + buffer_size=20, + seed=0, + sample_shape=(1,), + total_size=25, + ) + + +def test_auto_rejects_factory_closing_over_consumed_iterator(): + """A generator function over a one-shot iterator returns a new (so not + identical) but empty stream after the counting pass; the re-read probe + catches it at construction.""" + data = np.zeros((20, 1)) + underlying = iter([data[i : i + 4] for i in range(0, 20, 4)]) + + def gen(): + yield from underlying + + with ( + pytest.warns(UserWarning, match="counting pass"), + pytest.raises(ValueError, match="fresh iterator"), + ): + DataLoader(gen, batch_size=4, sample_shape=(1,), total_size="auto") + + +def test_parquet_source_freezes_column_order_across_permuted_shards(tmp_path): + """A shard whose schema permutes the columns is read back in the first + shard's order instead of silently swapping features.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0, 1.0], "b": [10.0, 10.0]}), f"{tmp_path}/p0.parquet") + pq.write_table(pa.table({"b": [20.0, 20.0], "a": [2.0, 2.0]}), f"{tmp_path}/p1.parquet") + blocks = list(parquet_source(str(tmp_path))) + np.testing.assert_array_equal(blocks[0], [[1.0, 10.0], [1.0, 10.0]]) + np.testing.assert_array_equal(blocks[1], [[2.0, 20.0], [2.0, 20.0]]) + + +def test_parquet_source_streams_row_groups_not_whole_files(tmp_path): + """A multi-row-group file is yielded one row group at a time, so peak read + memory is a row group rather than the whole file.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": np.arange(30.0)}), f"{tmp_path}/p.parquet", row_group_size=10) + blocks = list(parquet_source(str(tmp_path))) + assert [b.shape for b in blocks] == [(10, 1), (10, 1), (10, 1)] + np.testing.assert_array_equal(np.concatenate(blocks).ravel(), np.arange(30.0)) + + def test_parquet_source_rejects_unknown_columns(tmp_path): """A typo in columns= raises a clear ValueError at construction instead of a pyarrow error at first iteration.""" From 4b977746bdacac5b905b75f44b172c3bb4d00a76 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Fri, 12 Jun 2026 07:35:00 -0500 Subject: [PATCH 24/27] Address fourth review round: refine support, boundary check, parquet diagnostics - the internal advance skips only fit's own final step, so Inference.refine on a method instance keeps streaming instead of silently retraining on the last batch - keep the rebatcher one batch ahead in the accounting stream, so the total_size sanity check still fires when fit(n) stops exactly at the pass boundary - drop the fixed-order divisibility warning: it false-alarmed on the module's own pre-shuffled-on-disk example and on manual shuffle_buffer wrapping; the drop-last caveat lives in the docs instead - validate n in Trainer.fit (fit(0) consumed the seed batch; fit(-1) failed deep inside PyTensor) - normalize shuffle_buffer's factory output with iter(), which a re-iterable-returning factory would otherwise restart every fill - parquet_source rejects non-numeric columns at construction and names the shard when a later file is missing a frozen column - name the sample_shape remedy in the block-shape error; spell behavior consistently --- pymc/variational/streaming.py | 90 ++++++++++++-------- tests/variational/test_streaming.py | 76 ++++++++++++++--- tests/variational/test_streaming_autosize.py | 63 +++++++------- 3 files changed, 149 insertions(+), 80 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 22ab5b1d32..e3db6ef83e 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -50,8 +50,8 @@ Batches have exactly ``batch_size`` rows, so each pass drops the final ``N mod batch_size`` rows (torch's ``drop_last``). With ``shuffle=True`` that remainder is re-drawn every epoch, so all rows participate across epochs; with -a source that replays a fixed order, the same rows would be dropped every pass, -and the loader warns at construction. +a source that replays a fixed order, the same rows are dropped every pass (after +a one-time on-disk pre-shuffle that fixed remainder is a random subset). One difference from ``pm.Minibatch`` is shuffling. ``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its @@ -127,7 +127,7 @@ class IterableDataset: A plain zero-arg factory (``Callable[[], Iterator[np.ndarray]]``) or any re-iterable is also accepted directly by :class:`DataLoader`; this base class - is only needed when you want to attach behaviour or ``n_rows`` to a custom + is only needed when you want to attach behavior or ``n_rows`` to a custom source. """ @@ -157,7 +157,7 @@ class DataLoader: ``drop_last=True`` in PyTorch (required here because the model observes a fixed-shape placeholder). With ``shuffle=True`` the dropped remainder differs per epoch; with a fixed replay order it is the same rows every - pass (warned at construction). + pass. batch_size : int Leading dimension of every yielded minibatch. shuffle : bool, default False @@ -255,21 +255,6 @@ def __init__( f"{total_size!r}." ) - if total_size is not None and not shuffle and int(total_size) % int(batch_size): - # Exact-size batches drop the tail of every pass (drop_last). Without - # shuffling the drop is not re-randomized, so a source that replays - # the same order would never show those rows to the fit at all. - warnings.warn( - f"shuffle=False with total_size={int(total_size)} not divisible by " - f"batch_size={int(batch_size)}: the trailing " - f"{int(total_size) % int(batch_size)} rows are dropped every pass, and " - f"if the source replays the same order each epoch they are never seen " - f"by the fit. Pass shuffle=True (the dropped remainder is then re-drawn " - f"each epoch) or choose a batch_size that divides the dataset size.", - UserWarning, - stacklevel=2, - ) - # Plain Python ints: create_minibatch_rv rejects np.int64 for total_size. self._batch_size = int(batch_size) self._sample_shape = sample_shape @@ -333,18 +318,26 @@ def _stream_batches(self) -> Iterator[np.ndarray]: """One epoch of prepared minibatches, with accounting (the Trainer's path). Like :meth:`__iter__` but it updates :attr:`batches_seen` / - :attr:`rows_streamed` and runs the one-shot ``total_size`` sanity check at - the epoch boundary. :meth:`__iter__` stays side-effect-free so plain - iteration does not mutate counters. + :attr:`rows_streamed` and runs the one-shot ``total_size`` sanity check on + the pass's final batch. The rebatcher is kept one batch ahead so the check + still fires when a fit stops exactly at the pass boundary; without the + lookahead the generator would be abandoned right before its epilogue. + :meth:`__iter__` stays side-effect-free so plain iteration does not mutate + counters. """ seen_this_pass = 0 - for batch in self._rebatched(): + it = self._rebatched() + batch = next(it, None) + while batch is not None: + following = next(it, None) prepared = self._prepare(batch) self._batches_seen += 1 self._rows_streamed += int(prepared.shape[0]) seen_this_pass += int(prepared.shape[0]) + if following is None: + self._maybe_warn_total_size(seen_this_pass) yield prepared - self._maybe_warn_total_size(seen_this_pass) + batch = following def _prepare(self, batch: np.ndarray) -> np.ndarray: """Preprocess, validate, and return an owned copy of one batch. @@ -482,6 +475,8 @@ def fit(self, n: int = 10_000, **kwargs): :class:`Approximation` The fitted approximation, as returned by :func:`pymc.fit`. """ + if not _is_positive_int(n): + raise ValueError(f"n must be a positive integer (the number of fit steps), got {n!r}") loader = self.dataloader if not isinstance(loader, DataLoader): raise TypeError( @@ -514,11 +509,12 @@ def _stream() -> Iterator[np.ndarray]: def _advance(*_): # pm.fit fires callbacks after every step including the last; skip the - # advance there so exactly n batches are consumed (an (n+1)-th would be - # fetched, never trained on, and could exhaust a finite source). + # advance on this fit's final step so exactly n batches are consumed. + # Only that one call is skipped (not every call past n): Inference.refine + # replays the saved callbacks and must keep streaming fresh batches. nonlocal steps_done steps_done += 1 - if steps_done < n: + if steps_done != n: model.set_data(self.data_name, next(batches)) merged = {**self._fit_kwargs, **kwargs} @@ -580,7 +576,9 @@ def factory() -> Iterator[np.ndarray]: # A fresh sub-stream per epoch: re-iterating reshuffles instead of # replaying one fixed permutation, yet stays reproducible per seed. rng = np.random.default_rng(seed_seq.spawn(1)[0]) - it = chunk_source() + # A factory may return a re-iterable (a list of chunks, ...); normalize so + # each buffer fill continues one stream instead of restarting it forever. + it = iter(chunk_source()) carry: np.ndarray | None = None exhausted = False # Accumulate at least one batch even when buffer_size < batch_size, @@ -628,8 +626,9 @@ def _promote_to_block(a: np.ndarray, sample_shape: tuple[int, ...]) -> np.ndarra return a[None, ...] if a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape: raise ValueError( - f"source yielded shape {a.shape}; expected a single sample of shape " - f"{sample_shape} or a block of shape (rows, *{sample_shape})" + f"source yielded shape {a.shape}; expected one sample of shape " + f"{sample_shape} or a (rows, *sample_shape) block; if the source is " + f"right, declare its trailing shape with DataLoader(sample_shape=...)" ) return a @@ -800,6 +799,12 @@ def __iter__(self) -> Iterator[np.ndarray]: for path in self._paths: file = pq.ParquetFile(path) + missing = [c for c in self._columns if c not in file.schema_arrow.names] + if missing: + # read_row_group(columns=...) silently drops unknown names, so a + # malformed shard must be named here, not surface as a bare + # KeyError with no path. + raise ValueError(f"columns {missing} not found in {path!r}") for i in range(file.metadata.num_row_groups): table = file.read_row_group(i, columns=self._columns) # Stack by the frozen column names, not the file's own order, so @@ -826,19 +831,36 @@ def parquet_source( wrap in :func:`shuffle_buffer`) to get shuffled batches. """ # pyarrow is an optional dependency, so it is imported on use. + import pyarrow as pa import pyarrow.parquet as pq paths = sorted(glob.glob(os.path.join(directory, pattern))) if not paths: raise ValueError(f"no Parquet files match {os.path.join(directory, pattern)!r}") - schema_names = list(pq.read_schema(paths[0]).names) + schema = pq.read_schema(paths[0]) if columns is None: - columns = schema_names + columns = list(schema.names) else: - missing = sorted(set(columns) - set(schema_names)) + missing = sorted(set(columns) - set(schema.names)) if missing: raise ValueError( - f"columns {missing} not found in {paths[0]!r}; available: {sorted(schema_names)}" + f"columns {missing} not found in {paths[0]!r}; available: {sorted(schema.names)}" ) + non_numeric = [ + c + for c in columns + if not ( + pa.types.is_integer(schema.field(c).type) + or pa.types.is_floating(schema.field(c).type) + or pa.types.is_boolean(schema.field(c).type) + ) + ] + if non_numeric: + # A string/dictionary column would turn whole chunks object-dtype and only + # fail later at the batch cast, without naming the column. + raise ValueError( + f"columns {non_numeric} in {paths[0]!r} are not numeric and cannot be " + f"streamed into a float batch; select numeric columns with columns=." + ) n_rows = sum(pq.read_metadata(p).num_rows for p in paths) return _ParquetDataset(paths, columns, n_rows) diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 48f927e189..e6b277dee5 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -34,11 +34,9 @@ def factory(): def test_plain_loader_rebatches_arbitrary_blocks(): """Blocks of 3 with batch_size=4 are re-batched in order; the trailing rows - that cannot fill a final batch are dropped (drop_last semantics), which the - unshuffled loader warns about up front.""" + that cannot fill a final batch are dropped (drop_last semantics).""" data = np.arange(20, dtype="float64").reshape(10, 2) - with pytest.warns(UserWarning, match="dropped every pass"): - ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) + ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) batches = list(ds) assert [b.shape for b in batches] == [(4, 2), (4, 2)] np.testing.assert_array_equal(np.concatenate(batches), data[:8]) @@ -49,10 +47,7 @@ def test_raw_array_source_like_vi_rework_sketch(): ``Dataloader(np.random.normal(...), batch_size=...)``: rows are yielded one sample at a time, re-batched, and counted as rows by total_size='auto'.""" data = np.arange(40, dtype="float64").reshape(20, 2) - with ( - pytest.warns(UserWarning, match="counting pass"), - pytest.warns(UserWarning, match="dropped every pass"), - ): + with pytest.warns(UserWarning, match="counting pass"): ds = DataLoader(data, batch_size=8, sample_shape=(2,), total_size="auto") assert ds.total_size == 20 batches = list(ds) @@ -454,10 +449,7 @@ def test_raw_2d_array_infers_sample_shape(): VI-rework sketch ``DataLoader(arr, batch_size=...)`` batches rows instead of flattening them into scalars.""" data = np.arange(40, dtype="float64").reshape(20, 2) - with ( - pytest.warns(UserWarning, match="counting pass"), - pytest.warns(UserWarning, match="dropped every pass"), - ): + with pytest.warns(UserWarning, match="counting pass"): ds = DataLoader(data, batch_size=8, total_size="auto") assert ds.total_size == 20 batches = list(ds) @@ -539,6 +531,66 @@ def test_fit_one_step_on_single_batch_one_shot_source(): assert loader.batches_seen == 1 +def test_refine_after_fit_keeps_streaming(): + """Inference.refine replays pm.fit's saved callbacks; the internal advance + skips only fit's own final step, so refine keeps streaming fresh batches + instead of going permanently dead and retraining on one batch.""" + data = np.ones((4, 1)) + loader = DataLoader(lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + inference = pm.ADVI(random_seed=0) + Trainer(method=inference, dataloader=loader).fit(3) + assert loader.batches_seen == 3 + inference.refine(4, progressbar=False) + assert loader.batches_seen == 7 + + +def test_total_size_check_fires_when_fit_ends_at_pass_boundary(): + """fit(n) with n exactly the batches in one pass still runs the total_size + sanity check: the stream is kept one batch ahead, so stopping at the + boundary does not abandon the check right before it would fire.""" + data = np.zeros((40, 1)) + loader = DataLoader(_chunks(data, 10), batch_size=10, sample_shape=(1,), total_size=400) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((10, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + with pytest.warns(UserWarning, match="disagrees with"): + Trainer(method="advi", dataloader=loader).fit(4, random_seed=0) + + +def test_fit_rejects_nonpositive_n(): + """fit consumes the seed batch before pm.fit could reject n itself, so a + non-positive n is refused up front, before touching the stream.""" + loader = DataLoader( + lambda: iter([np.zeros((2, 1))]), batch_size=2, sample_shape=(1,), total_size=2 + ) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((2, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + with pytest.raises(ValueError, match="positive integer"): + Trainer(method="advi", dataloader=loader).fit(0) + assert loader.batches_seen == 0 + + +def test_shuffle_buffer_accepts_factory_returning_reiterable(): + """A factory returning a re-iterable (which _make_factory tolerates for the + loader) must not restart per buffer fill and loop forever; the stream is + normalized to a single iterator.""" + data = np.arange(120, dtype="float64").reshape(120, 1) + chunks = [data[i : i + 20] for i in range(0, 120, 20)] + src = shuffle_buffer(lambda: chunks, buffer_size=50, batch_size=10, seed=0) + batches = list(src()) + assert len(batches) == 12 + np.testing.assert_array_equal( + np.sort(np.concatenate([b.ravel() for b in batches])), data.ravel() + ) + + def test_unknown_data_name_raises_before_consuming(): """A data_name that is not in the model raises a guided KeyError before any batch is pulled from the loader.""" diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index b1115c980f..c06df07c7f 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -231,12 +231,10 @@ def test_stream_batches_updates_counters_and_warns_on_wrong_total_size(): def test_sanity_silent_when_drop_last_truncates(): - """An exactly-correct total_size does not warn at the epoch boundary when - batch_size does not divide N: the trailing partial batch is dropped by - design (the fixed-order construction warning is separate).""" + """An exactly-correct total_size does not warn when batch_size does not + divide N: the trailing partial batch is dropped by design.""" data = np.arange(25, dtype="float64").reshape(25, 1) - with pytest.warns(UserWarning, match="dropped every pass"): - ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) + ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) list(ds._stream_batches()) @@ -245,10 +243,7 @@ def test_sanity_silent_when_drop_last_truncates(): def test_sanity_silent_for_auto_resolved_non_divisible_n(): """total_size='auto' must not warn against the N it just resolved.""" data = np.arange(25, dtype="float64").reshape(25, 1) - with ( - pytest.warns(UserWarning, match="counting pass"), - pytest.warns(UserWarning, match="dropped every pass"), - ): + with pytest.warns(UserWarning, match="counting pass"): ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size="auto") with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) @@ -280,31 +275,6 @@ def test_sanity_check_not_fooled_by_cumulative_rows_matching_total_size(): list(ds._stream_batches()) -def test_fixed_order_non_divisible_total_size_warns_at_construction(): - """shuffle=False with batch_size not dividing N would drop the same trailing - rows every pass, so the loader says so up front.""" - data = np.arange(25, dtype="float64").reshape(25, 1) - with pytest.warns(UserWarning, match="dropped every pass"): - DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) - - -def test_shuffled_non_divisible_total_size_is_silent_at_construction(): - """With shuffle=True the dropped remainder is re-drawn each epoch, so the - fixed-tail warning does not apply.""" - data = np.arange(25, dtype="float64").reshape(25, 1) - with warnings.catch_warnings(): - warnings.simplefilter("error", UserWarning) - DataLoader( - _factory(data, 5), - batch_size=10, - shuffle=True, - buffer_size=20, - seed=0, - sample_shape=(1,), - total_size=25, - ) - - def test_auto_rejects_factory_closing_over_consumed_iterator(): """A generator function over a one-shot iterator returns a new (so not identical) but empty stream after the counting pass; the re-read probe @@ -345,6 +315,31 @@ def test_parquet_source_streams_row_groups_not_whole_files(tmp_path): np.testing.assert_array_equal(np.concatenate(blocks).ravel(), np.arange(30.0)) +def test_parquet_source_rejects_non_numeric_columns(tmp_path): + """A string column cannot be streamed into a float batch; the default + all-columns freeze rejects it at construction, naming the column and the + columns= remedy, instead of failing later at the batch cast.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"x": [1.0, 2.0], "id": ["a", "b"]}), f"{tmp_path}/p.parquet") + with pytest.raises(ValueError, match="not numeric"): + parquet_source(str(tmp_path)) + src = parquet_source(str(tmp_path), columns=["x"]) + np.testing.assert_array_equal(next(iter(src)), [[1.0], [2.0]]) + + +def test_parquet_source_names_the_shard_missing_a_column(tmp_path): + """read_row_group silently drops unknown column names, so a later shard + missing a frozen column must raise an error that names that shard.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0], "b": [2.0]}), f"{tmp_path}/p0.parquet") + pq.write_table(pa.table({"a": [3.0]}), f"{tmp_path}/p1.parquet") + src = parquet_source(str(tmp_path)) + with pytest.raises(ValueError, match="p1.parquet"): + list(src) + + def test_parquet_source_rejects_unknown_columns(tmp_path): """A typo in columns= raises a clear ValueError at construction instead of a pyarrow error at first iteration.""" From ab178a0e6454a9f8387280e896467452f43e30dc Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Sat, 13 Jun 2026 13:21:37 -0500 Subject: [PATCH 25/27] Align the DataLoader summary with the module's bounded-chunks wording The class summary still claimed the full dataset never enters memory in the absolute; match the module docstring's bounded-source-chunks framing and fix a double space. --- pymc/variational/streaming.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index e3db6ef83e..2627659db1 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -55,7 +55,7 @@ One difference from ``pm.Minibatch`` is shuffling. ``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its -minibatches are i.i.d. by construction. A streaming source is only as well +minibatches are i.i.d. by construction. A streaming source is only as well mixed as the order it yields rows in: reading time/row-ordered data through a bounded buffer is merely a block-shuffle, and the resulting non-representative minibatches can bias the variational posterior. @@ -143,7 +143,8 @@ class DataLoader: Like ``torch.utils.data.DataLoader``, it batches (and optionally shuffles) an :class:`IterableDataset` into the minibatch stream that :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)`` - is the dataset size ``N``). The full dataset never enters memory. + is the dataset size ``N``). With bounded source chunks the full dataset is + never resident at once. Parameters ---------- From cb85e8cc2d9967787fc8d972974b351d4e996d68 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Sat, 13 Jun 2026 13:41:11 -0500 Subject: [PATCH 26/27] Tighten parquet type checks, fit/refine wording from review - _ParquetDataset checks each shard's column types, so a later shard whose column turned non-numeric is named instead of failing as an opaque float cast downstream (parquet_source only saw the first shard) - the fit docstring no longer says 'exactly n consumed'; it feeds exactly n batches to the model, but the one-batch lookahead can read a re-readable source one batch further - the refine test now uses distinct batch markers and pins the honest resume-not-reseed behavior (its first step reuses fit's last batch) instead of only checking counters on all-ones data --- pymc/variational/streaming.py | 28 ++++++++++++++++-- tests/variational/test_streaming.py | 30 ++++++++++++++------ tests/variational/test_streaming_autosize.py | 13 +++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 2627659db1..9317e61627 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -465,8 +465,11 @@ def __init__( def fit(self, n: int = 10_000, **kwargs): """Fit for ``n`` steps, streaming minibatches into the model's placeholder. - Exactly ``n`` minibatches are consumed: the first seeds the placeholder - before step 0, and the advance after the final step is skipped. Keyword + Exactly ``n`` minibatches are fed to the model: the first seeds the + placeholder before step 0, and the advance after the final step is skipped. + The accounting stream reads one batch ahead so the pass-size check can fire + at a pass boundary, so a re-readable source (the only kind the loader + accepts) may be read one batch past the ``n`` the model uses. Keyword arguments are forwarded to :func:`pymc.fit` on top of the constructor's ``fit_kwargs`` (per-call wins); ``progressbar`` defaults to ``False`` unless either sets it. @@ -796,16 +799,35 @@ def __init__(self, paths: list[str], columns: list[str], n_rows: int): self.n_rows = n_rows def __iter__(self) -> Iterator[np.ndarray]: + import pyarrow as pa import pyarrow.parquet as pq for path in self._paths: file = pq.ParquetFile(path) - missing = [c for c in self._columns if c not in file.schema_arrow.names] + schema = file.schema_arrow + missing = [c for c in self._columns if c not in schema.names] if missing: # read_row_group(columns=...) silently drops unknown names, so a # malformed shard must be named here, not surface as a bare # KeyError with no path. raise ValueError(f"columns {missing} not found in {path!r}") + non_numeric = [ + c + for c in self._columns + if not ( + pa.types.is_integer(schema.field(c).type) + or pa.types.is_floating(schema.field(c).type) + or pa.types.is_boolean(schema.field(c).type) + ) + ] + if non_numeric: + # parquet_source validates types against the first shard only; a + # later shard whose column turned non-numeric would otherwise + # become an object array and fail at the batch cast with no path. + raise ValueError( + f"columns {non_numeric} in {path!r} are not numeric and cannot be " + f"streamed into a float batch; select numeric columns with columns=." + ) for i in range(file.metadata.num_row_groups): table = file.read_row_group(i, columns=self._columns) # Stack by the frozen column names, not the file's own order, so diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index e6b277dee5..06c1c47507 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -531,20 +531,34 @@ def test_fit_one_step_on_single_batch_one_shot_source(): assert loader.batches_seen == 1 -def test_refine_after_fit_keeps_streaming(): - """Inference.refine replays pm.fit's saved callbacks; the internal advance - skips only fit's own final step, so refine keeps streaming fresh batches - instead of going permanently dead and retraining on one batch.""" - data = np.ones((4, 1)) - loader = DataLoader(lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4) - with pm.Model(): +def test_refine_after_fit_resumes_the_stream(): + """Inference.refine replays pm.fit's saved callbacks. Because the advance + skips only fit's own final step (and not every step past n), refine resumes + advancing the stream instead of going permanently dead on the last batch. + + refine does not re-seed, so its first step still trains on the batch fit left + in the placeholder; this pins that resume-not-reseed behavior with distinct + batch markers rather than claiming every refine step is fresh. + """ + blocks = [np.full((4, 1), float(i)) for i in range(50)] + loader = DataLoader(lambda: iter(blocks), batch_size=4, sample_shape=(1,), total_size=4) + sets = [] + with pm.Model() as model: mu = pm.Normal("mu", 0, 1) batch = pm.Data("batch", np.zeros((4, 1))) pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + original = model.set_data + model.set_data = lambda name, values, *a, **k: ( # type: ignore[method-assign] + sets.append(float(np.asarray(values)[0, 0])), + original(name, values, *a, **k), + )[1] inference = pm.ADVI(random_seed=0) Trainer(method=inference, dataloader=loader).fit(3) - assert loader.batches_seen == 3 + assert sets == [0.0, 1.0, 2.0] # fit seeds 0, advances to 1 and 2, skips its last + sets.clear() inference.refine(4, progressbar=False) + # refine resumes from where the stream stopped (3, 4, 5, ...), not stuck on 2 + assert sets == [3.0, 4.0, 5.0, 6.0] assert loader.batches_seen == 7 diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index c06df07c7f..f3342ddd18 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -315,6 +315,19 @@ def test_parquet_source_streams_row_groups_not_whole_files(tmp_path): np.testing.assert_array_equal(np.concatenate(blocks).ravel(), np.arange(30.0)) +def test_parquet_source_names_a_later_shard_with_a_non_numeric_column(tmp_path): + """parquet_source type-checks only the first shard at construction; a later + shard whose same-named column turned non-numeric is caught at iteration with + that shard's path, not as an opaque float-cast error downstream.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0, 2.0]}), f"{tmp_path}/p0.parquet") + pq.write_table(pa.table({"a": ["bad", "worse"]}), f"{tmp_path}/p1.parquet") + src = parquet_source(str(tmp_path)) # construction sees only the numeric p0 + with pytest.raises(ValueError, match=r"p1\.parquet.*not numeric"): + list(src) + + def test_parquet_source_rejects_non_numeric_columns(tmp_path): """A string column cannot be streamed into a float batch; the default all-columns freeze rejects it at construction, naming the column and the From 796306db6093bdb5c366040a7cb477fd23ad82f7 Mon Sep 17 00:00:00 2001 From: Yicheng Yang Date: Tue, 16 Jun 2026 07:37:35 -0500 Subject: [PATCH 27/27] Split the Trainer into a follow-up PR; share the loader test helper Keep this PR to the dataset/loader layer (IterableDataset, DataLoader, shuffle_buffer, parquet_source); the Trainer and its tests move to a stacked follow-up PR. Fold the re-readable chunked-source factory the loader tests share into tests/variational/streaming_helpers.py, which doubles as a place to explain why a re-readable factory (not a one-shot generator) is needed. --- docs/source/api/vi.rst | 1 - pymc/variational/__init__.py | 2 - pymc/variational/streaming.py | 179 +---------- tests/variational/streaming_helpers.py | 34 ++ tests/variational/test_streaming.py | 312 ++----------------- tests/variational/test_streaming_autosize.py | 49 ++- 6 files changed, 98 insertions(+), 479 deletions(-) create mode 100644 tests/variational/streaming_helpers.py diff --git a/docs/source/api/vi.rst b/docs/source/api/vi.rst index 3e59294f67..54969287f1 100644 --- a/docs/source/api/vi.rst +++ b/docs/source/api/vi.rst @@ -79,7 +79,6 @@ memory (see :mod:`pymc.variational.streaming`). DataLoader IterableDataset - Trainer shuffle_buffer parquet_source diff --git a/pymc/variational/__init__.py b/pymc/variational/__init__.py index 61896fb068..2e8ccd066b 100644 --- a/pymc/variational/__init__.py +++ b/pymc/variational/__init__.py @@ -47,7 +47,6 @@ from pymc.variational.streaming import ( DataLoader, IterableDataset, - Trainer, parquet_source, shuffle_buffer, ) @@ -78,7 +77,6 @@ "Group", "IterableDataset", "MeanField", - "Trainer", "adadelta", "adagrad", "adagrad_window", diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py index 9317e61627..e614eda282 100644 --- a/pymc/variational/streaming.py +++ b/pymc/variational/streaming.py @@ -28,14 +28,10 @@ minibatches; it is iterable (the minibatch stream) and sized. Note ``len(loader)`` is the row count ``N`` (what the observed distribution needs for ``total_size``), not the batch count ``torch.utils.data.DataLoader.__len__`` returns. -* :class:`Trainer`: drives variational inference (ADVI, ...) over a - ``DataLoader`` with no user-facing callbacks; - ``Trainer(method=..., dataloader=...).fit(n)`` streams each minibatch into the - model's ``pm.Data`` placeholder with ``set_data``. With bounded source chunks the full data never sits in RAM at once. The model graph observes only a ``(batch_size, *sample_shape)`` ``pm.Data`` placeholder -that the ``Trainer`` overwrites with the next minibatch every step. Passing a +that is overwritten with the next minibatch every step. Passing a directory of Parquet shards far larger than RAM still gives a model whose resident footprint is one batch (:func:`parquet_source` reads one row group at a time). @@ -68,7 +64,7 @@ import numpy as np import pymc as pm - from pymc.variational.streaming import DataLoader, Trainer, parquet_source + from pymc.variational.streaming import DataLoader, parquet_source # The data was pre-shuffled on disk once (see the module note on shuffling), # so the loader streams it sequentially. The full table stays on disk. @@ -85,9 +81,13 @@ logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) - # No callbacks: the Trainer streams each minibatch into "batch" with set_data. + # The loader is sized (len(loader) == N, what total_size needs) and iterable: + # each epoch yields validated (batch_size, *sample_shape) minibatches. Stream + # each into the "batch" placeholder with model.set_data before a step. with model: - approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) + for minibatch in loader: + model.set_data("batch", minibatch) + ... # one variational step over this minibatch """ from __future__ import annotations @@ -101,11 +101,7 @@ import numpy as np -from pymc.model import modelcontext -from pymc.variational.inference import Inference -from pymc.variational.inference import fit as _fit - -__all__ = ["DataLoader", "IterableDataset", "Trainer", "parquet_source", "shuffle_buffer"] +__all__ = ["DataLoader", "IterableDataset", "parquet_source", "shuffle_buffer"] def _is_positive_int(value: object) -> bool: @@ -141,10 +137,9 @@ class DataLoader: """Turn an out-of-core dataset into fixed-size minibatches for variational inference. Like ``torch.utils.data.DataLoader``, it batches (and optionally - shuffles) an :class:`IterableDataset` into the minibatch stream that - :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)`` - is the dataset size ``N``). With bounded source chunks the full dataset is - never resident at once. + shuffles) an :class:`IterableDataset` into a minibatch stream for variational + inference. It is iterable and sized (``len(loader)`` is the dataset size + ``N``). With bounded source chunks the full dataset is never resident at once. Parameters ---------- @@ -292,10 +287,10 @@ def _rebatched(self) -> Iterator[np.ndarray]: def __iter__(self) -> Iterator[np.ndarray]: """Yield one epoch of validated ``(batch_size, *sample_shape)`` minibatches. - The same batches the :class:`Trainer` streams into the model's ``pm.Data`` - placeholder (it consumes them through an accounting wrapper, so plain - iteration leaves the counters untouched). Re-iterate the loader for - another epoch. + Stream each into the model's ``pm.Data`` placeholder with ``model.set_data`` + before a step. Plain iteration leaves :attr:`batches_seen` / + :attr:`rows_streamed` untouched (it does not run the internal accounting + path); re-iterate the loader for another epoch. """ for batch in self._rebatched(): yield self._prepare(batch) @@ -316,7 +311,7 @@ def __len__(self) -> int: return self._total_size def _stream_batches(self) -> Iterator[np.ndarray]: - """One epoch of prepared minibatches, with accounting (the Trainer's path). + """One epoch of prepared minibatches, with accounting (the consumer's path). Like :meth:`__iter__` but it updates :attr:`batches_seen` / :attr:`rows_streamed` and runs the one-shot ``total_size`` sanity check on @@ -395,146 +390,6 @@ def _validate(self, batch: np.ndarray) -> None: ) -class Trainer: - """Drive variational inference over a :class:`DataLoader` without user callbacks. - - Follows the design in PyMC's variational-inference rework and PyTorch - Lightning: the ``Trainer`` owns the training loop, the - :class:`DataLoader` owns batching (and ``len(dataloader)`` is the dataset size - ``N``), and the model owns the math. The model exposes a ``pm.Data`` placeholder; - the ``Trainer`` streams minibatches into it with ``model.set_data`` once per - step; no user callbacks are needed. - - Parameters - ---------- - method : str or Inference, default "advi" - Variational method, forwarded to :func:`pymc.fit`: a name (``"advi"``, - ``"fullrank_advi"``, ...) or an :class:`~pymc.variational.inference.Inference` - instance. ``pm.fit`` applies ``model`` and ``random_seed`` only to a name; - an instance is already bound to a model, so configure it at construction - (e.g. ``ADVI(random_seed=...)``). - dataloader : DataLoader - The minibatch source. ``len(dataloader)`` is ``N``; the model should pass - it to the observed distribution's ``total_size``. - model : pymc.Model, optional - Defaults to the model on the context stack. - data_name : str, default "batch" - Name of the ``pm.Data`` placeholder minibatches are streamed into. Must - match the name used for ``pm.Data(name, ...)`` in the model. - **fit_kwargs - Default keyword arguments forwarded to :func:`pymc.fit` (e.g. - ``obj_optimizer``); per-call kwargs to :meth:`fit` override them. - - Notes - ----- - The per-step ``set_data`` currently lives in the ``Trainer``. Once the VI - rework's ``Inference.step(batch)`` lands it moves there, at which point the - ``total_size`` rescaling can be derived from ``len(dataloader)`` and dropped - from the model body entirely. - - Examples - -------- - .. code-block:: python - - loader = DataLoader( - parquet_source("shuffled/"), batch_size=4096, sample_shape=(4,), total_size="auto" - ) - with pm.Model() as model: - b = pm.Normal("b", 0.0, 3.0, shape=4) - batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder - logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] - pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) - approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) - """ - - def __init__( - self, - *, - method: str | Inference = "advi", - dataloader: DataLoader, - model=None, - data_name: str = "batch", - **fit_kwargs, - ): - self.method = method - self.dataloader = dataloader - self.model = model - self.data_name = data_name - self._fit_kwargs = fit_kwargs - - def fit(self, n: int = 10_000, **kwargs): - """Fit for ``n`` steps, streaming minibatches into the model's placeholder. - - Exactly ``n`` minibatches are fed to the model: the first seeds the - placeholder before step 0, and the advance after the final step is skipped. - The accounting stream reads one batch ahead so the pass-size check can fire - at a pass boundary, so a re-readable source (the only kind the loader - accepts) may be read one batch past the ``n`` the model uses. Keyword - arguments are forwarded to :func:`pymc.fit` on top of the constructor's - ``fit_kwargs`` (per-call wins); ``progressbar`` defaults to ``False`` - unless either sets it. - - Returns - ------- - :class:`Approximation` - The fitted approximation, as returned by :func:`pymc.fit`. - """ - if not _is_positive_int(n): - raise ValueError(f"n must be a positive integer (the number of fit steps), got {n!r}") - loader = self.dataloader - if not isinstance(loader, DataLoader): - raise TypeError( - f"Trainer needs a DataLoader for `dataloader`, got {type(loader).__name__}." - ) - model = modelcontext(self.model) - if self.data_name not in model: - # Checked before the stream starts so no batch is consumed (and no - # counter advances) on a typo. - raise KeyError( - f"data_name {self.data_name!r} is not a variable in the model; it " - f"must name the pm.Data placeholder the minibatches are streamed into." - ) - - def _stream() -> Iterator[np.ndarray]: - while True: - empty = True - for batch in loader._stream_batches(): - empty = False - yield batch - if empty: - raise RuntimeError("dataloader yielded no batches") - - batches = _stream() - # Seed the placeholder before step 0: pm.fit runs callbacks after each step, - # so without this the first step would train on the placeholder's contents. - model.set_data(self.data_name, next(batches)) - - steps_done = 0 - - def _advance(*_): - # pm.fit fires callbacks after every step including the last; skip the - # advance on this fit's final step so exactly n batches are consumed. - # Only that one call is skipped (not every call past n): Inference.refine - # replays the saved callbacks and must keep streaming fresh batches. - nonlocal steps_done - steps_done += 1 - if steps_done != n: - model.set_data(self.data_name, next(batches)) - - merged = {**self._fit_kwargs, **kwargs} - merged.setdefault("progressbar", False) - # User callbacks (e.g. convergence trackers) are appended after the - # internal advance instead of colliding with it on the keyword. - user_callbacks = merged.pop("callbacks", None) or [] - return _fit( - n, - method=self.method, - model=model, - callbacks=[_advance, *user_callbacks], - **merged, - ) - - def shuffle_buffer( chunk_source: Callable[[], Iterator[np.ndarray]], *, diff --git a/tests/variational/streaming_helpers.py b/tests/variational/streaming_helpers.py new file mode 100644 index 0000000000..4d7fab5a1a --- /dev/null +++ b/tests/variational/streaming_helpers.py @@ -0,0 +1,34 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared helpers for the streaming-dataset tests.""" + + +def chunked_factory(data, size): + """Return a zero-arg factory that replays ``data`` in ``size``-row chunks. + + A ``DataLoader`` restarts its source once per epoch, so the source has to be + re-readable. This returns a *factory* (a zero-arg callable) that produces a + fresh generator each call, the way an out-of-core source like + ``parquet_source`` does; a bare generator would be one-shot and could not be + replayed. The + final chunk may hold fewer than ``size`` rows -- the loader re-batches the + stream to ``batch_size`` regardless -- so this also exercises the loader's + re-batching across uneven source blocks. + """ + + def factory(): + for i in range(0, len(data), size): + yield data[i : i + size] + + return factory diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py index 06c1c47507..175d145093 100644 --- a/tests/variational/test_streaming.py +++ b/tests/variational/test_streaming.py @@ -19,24 +19,16 @@ from pymc.variational.streaming import ( DataLoader, IterableDataset, - Trainer, shuffle_buffer, ) - - -def _chunks(data, size): - def factory(): - for i in range(0, len(data), size): - yield data[i : i + size] - - return factory +from tests.variational.streaming_helpers import chunked_factory def test_plain_loader_rebatches_arbitrary_blocks(): """Blocks of 3 with batch_size=4 are re-batched in order; the trailing rows that cannot fill a final batch are dropped (drop_last semantics).""" data = np.arange(20, dtype="float64").reshape(10, 2) - ds = DataLoader(_chunks(data, 3), batch_size=4, sample_shape=(2,), total_size=10) + ds = DataLoader(chunked_factory(data, 3), batch_size=4, sample_shape=(2,), total_size=10) batches = list(ds) assert [b.shape for b in batches] == [(4, 2), (4, 2)] np.testing.assert_array_equal(np.concatenate(batches), data[:8]) @@ -58,7 +50,7 @@ def test_raw_array_source_like_vi_rework_sketch(): def test_wrong_sample_shape_rejected(): """A source whose trailing shape does not match sample_shape raises.""" data = np.zeros((12, 3)) - ds = DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(2,), total_size=12) + ds = DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(2,), total_size=12) with pytest.raises(ValueError, match="source yielded shape"): next(iter(ds)) @@ -67,14 +59,14 @@ def test_total_size_none_warns_at_construction(): """total_size=None disables the N/batch_size rescaling, so it warns.""" data = np.zeros((8, 1)) with pytest.warns(UserWarning, match="total_size=None"): - DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,)) + DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,)) def test_preprocess_fn_applied(): """preprocess_fn transforms each batch before it is yielded.""" data = np.ones((8, 1)) ds = DataLoader( - _chunks(data, 4), + chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=8, @@ -87,7 +79,7 @@ def test_shuffle_buffer_conserves_rows_with_non_dividing_chunks(): """Chunk and buffer sizes that do not divide batch_size must not lose or duplicate rows; the remainder is carried into the next buffer fill.""" data = np.arange(140, dtype="float64").reshape(140, 1) - src = shuffle_buffer(_chunks(data, 7), buffer_size=55, batch_size=10, seed=0) + src = shuffle_buffer(chunked_factory(data, 7), buffer_size=55, batch_size=10, seed=0) batches = list(src()) assert all(b.shape == (10, 1) for b in batches) seen = np.sort(np.concatenate([b.ravel() for b in batches])) @@ -98,7 +90,7 @@ def test_shuffle_buffer_does_not_mutate_source(): """Shuffling happens on an owned copy, never in place on the source arrays.""" data = np.arange(100, dtype="float64").reshape(100, 1) original = data.copy() - src = shuffle_buffer(_chunks(data, 25), buffer_size=40, batch_size=10, seed=1) + src = shuffle_buffer(chunked_factory(data, 25), buffer_size=40, batch_size=10, seed=1) list(src()) np.testing.assert_array_equal(data, original) @@ -108,7 +100,7 @@ def test_dataloader_shuffle_true_yields_full_batches(): full batches and conserves every row when N divides batch_size.""" data = np.arange(120, dtype="float64").reshape(120, 1) ds = DataLoader( - _chunks(data, 8), + chunked_factory(data, 8), batch_size=10, shuffle=True, buffer_size=40, @@ -146,110 +138,10 @@ def test_total_size_rescales_logp_like_minibatch(): np.testing.assert_allclose(obs_scaled, obs_plain * (N / bs), rtol=1e-6) -def test_trainer_end_to_end_matches_in_ram_minibatch(): - """End-to-end: Trainer-driven streaming ADVI reproduces in-RAM pm.Minibatch ADVI. - - Exercises the whole API: a pm.Data placeholder, total_size=len(loader), and a - Trainer that streams minibatches into the placeholder with set_data while the - user writes no callbacks. Runs long enough to cycle the loader across epochs. - """ - seed = 0 - rng = np.random.default_rng(seed) - N, bs = 60_000, 2048 - X = rng.normal(size=(N, 2)) - b_true = np.array([0.3, -1.1, 0.7]) - y = (rng.random(N) < 1 / (1 + np.exp(-(b_true[0] + X @ b_true[1:])))).astype("float64") - data = np.column_stack([X, y]) - - with pm.Model(): - b = pm.Normal("b", 0, 3, shape=3) - xb, zb, yb = pm.Minibatch(X[:, 0].copy(), X[:, 1].copy(), y, batch_size=bs) - pm.Bernoulli("o", logit_p=b[0] + b[1] * xb + b[2] * zb, observed=yb, total_size=N) - ap = pm.fit( - 6000, - method="advi", - obj_optimizer=pm.adam(learning_rate=0.02), - progressbar=False, - random_seed=seed, - ) - in_ram = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) - - loader = DataLoader( - _chunks(data, 20_000), - batch_size=bs, - shuffle=True, - buffer_size=40_000, - seed=seed, - sample_shape=(3,), - total_size=N, - ) - with pm.Model() as model: - b = pm.Normal("b", 0, 3, shape=3) - batch = pm.Data("batch", np.zeros((bs, 3))) - pm.Bernoulli( - "o", - logit_p=b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1], - observed=batch[:, 2], - total_size=len(loader), - ) - ap = Trainer( - method="advi", - dataloader=loader, - data_name="batch", - obj_optimizer=pm.adam(learning_rate=0.02), - ).fit(6000, random_seed=seed) - stream = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) - - np.testing.assert_allclose(in_ram, stream, atol=0.1) - - -def test_trainer_streams_into_placeholder(): - """The Trainer seeds the pm.Data placeholder before step 0 (pm.fit runs - callbacks after each step) and overwrites it each step; after fitting it holds - a real batch, not the zero seed.""" - data = np.ones((4, 1)) - loader = DataLoader(lambda: iter([data] * 100), batch_size=4, sample_shape=(1,), total_size=4) - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((4, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - Trainer(method="advi", dataloader=loader, data_name="batch").fit( - 5, progressbar=False, random_seed=0 - ) - np.testing.assert_array_equal(model["batch"].get_value(), data) - - -def test_trainer_raises_when_loader_cannot_restart(): - """A source that streams one epoch and then comes back empty cannot be cycled; - the Trainer surfaces a clear error instead of training on stale data.""" - calls = {"n": 0} - - def factory(): - calls["n"] += 1 - if calls["n"] == 1: - yield np.zeros((4, 1)) - - loader = DataLoader(factory, batch_size=4, sample_shape=(1,), total_size=4) - with pm.Model(): - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((4, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - with pytest.raises(RuntimeError, match="yielded no batches"): - Trainer(method="advi", dataloader=loader, data_name="batch").fit( - 5, progressbar=False, random_seed=0 - ) - - -def test_trainer_rejects_non_dataloader(): - """The isinstance guard fires before any model lookup.""" - with pytest.raises(TypeError, match="DataLoader"): - Trainer(method="advi", dataloader=object()).fit(10) - - def test_len_returns_total_size(): """len(loader) is the dataset row count N, the value total_size needs.""" data = np.zeros((40, 1)) - loader = DataLoader(_chunks(data, 8), batch_size=8, sample_shape=(1,), total_size=40) + loader = DataLoader(chunked_factory(data, 8), batch_size=8, sample_shape=(1,), total_size=40) assert len(loader) == 40 @@ -267,7 +159,7 @@ def test_iter_yields_clean_batches_and_reiterates(): """__iter__ yields validated (batch_size, *sample_shape) batches and can be re-iterated for another epoch.""" data = np.arange(40, dtype="float64").reshape(40, 1) - loader = DataLoader(_chunks(data, 10), batch_size=10, sample_shape=(1,), total_size=40) + loader = DataLoader(chunked_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=40) e1 = list(loader) e2 = list(loader) assert len(e1) == 4 and all(b.shape == (10, 1) for b in e1) @@ -279,21 +171,21 @@ def test_total_size_zero_raises(): """total_size=0 is falsy and would silently skip the rescaling, so it raises.""" data = np.zeros((8, 1)) with pytest.raises(ValueError, match="positive integer"): - DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=0) + DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=0) def test_total_size_negative_raises(): """A negative total_size would flip the sign of the data log-likelihood.""" data = np.zeros((8, 1)) with pytest.raises(ValueError, match="positive integer"): - DataLoader(_chunks(data, 4), batch_size=4, sample_shape=(1,), total_size=-100) + DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=-100) def test_shuffle_buffer_small_buffer_conserves_rows(): """buffer_size < batch_size must not silently discard the dataset: the buffer accumulates to at least batch_size before emitting.""" data = np.arange(120, dtype="float64").reshape(120, 1) - src = shuffle_buffer(_chunks(data, 7), buffer_size=3, batch_size=10, seed=0) + src = shuffle_buffer(chunked_factory(data, 7), buffer_size=3, batch_size=10, seed=0) batches = list(src()) assert batches assert all(b.shape == (10, 1) for b in batches) @@ -305,28 +197,28 @@ def test_shuffle_buffer_rejects_nonpositive_sizes(): """Zero or negative buffer/batch sizes raise at construction.""" data = np.zeros((10, 1)) with pytest.raises(ValueError, match="buffer_size"): - shuffle_buffer(_chunks(data, 5), buffer_size=0, batch_size=4) + shuffle_buffer(chunked_factory(data, 5), buffer_size=0, batch_size=4) with pytest.raises(ValueError, match="batch_size"): - shuffle_buffer(_chunks(data, 5), buffer_size=10, batch_size=0) + shuffle_buffer(chunked_factory(data, 5), buffer_size=10, batch_size=0) def test_accepts_numpy_integer_sizes_rejects_bool(): """The positive-int check uses numbers.Integral: numpy ints pass, bool does not.""" data = np.zeros((8, 1)) ds = DataLoader( - _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) + chunked_factory(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) ) assert next(iter(ds)).shape == (4, 1) assert ds.batch_size == 4 with pytest.raises(ValueError): - DataLoader(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8) + DataLoader(chunked_factory(data, 4), batch_size=True, sample_shape=(1,), total_size=8) def test_shuffle_buffer_draws_fresh_permutation_each_epoch(): """A seeded buffer must not replay one fixed permutation every epoch; each epoch reshuffles while conserving rows.""" data = np.arange(60, dtype="float64").reshape(60, 1) - factory = shuffle_buffer(_chunks(data, 10), buffer_size=60, batch_size=10, seed=0) + factory = shuffle_buffer(chunked_factory(data, 10), buffer_size=60, batch_size=10, seed=0) epoch1 = np.concatenate([b.ravel() for b in factory()]) epoch2 = np.concatenate([b.ravel() for b in factory()]) assert not np.array_equal(epoch1, epoch2) @@ -340,13 +232,17 @@ def test_shuffle_buffer_seed_reproducible_across_runs(): a = np.concatenate( [ b.ravel() - for b in shuffle_buffer(_chunks(data, 10), buffer_size=60, batch_size=10, seed=7)() + for b in shuffle_buffer( + chunked_factory(data, 10), buffer_size=60, batch_size=10, seed=7 + )() ] ) b = np.concatenate( [ b.ravel() - for b in shuffle_buffer(_chunks(data, 10), buffer_size=60, batch_size=10, seed=7)() + for b in shuffle_buffer( + chunked_factory(data, 10), buffer_size=60, batch_size=10, seed=7 + )() ] ) np.testing.assert_array_equal(a, b) @@ -357,7 +253,7 @@ def test_sizes_normalized_to_python_int(): accepted downstream by create_minibatch_rv.""" data = np.zeros((8, 1)) ds = DataLoader( - _chunks(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) + chunked_factory(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) ) assert type(ds.batch_size) is int assert type(ds.total_size) is int @@ -419,25 +315,6 @@ def test_scalar_samples_are_batched(): np.testing.assert_array_equal(np.concatenate(batches), data) -def test_trainer_appends_user_callbacks_and_streams_distinct_batches(): - """User callbacks (e.g. convergence trackers) compose with the internal - advance callback instead of colliding on the keyword, and the placeholder - holds a different batch on successive steps. Also exercises the default - data_name ("batch").""" - blocks = [np.full((4, 1), float(i)) for i in range(60)] - loader = DataLoader(lambda: iter(blocks), batch_size=4, sample_shape=(1,), total_size=240) - seen = [] - with pm.Model() as model: - x = pm.Normal("x", 0.0, 1.0) - batch = pm.Data("batch", np.zeros((4, 1))) - pm.Normal("y", x, 1.0, observed=batch[:, 0], total_size=len(loader)) - Trainer(method="advi", dataloader=loader).fit( - 5, callbacks=[lambda *_: seen.append(float(model["batch"].get_value()[0, 0]))] - ) - assert len(seen) == 5 - assert len(set(seen)) > 1 - - def test_iterable_dataset_base_is_abstract(): """The base class is a contract: __iter__ must be overridden.""" with pytest.raises(NotImplementedError): @@ -466,131 +343,6 @@ def test_explicit_sample_shape_overrides_inference(): assert [b.shape for b in batches] == [(8,)] * 5 -def test_trainer_accepts_inference_instance(): - """An Inference instance is forwarded to pm.fit unchanged; it is bound to - the model it was built under, so the Trainer only streams the batches.""" - data = np.ones((4, 1)) - loader = DataLoader(lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4) - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((4, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - approx = Trainer(method=pm.ADVI(random_seed=0), dataloader=loader).fit(5) - assert len(approx.hist) == 5 - np.testing.assert_array_equal(model["batch"].get_value(), data) - - -def test_constructor_fit_kwargs_take_random_seed(): - """random_seed works as a constructor default, as the docstring promises, - and a per-call value overrides the constructor's.""" - data = np.ones((4, 1)) - - def fit_with(ctor_kwargs, fit_kwargs): - loader = DataLoader( - lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4 - ) - with pm.Model(): - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((4, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - return Trainer(method="advi", dataloader=loader, data_name="batch", **ctor_kwargs).fit( - 5, **fit_kwargs - ) - - a = fit_with({"random_seed": 7}, {}) - b = fit_with({"random_seed": 0}, {"random_seed": 7}) - np.testing.assert_array_equal(a.hist, b.hist) - - -def test_fit_consumes_exactly_n_batches(): - """fit(n) consumes exactly n minibatches: one seeds the placeholder before - step 0 and the advance after the final step is skipped, so an (n+1)-th - batch is never fetched.""" - blocks = [np.full((2, 1), float(i)) for i in range(2)] - loader = DataLoader(lambda: iter(blocks), batch_size=2, sample_shape=(1,), total_size=4) - with pm.Model(): - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((2, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - Trainer(method="advi", dataloader=loader).fit(3, random_seed=0) - assert loader.batches_seen == 3 - assert loader.rows_streamed == 6 - - -def test_fit_one_step_on_single_batch_one_shot_source(): - """A finite stream with exactly the batches needed must not be over-consumed: - fit(1) on a one-batch, one-shot source trains and returns instead of failing - on a post-final restart.""" - loader = DataLoader(iter([np.ones((2, 1))]), batch_size=2, sample_shape=(1,), total_size=2) - with pm.Model(): - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((2, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - approx = Trainer(method="advi", dataloader=loader).fit(1, random_seed=0) - assert len(approx.hist) == 1 - assert loader.batches_seen == 1 - - -def test_refine_after_fit_resumes_the_stream(): - """Inference.refine replays pm.fit's saved callbacks. Because the advance - skips only fit's own final step (and not every step past n), refine resumes - advancing the stream instead of going permanently dead on the last batch. - - refine does not re-seed, so its first step still trains on the batch fit left - in the placeholder; this pins that resume-not-reseed behavior with distinct - batch markers rather than claiming every refine step is fresh. - """ - blocks = [np.full((4, 1), float(i)) for i in range(50)] - loader = DataLoader(lambda: iter(blocks), batch_size=4, sample_shape=(1,), total_size=4) - sets = [] - with pm.Model() as model: - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((4, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - original = model.set_data - model.set_data = lambda name, values, *a, **k: ( # type: ignore[method-assign] - sets.append(float(np.asarray(values)[0, 0])), - original(name, values, *a, **k), - )[1] - inference = pm.ADVI(random_seed=0) - Trainer(method=inference, dataloader=loader).fit(3) - assert sets == [0.0, 1.0, 2.0] # fit seeds 0, advances to 1 and 2, skips its last - sets.clear() - inference.refine(4, progressbar=False) - # refine resumes from where the stream stopped (3, 4, 5, ...), not stuck on 2 - assert sets == [3.0, 4.0, 5.0, 6.0] - assert loader.batches_seen == 7 - - -def test_total_size_check_fires_when_fit_ends_at_pass_boundary(): - """fit(n) with n exactly the batches in one pass still runs the total_size - sanity check: the stream is kept one batch ahead, so stopping at the - boundary does not abandon the check right before it would fire.""" - data = np.zeros((40, 1)) - loader = DataLoader(_chunks(data, 10), batch_size=10, sample_shape=(1,), total_size=400) - with pm.Model(): - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((10, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - with pytest.warns(UserWarning, match="disagrees with"): - Trainer(method="advi", dataloader=loader).fit(4, random_seed=0) - - -def test_fit_rejects_nonpositive_n(): - """fit consumes the seed batch before pm.fit could reject n itself, so a - non-positive n is refused up front, before touching the stream.""" - loader = DataLoader( - lambda: iter([np.zeros((2, 1))]), batch_size=2, sample_shape=(1,), total_size=2 - ) - with pm.Model(): - mu = pm.Normal("mu", 0, 1) - batch = pm.Data("batch", np.zeros((2, 1))) - pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) - with pytest.raises(ValueError, match="positive integer"): - Trainer(method="advi", dataloader=loader).fit(0) - assert loader.batches_seen == 0 - - def test_shuffle_buffer_accepts_factory_returning_reiterable(): """A factory returning a re-iterable (which _make_factory tolerates for the loader) must not restart per buffer fill and loop forever; the stream is @@ -603,17 +355,3 @@ def test_shuffle_buffer_accepts_factory_returning_reiterable(): np.testing.assert_array_equal( np.sort(np.concatenate([b.ravel() for b in batches])), data.ravel() ) - - -def test_unknown_data_name_raises_before_consuming(): - """A data_name that is not in the model raises a guided KeyError before any - batch is pulled from the loader.""" - loader = DataLoader( - lambda: iter([np.zeros((4, 1))] * 3), batch_size=4, sample_shape=(1,), total_size=4 - ) - with pm.Model(): - pm.Normal("mu", 0, 1) - with pytest.raises(KeyError, match="pm.Data placeholder"): - Trainer(method="advi", dataloader=loader, data_name="nope").fit(2) - assert loader.batches_seen == 0 - assert loader.rows_streamed == 0 diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py index f3342ddd18..c0a71de0ea 100644 --- a/tests/variational/test_streaming_autosize.py +++ b/tests/variational/test_streaming_autosize.py @@ -24,30 +24,23 @@ parquet_source, shuffle_buffer, ) - - -def _factory(data, size): - """A re-readable zero-arg factory yielding `size`-row chunks of `data`.""" - - def f(): - for i in range(0, len(data), size): - yield data[i : i + size] - - return f +from tests.variational.streaming_helpers import chunked_factory def test_auto_counts_finite_source(): """Without .n_rows, 'auto' does one counting pass and resolves the true N.""" data = np.arange(60, dtype="float64").reshape(60, 1) with pytest.warns(UserWarning, match="counting pass"): - ds = DataLoader(_factory(data, 7), batch_size=10, sample_shape=(1,), total_size="auto") + ds = DataLoader( + chunked_factory(data, 7), batch_size=10, sample_shape=(1,), total_size="auto" + ) assert ds.total_size == 60 def test_auto_uses_n_rows_fast_path(): """A source-advertised .n_rows is trusted without a counting pass.""" data = np.zeros((8, 1)) - f = _factory(data, 4) + f = chunked_factory(data, 4) f.n_rows = 1000 ds = DataLoader(f, batch_size=4, sample_shape=(1,), total_size="auto") assert ds.total_size == 1000 @@ -65,7 +58,7 @@ def test_shuffle_buffer_forwards_n_rows_for_auto(): """shuffle_buffer forwards a known .n_rows so total_size='auto' works through an explicit shuffle_buffer(parquet_source(...)) composition without counting.""" data = np.arange(40, dtype="float64").reshape(40, 1) - src = _factory(data, 8) + src = chunked_factory(data, 8) src.n_rows = 40 wrapped = shuffle_buffer(src, buffer_size=20, batch_size=10, seed=0) assert wrapped.n_rows == 40 @@ -80,7 +73,7 @@ def test_dataloader_shuffle_auto_resolves_via_n_rows(): """DataLoader(shuffle=True, total_size='auto') resolves N from the source's .n_rows without a counting pass, even though shuffle wraps the source.""" data = np.arange(40, dtype="float64").reshape(40, 1) - src = _factory(data, 8) + src = chunked_factory(data, 8) src.n_rows = 40 with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) @@ -99,7 +92,7 @@ def test_dataloader_shuffle_auto_resolves_via_n_rows(): def test_shuffle_buffer_without_n_rows_has_no_attribute(): """A source without .n_rows must not gain a bogus one through the wrapper.""" data = np.arange(40, dtype="float64").reshape(40, 1) - wrapped = shuffle_buffer(_factory(data, 8), buffer_size=20, batch_size=10, seed=0) + wrapped = shuffle_buffer(chunked_factory(data, 8), buffer_size=20, batch_size=10, seed=0) assert not hasattr(wrapped, "n_rows") @@ -117,7 +110,7 @@ def test_auto_rejects_factory_returning_same_one_shot_iterator(): def test_auto_rejects_bad_n_rows(): """A non-positive source .n_rows is rejected instead of trusted.""" - f = _factory(np.zeros((8, 1)), 4) + f = chunked_factory(np.zeros((8, 1)), 4) f.n_rows = 0 with pytest.raises(ValueError, match="n_rows must be a positive integer"): DataLoader(f, batch_size=4, sample_shape=(1,), total_size="auto") @@ -127,7 +120,7 @@ def test_sanity_warns_on_grossly_wrong_total_size(): """A hand-passed total_size that grossly disagrees with the rows actually streamed in one pass triggers the one-shot warning at the epoch boundary.""" data = np.arange(20, dtype="float64").reshape(20, 1) - ds = DataLoader(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=100) + ds = DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=100) with pytest.warns(UserWarning, match="disagrees with"): list(ds._stream_batches()) @@ -135,7 +128,7 @@ def test_sanity_warns_on_grossly_wrong_total_size(): def test_sanity_silent_when_total_size_matches(): """No warning when total_size matches the rows streamed in one pass.""" data = np.arange(20, dtype="float64").reshape(20, 1) - ds = DataLoader(_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=20) + ds = DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=20) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) list(ds._stream_batches()) @@ -198,7 +191,7 @@ def test_auto_counts_unshuffled_source_when_shuffling_non_divisible(): data = np.arange(125, dtype="float64").reshape(125, 1) with pytest.warns(UserWarning, match="counting pass"): ds = DataLoader( - _factory(data, 125), + chunked_factory(data, 125), batch_size=10, shuffle=True, buffer_size=30, @@ -210,12 +203,12 @@ def test_auto_counts_unshuffled_source_when_shuffling_non_divisible(): def test_stream_batches_updates_counters_and_warns_on_wrong_total_size(): - """The accounting stream the Trainer iterates updates the public counters and - fires the one-shot total_size sanity check at the epoch boundary, while plain - __iter__ stays side-effect-free.""" + """The accounting stream (``DataLoader._stream_batches``) updates the public + counters and fires the one-shot total_size sanity check at the epoch boundary, + while plain __iter__ stays side-effect-free.""" data = np.arange(40, dtype="float64").reshape(20, 2) ds = DataLoader( - _factory(data, 5), + chunked_factory(data, 5), batch_size=5, sample_shape=(2,), total_size=10_000, @@ -234,7 +227,7 @@ def test_sanity_silent_when_drop_last_truncates(): """An exactly-correct total_size does not warn when batch_size does not divide N: the trailing partial batch is dropped by design.""" data = np.arange(25, dtype="float64").reshape(25, 1) - ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) + ds = DataLoader(chunked_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) list(ds._stream_batches()) @@ -244,7 +237,9 @@ def test_sanity_silent_for_auto_resolved_non_divisible_n(): """total_size='auto' must not warn against the N it just resolved.""" data = np.arange(25, dtype="float64").reshape(25, 1) with pytest.warns(UserWarning, match="counting pass"): - ds = DataLoader(_factory(data, 5), batch_size=10, sample_shape=(1,), total_size="auto") + ds = DataLoader( + chunked_factory(data, 5), batch_size=10, sample_shape=(1,), total_size="auto" + ) with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) list(ds._stream_batches()) @@ -254,7 +249,7 @@ def test_sanity_check_counts_the_completed_pass_not_cumulative_rows(): """A partially consumed stray stream must not inflate the epoch-boundary check: with a correct total_size, the next full pass stays silent.""" data = np.arange(100, dtype="float64").reshape(100, 1) - ds = DataLoader(_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=100) + ds = DataLoader(chunked_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=100) stray = ds._stream_batches() for _ in range(3): next(stray) @@ -267,7 +262,7 @@ def test_sanity_check_not_fooled_by_cumulative_rows_matching_total_size(): """The converse: a wrong total_size that happens to equal the cumulative row counter must still warn after a true full pass.""" data = np.arange(100, dtype="float64").reshape(100, 1) - ds = DataLoader(_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=130) + ds = DataLoader(chunked_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=130) stray = ds._stream_batches() for _ in range(3): next(stray)