Skip to content

Streaming variational inference: out-of-core DataLoader for minibatch ADVI#8325

Draft
YichengYang-Ethan wants to merge 27 commits into
pymc-devs:mainfrom
YichengYang-Ethan:streaming-dataset-draft
Draft

Streaming variational inference: out-of-core DataLoader for minibatch ADVI#8325
YichengYang-Ethan wants to merge 27 commits into
pymc-devs:mainfrom
YichengYang-Ethan:streaming-dataset-draft

Conversation

@YichengYang-Ethan

@YichengYang-Ethan YichengYang-Ethan commented Jun 9, 2026

Copy link
Copy Markdown

Draft for mentor review (GSoC 2026, Streaming Variational Inference). Not for merge yet.

Split out per review: this PR is now the out-of-core data layer only; the Trainer follows in a stacked PR once this lands.

Adds pymc.variational.streaming, an out-of-core minibatch VI path for data that does not fit in RAM. The API matches torch.utils.data:

  • IterableDataset / parquet_source: a re-iterable, out-of-core source of rows. parquet_source reads one row group at a time, with the column order frozen at construction so a shard with a permuted schema cannot silently swap features.
  • DataLoader: batches and optionally shuffles the source. Blocks of any size are re-batched in order, and a raw array works directly: DataLoader(np.random.normal(size=(10_000, 2)), batch_size=64) from the VI overview sketch yields (64, 2) row batches (sample_shape defaults to arr.shape[1:] for an array source). Trailing rows that do not fill a final batch are dropped, like drop_last=True in torch, because the model observes a fixed-shape placeholder; with shuffle=True the dropped remainder is re-drawn each epoch, while a fixed replay order keeps dropping the same tail rows (documented, and the second open question below). The loader is sized, so total_size=len(loader) gives the N / batch_size ELBO rescaling (the same mechanism as pm.Minibatch).

With bounded source chunks the full dataset never sits in RAM at once: the model observes only a (batch_size, *sample_shape) placeholder that is overwritten with the next minibatch each step, and peak memory is bounded by the batch, the source chunks in flight, and the optional shuffle buffer (concatenating a fill briefly holds a second copy) — independent of N.

Status: 55 tests in tests/variational/test_streaming.py and test_streaming_autosize.py. The total_size log-likelihood rescaling is checked numerically (scaled by exactly N / batch_size, the same mechanism as pm.Minibatch), alongside the re-batching, shuffle-buffer, parquet_source, and total_size="auto" behavior.

Open questions for discussion:

  • Whether total_size can leave the model body (currently len(loader)).
  • Whether the loader should carry the end-of-pass remainder into the next epoch instead of dropping it, so every row participates even without shuffling.
  • len(loader) returns the row count N, matching the VI overview note (__len__ is total_size), not the batch count that torch returns.

Example notebook: pymc-devs/pymc-examples#888

@zaxtax zaxtax left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some comments as you prepare the draft!

Comment thread tests/variational/test_streaming.py Outdated


def test_plain_loader_rebatches_arbitrary_blocks():
# blocks of 3 with batch_size=4: re-batched in order; the trailing 2 rows that

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reserve comments for non-obvious code flow

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, every test describes itself in a docstring now

Comment thread tests/variational/test_streaming.py Outdated
batches = list(src())
assert all(b.shape == (10, 1) for b in batches)
seen = np.sort(np.concatenate([b.ravel() for b in batches]))
# 140 rows, batch 10 -> 14 full batches, nothing dropped (140 % 10 == 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment

Comment thread tests/variational/test_streaming.py Outdated
np.testing.assert_array_equal(ds.as_tensor().get_value(), np.full((4, 1), 3.0))


def test_shuffle_buffer_conserves_rows_non_dividing():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use docstrings not comments for describing the test

Comment thread tests/variational/test_streaming.py Outdated


def test_total_size_rescales_logp_like_minibatch():
# observed=buf[:, k] + total_size=N must scale the observed log-likelihood by

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear comment

Comment thread tests/variational/test_streaming.py Outdated
DataLoader(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8)


def test_shuffle_buffer_reshuffles_across_epochs():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment would be better done with a more descriptive test name

Comment thread pymc/variational/streaming.py Outdated
np.zeros((batch_size, *self._sample_shape), dtype=dtype), name=name
)

