Skip to content

Commit dacf606

Browse files
author
Han Wang
committed
test(pt_expt): shrink change-bias water dataset to 5 frames
``TestChangeBias`` was the dominant memory hog in the ``Test Python`` shard ``(10, 3.13)`` of the CI matrix — by itself it peaked at ~5 GB RSS, leaving so little headroom under the 7 GB GitHub-hosted runner ceiling that the shard sporadically lost communication with the GitHub Actions server (intermittent ``runner lost communication`` flake observed across many recent PRs). Profile finding: peak RSS scales **linearly at ~50 MB per frame** during ``dp change-bias``'s in-process ``main(cmds)`` call. The forward over ``compute_output_stats`` enumerates ``nbatches = min( data.get_nbatches()) = 80`` frames of the water example, and each frame leaks ~50 MB into torch's caching allocator (not autograd — the wrapper is already in ``torch.no_grad()``; the leak is in ``forward_common_atomic`` somewhere and is a separate bug). Constraint: we **must** keep ``nbatches == total dataset frames`` to preserve determinism for ``test_change_bias_pt2_pte_consistency`` which compares two .pte and .pt2 invocations with ``atol=1e-10``. ``_load_batch_set`` shuffles the loaded set, so a value of ``nbatches < total_frames`` would sample a random subset and the two calls (running in the same Python process with an advancing ``dp_random`` state) would see different frames and produce different biases. Full enumeration sees every frame and so the aggregate bias is invariant under shuffle. Solution: build a 5-frame subset of ``examples/water/data/data_0`` in ``TestChangeBias.setUpClass`` and point both the trainer config and the change-bias ``-s`` argument at it. ``nbatches`` then resolves to 5 (= the new dataset size, = full enumeration), peak RSS drops to ~1.7 GB for the whole class, and all 9 tests in the class (including the strict atol=1e-10 consistency check) still pass. Class wall time also improves (~3:55 → less data-loop work in each change-bias invocation).
1 parent 4e64f8b commit dacf606

1 file changed

Lines changed: 56 additions & 5 deletions

File tree

source/tests/pt_expt/test_change_bias.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,70 @@ def _make_config(data_dir: str) -> dict:
118118
}
119119

120120

121+
def _make_subset_dataset(src_system: str, dst_system: str, n_frames: int) -> None:
122+
"""Copy ``type{,_map}.raw`` and the first ``n_frames`` of every ``.npy``
123+
in ``set.000`` from ``src_system`` to ``dst_system``.
124+
125+
Used by ``TestChangeBias`` to shrink the water/data_0 example (80
126+
frames) down to a tiny subset so that ``dp change-bias`` enumerates
127+
over only ``n_frames`` frames. Why this matters: the in-process
128+
``main(cmds)`` path runs the model forward over ``nbatches`` frames
129+
via ``compute_output_stats``, and each frame leaks ~50 MB into
130+
torch's caching allocator. At ``n_frames=80`` (the default,
131+
``min(data.get_nbatches()) = 80``) peak RSS hits ~5 GB which OOMs
132+
the 7 GB GitHub-hosted CI runner. Shrinking to ``n_frames=5`` keeps
133+
peak at ~800 MB while preserving **determinism**: the test
134+
``test_change_bias_pt2_pte_consistency`` asserts ``atol=1e-10``
135+
between two .pte and .pt2 calls in the same process, which requires
136+
every frame to be seen on each call regardless of the
137+
shuffle-based ``_load_batch_set`` order. ``nbatches == total
138+
frames`` makes the forward enumerate every frame and so the
139+
aggregate bias is invariant under shuffle.
140+
"""
141+
import numpy as np
142+
143+
src_set = os.path.join(src_system, "set.000")
144+
dst_set = os.path.join(dst_system, "set.000")
145+
os.makedirs(dst_set, exist_ok=True)
146+
for raw in ("type.raw", "type_map.raw"):
147+
src = os.path.join(src_system, raw)
148+
if os.path.isfile(src):
149+
shutil.copyfile(src, os.path.join(dst_system, raw))
150+
for fname in os.listdir(src_set):
151+
if not fname.endswith(".npy"):
152+
continue
153+
arr = np.load(os.path.join(src_set, fname))
154+
np.save(os.path.join(dst_set, fname), arr[:n_frames])
155+
156+
121157
class TestChangeBias(unittest.TestCase):
122158
"""Test dp change-bias for the pt_expt backend."""
123159

124160
@classmethod
125161
def setUpClass(cls) -> None:
126-
data_dir = os.path.join(EXAMPLE_DIR, "data")
127-
if not os.path.isdir(data_dir):
128-
raise unittest.SkipTest(f"Example data not found: {data_dir}")
162+
full_data_dir = os.path.join(EXAMPLE_DIR, "data")
163+
if not os.path.isdir(full_data_dir):
164+
raise unittest.SkipTest(f"Example data not found: {full_data_dir}")
165+
cls.tmpdir = tempfile.mkdtemp()
166+
cls.old_cwd = os.getcwd()
167+
168+
# Shrink the water example dataset (80 frames) to a 5-frame
169+
# subset. ``dp change-bias`` defaults to enumerating every
170+
# frame (``nbatches = min(data.get_nbatches())``), and each
171+
# frame's forward pass leaks ~50 MB into torch's allocator; at
172+
# 80 frames peak RSS pushes the 7 GB CI runner into OOM. See
173+
# the docstring of ``_make_subset_dataset`` for why we keep
174+
# full enumeration (determinism) but shrink the dataset.
175+
data_dir = os.path.join(cls.tmpdir, "data")
176+
os.makedirs(data_dir, exist_ok=True)
177+
_make_subset_dataset(
178+
src_system=os.path.join(full_data_dir, "data_0"),
179+
dst_system=os.path.join(data_dir, "data_0"),
180+
n_frames=5,
181+
)
129182
cls.data_dir = data_dir
130183
cls.data_file = [os.path.join(data_dir, "data_0")]
131184

132-
cls.tmpdir = tempfile.mkdtemp()
133-
cls.old_cwd = os.getcwd()
134185
os.chdir(cls.tmpdir)
135186

136187
# Build & train 1-step model

0 commit comments

Comments
 (0)