Skip to content

Commit b5dde50

Browse files
authored
Merge pull request #56 from KempnerInstitute/mmap-init-failure-cleanup
Release partial mmaps when MemoryMappedDataset init fails
2 parents 7f3e519 + 8341a52 commit b5dde50

2 files changed

Lines changed: 150 additions & 14 deletions

File tree

kempnerforge/data/dataset.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import bisect
16+
import contextlib
1617
import logging
1718
from pathlib import Path
1819

@@ -108,20 +109,27 @@ def __init__(
108109
self._cumulative_samples: list[int] = [0]
109110
total_tokens = 0
110111

111-
for f in self._files:
112-
if self._is_bin:
113-
# Raw binary: flat array of tokens. Infer dtype from file size
114-
# or use uint32 (most common for modern tokenizers with vocab > 65535)
115-
file_size = f.stat().st_size
116-
dtype = np.uint32 if file_size % 4 == 0 else np.uint16
117-
n_tokens = file_size // np.dtype(dtype).itemsize
118-
mmap = np.memmap(str(f), dtype=dtype, mode="r", shape=(n_tokens,))
119-
else:
120-
mmap = np.load(str(f), mmap_mode="r")
121-
n_samples = len(mmap) // seq_len
122-
self._mmaps.append(mmap)
123-
total_tokens += len(mmap)
124-
self._cumulative_samples.append(self._cumulative_samples[-1] + n_samples)
112+
# If any open fails partway, close the ones we already opened so they
113+
# don't leak via the exception traceback (pytest, logger.exception,
114+
# post-mortem debuggers all pin the partial `self` and its mmaps).
115+
try:
116+
for f in self._files:
117+
if self._is_bin:
118+
# Raw binary: flat array of tokens. Infer dtype from file size
119+
# or use uint32 (most common for modern tokenizers with vocab > 65535)
120+
file_size = f.stat().st_size
121+
dtype = np.uint32 if file_size % 4 == 0 else np.uint16
122+
n_tokens = file_size // np.dtype(dtype).itemsize
123+
mmap = np.memmap(str(f), dtype=dtype, mode="r", shape=(n_tokens,))
124+
else:
125+
mmap = np.load(str(f), mmap_mode="r")
126+
n_samples = len(mmap) // seq_len
127+
self._mmaps.append(mmap)
128+
total_tokens += len(mmap)
129+
self._cumulative_samples.append(self._cumulative_samples[-1] + n_samples)
130+
except Exception:
131+
self._close_mmaps()
132+
raise
125133

126134
self._total_samples = self._cumulative_samples[-1]
127135
logger.info(
@@ -177,6 +185,27 @@ def load_state_dict(self, state: dict) -> None:
177185
"""Restore from checkpoint. Only ``epoch`` is restored; sample count is derived."""
178186
self._epoch = state.get("epoch", 0)
179187

188+
def _close_mmaps(self) -> None:
189+
"""Release the underlying mmap objects. Idempotent."""
190+
for mm in self._mmaps:
191+
inner = getattr(mm, "_mmap", None)
192+
if inner is not None and not inner.closed:
193+
# BufferError: live views into the mapping still exist — can't
194+
# force-close safely; drop the ref and let GC finish it.
195+
# ValueError: already closed by another code path.
196+
with contextlib.suppress(BufferError, ValueError):
197+
inner.close()
198+
self._mmaps.clear()
199+
200+
def close(self) -> None:
201+
"""Release the underlying mmaps. Preferred path; do not rely on ``__del__``."""
202+
self._close_mmaps()
203+
204+
def __del__(self) -> None:
205+
"""GC safety net only. Prefer explicit :meth:`close`."""
206+
with contextlib.suppress(Exception):
207+
self._close_mmaps()
208+
180209

181210
class HuggingFaceDataset(Dataset):
182211
"""HuggingFace dataset with on-the-fly tokenization and sequence packing.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Unit tests for MemoryMappedDataset mmap lifecycle.
2+
3+
Covers the partial-open failure path (prior mmaps must be released so they
4+
don't leak through exception tracebacks) and the success-path invariant
5+
(mmaps stay open for the life of the dataset).
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from unittest.mock import patch
11+
12+
import numpy as np
13+
import pytest
14+
15+
from kempnerforge.data.dataset import MemoryMappedDataset
16+
17+
18+
def _write_npy_files(tmp_path, n_files: int, tokens_per_file: int = 1024) -> None:
19+
for i in range(n_files):
20+
arr = np.arange(tokens_per_file, dtype=np.uint32)
21+
np.save(tmp_path / f"shard_{i:03d}.npy", arr)
22+
23+
24+
def test_partial_open_failure_closes_prior_mmaps(tmp_path):
25+
"""If np.load raises partway through, the mmaps already opened must be closed.
26+
27+
Without the fix, the prior mmaps stay live through any exception traceback
28+
(pytest frames, logger.exception, post-mortem debuggers), accumulating
29+
virtual-memory mappings on Lustre/NFS clusters under retry loops.
30+
"""
31+
_write_npy_files(tmp_path, n_files=5)
32+
33+
original = np.load
34+
calls = {"n": 0}
35+
opened_mmaps: list = []
36+
37+
def flaky(*args, **kwargs):
38+
calls["n"] += 1
39+
if calls["n"] == 3:
40+
raise RuntimeError("simulated Lustre hiccup")
41+
mm = original(*args, **kwargs)
42+
opened_mmaps.append(mm)
43+
return mm
44+
45+
with (
46+
patch("kempnerforge.data.dataset.np.load", side_effect=flaky),
47+
pytest.raises(RuntimeError, match="Lustre hiccup"),
48+
):
49+
MemoryMappedDataset(str(tmp_path), seq_len=128)
50+
51+
assert len(opened_mmaps) == 2, f"expected 2 opens before failure, got {len(opened_mmaps)}"
52+
for mm in opened_mmaps:
53+
inner = getattr(mm, "_mmap", None)
54+
assert inner is not None
55+
assert inner.closed, "mmap was not closed after __init__ raised"
56+
57+
58+
def test_close_is_idempotent_and_releases_mmaps(tmp_path):
59+
_write_npy_files(tmp_path, n_files=3)
60+
ds = MemoryMappedDataset(str(tmp_path), seq_len=128)
61+
62+
inners = [mm._mmap for mm in ds._mmaps]
63+
assert all(not i.closed for i in inners)
64+
65+
ds.close()
66+
assert all(i.closed for i in inners)
67+
68+
# Second close is a no-op, not a crash.
69+
ds.close()
70+
71+
72+
def test_successful_init_keeps_mmaps_open(tmp_path):
73+
"""Regression guard: the fix must not close mmaps on the success path."""
74+
_write_npy_files(tmp_path, n_files=2)
75+
ds = MemoryMappedDataset(str(tmp_path), seq_len=128)
76+
assert all(not mm._mmap.closed for mm in ds._mmaps)
77+
sample = ds[0]
78+
assert "input_ids" in sample
79+
ds.close()
80+
81+
82+
def test_partial_open_failure_on_bin_files(tmp_path):
83+
"""Same leak guarantee applies to the .bin branch."""
84+
for i in range(4):
85+
(tmp_path / f"shard_{i:03d}.bin").write_bytes(np.arange(512, dtype=np.uint32).tobytes())
86+
87+
original = np.memmap
88+
opened_mmaps: list = []
89+
calls = {"n": 0}
90+
91+
def flaky_memmap(*args, **kwargs):
92+
calls["n"] += 1
93+
if calls["n"] == 2:
94+
raise RuntimeError("simulated bin open failure")
95+
mm = original(*args, **kwargs)
96+
opened_mmaps.append(mm)
97+
return mm
98+
99+
with (
100+
patch("kempnerforge.data.dataset.np.memmap", side_effect=flaky_memmap),
101+
pytest.raises(RuntimeError, match="bin open failure"),
102+
):
103+
MemoryMappedDataset(str(tmp_path), seq_len=128, file_pattern="*.bin")
104+
105+
assert len(opened_mmaps) == 1
106+
inner = getattr(opened_mmaps[0], "_mmap", None)
107+
assert inner is not None and inner.closed

0 commit comments

Comments
 (0)