# ----- read-only state ---------------------------------------------------

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove all these banners

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread pymc/variational/streaming.py Outdated
"""
return self._shared

def advance(self) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is advance used in any public API here or in pytorch?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, torch's DataLoader has no advance and there's no public one here; I removed that path the only advance left is a private _advance closure inside Trainer.fit that drives set_data

Comment thread pymc/variational/streaming.py Outdated

Notes
-----
This is the *starting point* Rob suggested: the per-step ``set_data`` logic

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this note won't be here once it's draft ready ;)

Comment thread pymc/variational/streaming.py Outdated

Returns whatever :func:`pymc.fit` returns for the chosen method.
"""
from pymc.model import modelcontext

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do import as instead of an internal import here.

Comment thread tests/variational/test_streaming.py Outdated


def test_dataloader_shuffle_true_yields_full_batches():
# shuffle=True wraps the source in a bounded shuffle_buffer internally; batches

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a docstring not a comment. Please apply this to the rest of the tests

pm.Minibatch random-indexes a fully-resident array (peak memory O(N)).
StreamingDataset feeds minibatches from an arbitrary source into a small
pytensor.shared buffer (peak memory O(batch_size)), reusing the existing
total_size / create_minibatch_rv rescaling unchanged. Adds a shuffle_buffer
helper and an equivalence test (streaming ADVI == in-RAM pm.Minibatch ADVI).
Close three silent-corruption holes found in a 5-lens review:
- reject total_size <= 0 in __init__: 0 is falsy and skips the N/batch_size
  rescaling entirely (posterior collapses to prior); negative flips the data
  log-likelihood's sign via get_scaling.
- shuffle_buffer now accumulates max(buffer_size, batch_size) rows before
  emitting, so buffer_size < batch_size no longer silently discards the whole
  stream; also validate buffer_size/batch_size as positive ints.
- positive-int checks use numbers.Integral (accept numpy ints, reject bool).

+5 regression tests; existing 10 unchanged and passing.
A seeded shuffle_buffer rebuilt its RNG from the same seed on every factory
call, so under cycle=True every epoch replayed one fixed permutation -- which
weakens the very mixing the buffer exists to provide and compounds the
block-shuffle bias on ordered data. Derive a fresh per-epoch sub-stream from a
SeedSequence so the order differs across epochs while staying reproducible for
a given seed. +2 tests.
Cuts the "user must pass total_size" burden (open question pymc-devs#1 for the design review):

- total_size="auto" resolves N from a source's .n_rows (cheap -- e.g. Parquet
  footer metadata via the new parquet_source) else one counting pass over a
  finite, re-readable source. One-shot / infinite sources still pass total_size
  explicitly (and are rejected with a clear error under "auto").
- a free sanity check using the existing rows_streamed counter: at the first
  epoch boundary, warn if total_size grossly disagrees with the rows actually
  streamed in one pass (catches a wrong-but-positive total_size).
- parquet_source(directory): a finite, re-readable source carrying .n_rows read
  from Parquet metadata (no data scan).

+7 tests; the existing 17 are unchanged and still pass.
…tion

An adversarial re-review surfaced edge cases the first hardening pass missed:

- total_size / batch_size: numpy integers were accepted but stored unchanged,
  so a stored np.int64 reached create_minibatch_rv and raised "Invalid type
  for total_size". Normalize to Python int at construction.
- _make_factory: a zero-arg factory returning a non-iterator iterable (e.g. a
  list) crashed in __next__ ("'list' object is not an iterator"); wrap in iter().
- total_size="auto": a factory that returns the same one-shot iterator each call
  now raises, instead of leaving the first advance() empty.
- fit_callback: seeds the buffer by default. PyMC runs callbacks after each
  step, so an unseeded first step trained on the zero-initialized placeholder.
- _validate: a 0-D batch now raises a clear ValueError instead of IndexError.

Adds 7 regression tests (31 total).
…ze="auto"

shuffle_buffer now propagates a known .n_rows (e.g. parquet_source's, read from
Parquet metadata) to its wrapped factory, so the common composition

    StreamingDataset(shuffle_buffer(parquet_source(dir)), total_size="auto")

resolves N for free instead of doing a full counting pass over the data. The
only discrepancy is the single dropped trailing partial batch (< batch_size
rows), which is within the auto-size sanity tolerance.

Adds 2 regression tests (33 total).
…iner

Design-review feedback from Rob (mentor): the streaming API should mirror
torch.utils.data so the mental model transfers, and the user-facing callback
should go away in favour of a Lightning-style trainer.

