Streaming variational inference: out-of-core DataLoader for minibatch ADVI#8325
Streaming variational inference: out-of-core DataLoader for minibatch ADVI#8325YichengYang-Ethan wants to merge 27 commits into
Conversation
Documentation build overview
15 files changed ·
|
zaxtax
left a comment
There was a problem hiding this comment.
Leaving some comments as you prepare the draft!
|
|
||
|
|
||
| def test_plain_loader_rebatches_arbitrary_blocks(): | ||
| # blocks of 3 with batch_size=4: re-batched in order; the trailing 2 rows that |
There was a problem hiding this comment.
Reserve comments for non-obvious code flow
There was a problem hiding this comment.
Done, every test describes itself in a docstring now
| batches = list(src()) | ||
| assert all(b.shape == (10, 1) for b in batches) | ||
| seen = np.sort(np.concatenate([b.ravel() for b in batches])) | ||
| # 140 rows, batch 10 -> 14 full batches, nothing dropped (140 % 10 == 0) |
| np.testing.assert_array_equal(ds.as_tensor().get_value(), np.full((4, 1), 3.0)) | ||
|
|
||
|
|
||
| def test_shuffle_buffer_conserves_rows_non_dividing(): |
There was a problem hiding this comment.
Use docstrings not comments for describing the test
|
|
||
|
|
||
| def test_total_size_rescales_logp_like_minibatch(): | ||
| # observed=buf[:, k] + total_size=N must scale the observed log-likelihood by |
| DataLoader(_chunks(data, 4), batch_size=True, sample_shape=(1,), total_size=8) | ||
|
|
||
|
|
||
| def test_shuffle_buffer_reshuffles_across_epochs(): |
There was a problem hiding this comment.
This comment would be better done with a more descriptive test name
| np.zeros((batch_size, *self._sample_shape), dtype=dtype), name=name | ||
| ) | ||
|
|
||
| # ----- read-only state --------------------------------------------------- |
| """ | ||
| return self._shared | ||
|
|
||
| def advance(self) -> None: |
There was a problem hiding this comment.
Is advance used in any public API here or in pytorch?
There was a problem hiding this comment.
No, torch's DataLoader has no advance and there's no public one here; I removed that path the only advance left is a private _advance closure inside Trainer.fit that drives set_data
|
|
||
| Notes | ||
| ----- | ||
| This is the *starting point* Rob suggested: the per-step ``set_data`` logic |
There was a problem hiding this comment.
I hope this note won't be here once it's draft ready ;)
|
|
||
| Returns whatever :func:`pymc.fit` returns for the chosen method. | ||
| """ | ||
| from pymc.model import modelcontext |
There was a problem hiding this comment.
I would do import as instead of an internal import here.
|
|
||
|
|
||
| def test_dataloader_shuffle_true_yields_full_batches(): | ||
| # shuffle=True wraps the source in a bounded shuffle_buffer internally; batches |
There was a problem hiding this comment.
Should be a docstring not a comment. Please apply this to the rest of the tests
pm.Minibatch random-indexes a fully-resident array (peak memory O(N)). StreamingDataset feeds minibatches from an arbitrary source into a small pytensor.shared buffer (peak memory O(batch_size)), reusing the existing total_size / create_minibatch_rv rescaling unchanged. Adds a shuffle_buffer helper and an equivalence test (streaming ADVI == in-RAM pm.Minibatch ADVI).
Close three silent-corruption holes found in a 5-lens review: - reject total_size <= 0 in __init__: 0 is falsy and skips the N/batch_size rescaling entirely (posterior collapses to prior); negative flips the data log-likelihood's sign via get_scaling. - shuffle_buffer now accumulates max(buffer_size, batch_size) rows before emitting, so buffer_size < batch_size no longer silently discards the whole stream; also validate buffer_size/batch_size as positive ints. - positive-int checks use numbers.Integral (accept numpy ints, reject bool). +5 regression tests; existing 10 unchanged and passing.
A seeded shuffle_buffer rebuilt its RNG from the same seed on every factory call, so under cycle=True every epoch replayed one fixed permutation -- which weakens the very mixing the buffer exists to provide and compounds the block-shuffle bias on ordered data. Derive a fresh per-epoch sub-stream from a SeedSequence so the order differs across epochs while staying reproducible for a given seed. +2 tests.
Cuts the "user must pass total_size" burden (open question pymc-devs#1 for the design review): - total_size="auto" resolves N from a source's .n_rows (cheap -- e.g. Parquet footer metadata via the new parquet_source) else one counting pass over a finite, re-readable source. One-shot / infinite sources still pass total_size explicitly (and are rejected with a clear error under "auto"). - a free sanity check using the existing rows_streamed counter: at the first epoch boundary, warn if total_size grossly disagrees with the rows actually streamed in one pass (catches a wrong-but-positive total_size). - parquet_source(directory): a finite, re-readable source carrying .n_rows read from Parquet metadata (no data scan). +7 tests; the existing 17 are unchanged and still pass.
…tion
An adversarial re-review surfaced edge cases the first hardening pass missed:
- total_size / batch_size: numpy integers were accepted but stored unchanged,
so a stored np.int64 reached create_minibatch_rv and raised "Invalid type
for total_size". Normalize to Python int at construction.
- _make_factory: a zero-arg factory returning a non-iterator iterable (e.g. a
list) crashed in __next__ ("'list' object is not an iterator"); wrap in iter().
- total_size="auto": a factory that returns the same one-shot iterator each call
now raises, instead of leaving the first advance() empty.
- fit_callback: seeds the buffer by default. PyMC runs callbacks after each
step, so an unseeded first step trained on the zero-initialized placeholder.
- _validate: a 0-D batch now raises a clear ValueError instead of IndexError.
Adds 7 regression tests (31 total).
…ze="auto"
shuffle_buffer now propagates a known .n_rows (e.g. parquet_source's, read from
Parquet metadata) to its wrapped factory, so the common composition
StreamingDataset(shuffle_buffer(parquet_source(dir)), total_size="auto")
resolves N for free instead of doing a full counting pass over the data. The
only discrepancy is the single dropped trailing partial batch (< batch_size
rows), which is within the auto-size sanity tolerance.
Adds 2 regression tests (33 total).
…iner Design-review feedback from Rob (mentor): the streaming API should mirror torch.utils.data so the mental model transfers, and the user-facing callback should go away in favour of a Lightning-style trainer. - IterableDataset: re-iterable out-of-core source base (parquet_source now returns one); carries an optional .n_rows for total_size="auto". - DataLoader: the former StreamingDataset, renamed; gains PyTorch-style shuffle=/buffer_size=/seed= (wraps shuffle_buffer internally). Still owns the fixed pytensor.shared buffer the model observes; advance()/as_tensor() kept. - Trainer: Trainer(method="advi").fit(model, loader, n) drives VI with NO user-facing callbacks -- it seeds the buffer and advances it each step internally. The per-step advance is wired into pm.fit privately. All hardening preserved (int normalization, total_size guards + "auto", shuffle row-conservation + per-epoch reshuffle, copy-before-borrow, validation). shuffle_buffer/parquet_source stay public. 36 tests pass (1 skipped: pyarrow). total_size still appears in the model (total_size=loader.total_size); removing it is an open design question for Rob -- see notes. It is compiled into the logp graph at register_rv time (MinibatchRandomVariable Op), so fit-time injection needs either Trainer graph surgery or a dims-based rule in core.
Follows jessegrabowski/pymc VI_Overview.ipynb (the VI rework Rob/Jesse are
building) instead of my ad-hoc shapes:
- DataLoader.__len__ == total_size N (sized like a PyTorch DataLoader), and
__iter__ yields the validated minibatch stream. This is the answer to Rob's
open question: total_size leaves the model and becomes len(loader).
- Trainer takes (method=, dataloader=, model=, data_name=) and fit(n); it streams
each minibatch into the model's pm.Data placeholder via model.set_data, so the
model is fully decoupled from the loader and the user writes no callbacks.
- Model idiom is now pm.Data("batch", placeholder) + total_size=len(loader),
matching the blueprint; verified end-to-end (recovers in-RAM pm.Minibatch ADVI).
- Kept the as_tensor()/advance() shared-buffer path as a documented advanced
escape hatch; dropped the now-unused _seed_buffer/_advance_callback.
38 tests pass (1 skipped: pyarrow). Open for Rob: spelling DataLoader (PyTorch,
per his "match PyTorch") vs Dataloader (Jesse's draft); method-as-string until
the ADVI(Inference).step rework lands.
- Trainer's stream now updates batches_seen/rows_streamed and runs the one-shot total_size sanity check at each epoch boundary (previously dead on the Trainer path; __iter__ stays side-effect-free). - total_size="auto" with shuffle=True counts the unshuffled source, fixing an undercount of up to batch_size-1 rows. - Trainer default data_name "data" -> "batch" to match the examples/tests. - Clarify len(loader)==N (rows, not batches) in docstrings; raise a clear error when a cycled source restarts empty. - Register the streaming API in docs/source/api/vi.rst. - Add regression tests for the auto-size shuffle count and Trainer counters.
The non-shuffle path previously required the source to yield exact batch_size blocks and raised on anything else, while the docstrings promised re-batching. Now both paths re-batch: blocks of any size are sliced in order with remainders carried across blocks, and a raw array (or any single-sample stream) is accepted directly, so the VI-rework sketch usage Dataloader(<array>, batch_size=...) works as written. Trailing rows that do not fill a final batch are dropped, like drop_last=True in torch, since the model observes a fixed-shape placeholder. Also: total_size="auto" counts a single-sample stream as rows rather than flattened elements; Trainer.fit(callbacks=...) appends user callbacks after the internal advance instead of raising a duplicate keyword error.
- Drop the shared-buffer path (as_tensor/advance and the cycle/name parameters): neither exists in torch.utils.data and the Trainer never used it. Manual stepping stays available through plain iteration plus set_data. - Move modelcontext/fit imports to module level. - Replace test comments with docstrings, drop redundant comments and section banners, and rename the reshuffle test descriptively.
shuffle_buffer concatenates yields along the leading axis, so a raw array source under shuffle=True had its rows flattened (2-D) or crashed on shape[0] (scalars). Promote single samples to one-row blocks before the shuffle wrap, with the same helper the re-batcher uses. Also tighten a few docstring claims: the parquet dtype follows the file columns, and the shuffle buffer bound is stated as rows held.
DataLoader infers sample_shape from a raw array source, so DataLoader(arr, batch_size=...) batches rows instead of silently flattening them to scalars. The total_size check no longer warns on an exact N when drop-last truncates the final batch, and its advice covers a wrong source n_rows. Trainer.fit routes all kwargs through one merge so constructor defaults like random_seed work as documented, accepts an Inference instance, and rejects an unknown data_name before consuming a batch. parquet_source validates columns against the schema up front. The shuffle_buffer docstring states the true buffer bound.
aa560c8 to
13a6a05
Compare
- Trainer.fit(n) consumes exactly n batches: the advance after the final step is skipped, so a finite source is not over-consumed - the total_size sanity check counts the pass that completed instead of the cumulative row counter, which inflated across partial streams - parquet_source freezes the column order at construction and reads one row group at a time, so a permuted shard schema cannot silently swap features and peak read memory is a row group, not a file - warn at construction when a fixed-order loader would drop the same non-divisible tail every pass - total_size='auto' probes that the factory can actually be re-read, catching factories that close over an already-consumed iterator - document the shuffle-buffer transient concatenation copy and the full-buffer case
…diagnostics - the internal advance skips only fit's own final step, so Inference.refine on a method instance keeps streaming instead of silently retraining on the last batch - keep the rebatcher one batch ahead in the accounting stream, so the total_size sanity check still fires when fit(n) stops exactly at the pass boundary - drop the fixed-order divisibility warning: it false-alarmed on the module's own pre-shuffled-on-disk example and on manual shuffle_buffer wrapping; the drop-last caveat lives in the docs instead - validate n in Trainer.fit (fit(0) consumed the seed batch; fit(-1) failed deep inside PyTensor) - normalize shuffle_buffer's factory output with iter(), which a re-iterable-returning factory would otherwise restart every fill - parquet_source rejects non-numeric columns at construction and names the shard when a later file is missing a frozen column - name the sample_shape remedy in the block-shape error; spell behavior consistently
The class summary still claimed the full dataset never enters memory in the absolute; match the module docstring's bounded-source-chunks framing and fix a double space.
- _ParquetDataset checks each shard's column types, so a later shard whose column turned non-numeric is named instead of failing as an opaque float cast downstream (parquet_source only saw the first shard) - the fit docstring no longer says 'exactly n consumed'; it feeds exactly n batches to the model, but the one-batch lookahead can read a re-readable source one batch further - the refine test now uses distinct batch markers and pins the honest resume-not-reseed behavior (its first step reuses fit's last batch) instead of only checking counters on all-ones data
| ) | ||
|
|
||
|
|
||
| def _chunks(data, size): |
There was a problem hiding this comment.
fresh iterator each call (a bare generator would be one-shot). It yields the data a chunk of rows at a time, the way an out-of-core source like I can fold into a conftest if you'd prefer.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8325 +/- ##
===========================================
+ Coverage 79.80% 91.66% +11.86%
===========================================
Files 125 126 +1
Lines 20526 20810 +284
===========================================
+ Hits 16380 19076 +2696
+ Misses 4146 1734 -2412
🚀 New features to boost your workflow:
|
Keep this PR to the dataset/loader layer (IterableDataset, DataLoader, shuffle_buffer, parquet_source); the Trainer and its tests move to a stacked follow-up PR. Fold the re-readable chunked-source factory the loader tests share into tests/variational/streaming_helpers.py, which doubles as a place to explain why a re-readable factory (not a one-shot generator) is needed.
|
@ricardoV94 Do we want this here or in extras? |
|
If it can go in extras, let's go there first |
Draft for mentor review (GSoC 2026, Streaming Variational Inference). Not for merge yet.
Split out per review: this PR is now the out-of-core data layer only; the
Trainerfollows in a stacked PR once this lands.Adds
pymc.variational.streaming, an out-of-core minibatch VI path for data that does not fit in RAM. The API matchestorch.utils.data:IterableDataset/parquet_source: a re-iterable, out-of-core source of rows.parquet_sourcereads one row group at a time, with the column order frozen at construction so a shard with a permuted schema cannot silently swap features.DataLoader: batches and optionally shuffles the source. Blocks of any size are re-batched in order, and a raw array works directly:DataLoader(np.random.normal(size=(10_000, 2)), batch_size=64)from the VI overview sketch yields(64, 2)row batches (sample_shapedefaults toarr.shape[1:]for an array source). Trailing rows that do not fill a final batch are dropped, likedrop_last=Truein torch, because the model observes a fixed-shape placeholder; withshuffle=Truethe dropped remainder is re-drawn each epoch, while a fixed replay order keeps dropping the same tail rows (documented, and the second open question below). The loader is sized, sototal_size=len(loader)gives theN / batch_sizeELBO rescaling (the same mechanism aspm.Minibatch).With bounded source chunks the full dataset never sits in RAM at once: the model observes only a
(batch_size, *sample_shape)placeholder that is overwritten with the next minibatch each step, and peak memory is bounded by the batch, the source chunks in flight, and the optional shuffle buffer (concatenating a fill briefly holds a second copy) — independent ofN.Status: 55 tests in
tests/variational/test_streaming.pyandtest_streaming_autosize.py. Thetotal_sizelog-likelihood rescaling is checked numerically (scaled by exactlyN / batch_size, the same mechanism aspm.Minibatch), alongside the re-batching, shuffle-buffer,parquet_source, andtotal_size="auto"behavior.Open questions for discussion:
total_sizecan leave the model body (currentlylen(loader)).len(loader)returns the row countN, matching the VI overview note (__len__istotal_size), not the batch count that torch returns.Example notebook: pymc-devs/pymc-examples#888