- IterableDataset: re-iterable out-of-core source base (parquet_source now
  returns one); carries an optional .n_rows for total_size="auto".
- DataLoader: the former StreamingDataset, renamed; gains PyTorch-style
  shuffle=/buffer_size=/seed= (wraps shuffle_buffer internally). Still owns the
  fixed pytensor.shared buffer the model observes; advance()/as_tensor() kept.
- Trainer: Trainer(method="advi").fit(model, loader, n) drives VI with NO
  user-facing callbacks -- it seeds the buffer and advances it each step
  internally. The per-step advance is wired into pm.fit privately.

All hardening preserved (int normalization, total_size guards + "auto",
shuffle row-conservation + per-epoch reshuffle, copy-before-borrow, validation).
shuffle_buffer/parquet_source stay public. 36 tests pass (1 skipped: pyarrow).

total_size still appears in the model (total_size=loader.total_size); removing
it is an open design question for Rob -- see notes. It is compiled into the logp
graph at register_rv time (MinibatchRandomVariable Op), so fit-time injection
needs either Trainer graph surgery or a dims-based rule in core.
Follows jessegrabowski/pymc VI_Overview.ipynb (the VI rework Rob/Jesse are
building) instead of my ad-hoc shapes:

- DataLoader.__len__ == total_size N (sized like a PyTorch DataLoader), and
  __iter__ yields the validated minibatch stream. This is the answer to Rob's
  open question: total_size leaves the model and becomes len(loader).
- Trainer takes (method=, dataloader=, model=, data_name=) and fit(n); it streams
  each minibatch into the model's pm.Data placeholder via model.set_data, so the
  model is fully decoupled from the loader and the user writes no callbacks.
- Model idiom is now pm.Data("batch", placeholder) + total_size=len(loader),
  matching the blueprint; verified end-to-end (recovers in-RAM pm.Minibatch ADVI).
- Kept the as_tensor()/advance() shared-buffer path as a documented advanced
  escape hatch; dropped the now-unused _seed_buffer/_advance_callback.

38 tests pass (1 skipped: pyarrow). Open for Rob: spelling DataLoader (PyTorch,
per his "match PyTorch") vs Dataloader (Jesse's draft); method-as-string until
the ADVI(Inference).step rework lands.
- Trainer's stream now updates batches_seen/rows_streamed and runs the
  one-shot total_size sanity check at each epoch boundary (previously dead
  on the Trainer path; __iter__ stays side-effect-free).
- total_size="auto" with shuffle=True counts the unshuffled source, fixing
  an undercount of up to batch_size-1 rows.
- Trainer default data_name "data" -> "batch" to match the examples/tests.
- Clarify len(loader)==N (rows, not batches) in docstrings; raise a clear
  error when a cycled source restarts empty.
- Register the streaming API in docs/source/api/vi.rst.
- Add regression tests for the auto-size shuffle count and Trainer counters.
The non-shuffle path previously required the source to yield exact
batch_size blocks and raised on anything else, while the docstrings
promised re-batching. Now both paths re-batch: blocks of any size are
sliced in order with remainders carried across blocks, and a raw array
(or any single-sample stream) is accepted directly, so the VI-rework
sketch usage Dataloader(<array>, batch_size=...) works as written.
Trailing rows that do not fill a final batch are dropped, like
drop_last=True in torch, since the model observes a fixed-shape
placeholder.

Also: total_size="auto" counts a single-sample stream as rows rather
than flattened elements; Trainer.fit(callbacks=...) appends user
callbacks after the internal advance instead of raising a duplicate
keyword error.
- Drop the shared-buffer path (as_tensor/advance and the cycle/name
  parameters): neither exists in torch.utils.data and the Trainer never
  used it. Manual stepping stays available through plain iteration plus
  set_data.
- Move modelcontext/fit imports to module level.
- Replace test comments with docstrings, drop redundant comments and
  section banners, and rename the reshuffle test descriptively.
shuffle_buffer concatenates yields along the leading axis, so a raw
array source under shuffle=True had its rows flattened (2-D) or crashed
on shape[0] (scalars). Promote single samples to one-row blocks before
the shuffle wrap, with the same helper the re-batcher uses.

Also tighten a few docstring claims: the parquet dtype follows the file
columns, and the shuffle buffer bound is stated as rows held.
DataLoader infers sample_shape from a raw array source, so
DataLoader(arr, batch_size=...) batches rows instead of silently
flattening them to scalars. The total_size check no longer warns on an
exact N when drop-last truncates the final batch, and its advice covers
a wrong source n_rows. Trainer.fit routes all kwargs through one merge
so constructor defaults like random_seed work as documented, accepts an
Inference instance, and rejects an unknown data_name before consuming a
batch. parquet_source validates columns against the schema up front.
The shuffle_buffer docstring states the true buffer bound.
@YichengYang-Ethan YichengYang-Ethan force-pushed the streaming-dataset-draft branch from aa560c8 to 13a6a05 Compare June 11, 2026 04:53
- Trainer.fit(n) consumes exactly n batches: the advance after the final
  step is skipped, so a finite source is not over-consumed
- the total_size sanity check counts the pass that completed instead of
  the cumulative row counter, which inflated across partial streams
- parquet_source freezes the column order at construction and reads one
  row group at a time, so a permuted shard schema cannot silently swap
  features and peak read memory is a row group, not a file
- warn at construction when a fixed-order loader would drop the same
  non-divisible tail every pass
- total_size='auto' probes that the factory can actually be re-read,
  catching factories that close over an already-consumed iterator
- document the shuffle-buffer transient concatenation copy and the
  full-buffer case
…diagnostics

- the internal advance skips only fit's own final step, so
  Inference.refine on a method instance keeps streaming instead of
  silently retraining on the last batch
- keep the rebatcher one batch ahead in the accounting stream, so the
  total_size sanity check still fires when fit(n) stops exactly at the
  pass boundary
- drop the fixed-order divisibility warning: it false-alarmed on the
  module's own pre-shuffled-on-disk example and on manual shuffle_buffer
  wrapping; the drop-last caveat lives in the docs instead
- validate n in Trainer.fit (fit(0) consumed the seed batch; fit(-1)
  failed deep inside PyTensor)
- normalize shuffle_buffer's factory output with iter(), which a
  re-iterable-returning factory would otherwise restart every fill
- parquet_source rejects non-numeric columns at construction and names
  the shard when a later file is missing a frozen column
- name the sample_shape remedy in the block-shape error; spell behavior
  consistently
The class summary still claimed the full dataset never enters memory in
the absolute; match the module docstring's bounded-source-chunks framing
and fix a double space.
- _ParquetDataset checks each shard's column types, so a later shard
  whose column turned non-numeric is named instead of failing as an
  opaque float cast downstream (parquet_source only saw the first shard)
- the fit docstring no longer says 'exactly n consumed'; it feeds exactly
  n batches to the model, but the one-batch lookahead can read a
  re-readable source one batch further
- the refine test now uses distinct batch markers and pins the honest
  resume-not-reseed behavior (its first step reuses fit's last batch)
  instead of only checking counters on all-ones data
Comment thread tests/variational/test_streaming.py Outdated
)


def _chunks(data, size):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fresh iterator each call (a bare generator would be one-shot). It yields the data a chunk of rows at a time, the way an out-of-core source like I can fold into a conftest if you'd prefer.

@codecov

codecov Bot commented Jun 14, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 85.56338% with 41 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.66%. Comparing base (9d24260) to head (cb85e8c).

Files with missing lines Patch % Lines
pymc/variational/streaming.py 85.51% 41 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #8325       +/-   ##
===========================================
+ Coverage   79.80%   91.66%   +11.86%     
===========================================
  Files         125      126        +1     
  Lines       20526    20810      +284     
===========================================
+ Hits        16380    19076     +2696     
+ Misses       4146     1734     -2412     
Files with missing lines Coverage Δ
pymc/variational/__init__.py 100.00% <100.00%> (ø)
pymc/variational/streaming.py 85.51% <85.51%> (ø)

... and 43 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Keep this PR to the dataset/loader layer (IterableDataset, DataLoader,
shuffle_buffer, parquet_source); the Trainer and its tests move to a stacked
follow-up PR. Fold the re-readable chunked-source factory the loader tests
share into tests/variational/streaming_helpers.py, which doubles as a place to
explain why a re-readable factory (not a one-shot generator) is needed.
@YichengYang-Ethan YichengYang-Ethan changed the title Streaming variational inference: out-of-core DataLoader + Trainer for minibatch ADVI Streaming variational inference: out-of-core DataLoader for minibatch ADVI Jun 16, 2026
@fonnesbeck

Copy link
Copy Markdown
Member

@ricardoV94 Do we want this here or in extras?

@ricardoV94

Copy link
Copy Markdown
Member

If it can go in extras, let's go there first

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants