Skip to content

Commit 8b94a05

Browse files
committed
Merge remote-tracking branch 'origin/main' into train-state-ownership-check
# Conflicts: # kempnerforge/checkpoint/manager.py
2 parents 338e698 + 499adcf commit 8b94a05

18 files changed

Lines changed: 1324 additions & 37 deletions

File tree

kempnerforge/checkpoint/manager.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def __init__(
100100
self._async_ckpt = AsyncCheckpointer(mode=config.async_mode)
101101
self._process_group = process_group
102102
self._pp_rank = pp_rank
103+
# Dataloader state stashed during load() when the caller cannot yet
104+
# provide a dataloader object. Applied later via
105+
# apply_dataloader_state() once the loader is constructed.
106+
self._pending_dataloader_state: dict[str, Any] | None = None
103107

104108
def _checkpoint_dir(self, step: int) -> Path:
105109
return self.base_dir / f"step_{step}"
@@ -170,6 +174,13 @@ def save(
170174
# Cleanup old checkpoints
171175
self._cleanup()
172176

177+
# save() is a collective: non-rank-0 ranks must not return until
178+
# rank-0 has committed train_state.pt, metadata.json, and the
179+
# latest symlink. Without this barrier, post-save hooks or readers
180+
# on other ranks race rank-0's writes (especially on NFS/Lustre).
181+
if dist.is_initialized():
182+
dist.barrier()
183+
173184
def wait(self) -> None:
174185
"""Block until any pending async checkpoint save completes."""
175186
self._async_ckpt.wait()
@@ -218,18 +229,46 @@ def load(
218229
if "optimizer" in dcp_state:
219230
self.optimizer.load_state_dict(dcp_state["optimizer"])
220231

221-
# Load non-distributed state
232+
# Load non-distributed state. On NFS/Lustre, independent stat()
233+
# calls can disagree briefly across ranks; if some ranks enter
234+
# this branch and others don't, the broadcast_object_list below
235+
# hangs. Use a rank-0-authoritative existence check broadcast to
236+
# all ranks so every rank takes the same branch.
222237
train_state_path = ckpt_dir / _TRAIN_STATE_FILE
223-
if train_state_path.exists():
224-
train_state = _load_train_state(train_state_path)
238+
if dist.is_initialized():
239+
exists_flag = [train_state_path.exists() if self._rank == 0 else False]
240+
dist.broadcast_object_list(exists_flag, src=0)
241+
train_state_exists = bool(exists_flag[0])
242+
else:
243+
train_state_exists = train_state_path.exists()
244+
245+
if train_state_exists:
246+
# Rank-0-authoritative: only rank 0 reads the file. The
247+
# ownership check inside ``_load_train_state`` runs there and
248+
# the resulting state is broadcast to all ranks below. Other
249+
# ranks pass ``None`` into the broadcast.
250+
train_state = (
251+
_load_train_state(train_state_path)
252+
if self._rank == 0 or not dist.is_initialized()
253+
else None
254+
)
225255

226-
# Broadcast from rank 0 to all ranks
256+
# Broadcast from rank 0 to all ranks. PyTorch 2.11's
257+
# broadcast_object_list does not accept async_op, so a per-op
258+
# timeout cannot be wired here — this call inherits the 1800s
259+
# process-group default. A wedged rank will still surface, just
260+
# later than the other fast-fail paths in this patch.
227261
if dist.is_initialized():
228262
object_list = [train_state if self._rank == 0 else None]
229263
dist.broadcast_object_list(object_list, src=0)
230264
train_state = object_list[0]
231265

232266
assert train_state is not None, "train_state broadcast failed"
267+
# Stash dataloader state if the caller can't yet provide the loader
268+
# object. Training loops construct the dataloader after load() so
269+
# apply_dataloader_state() can restore it once it exists.
270+
if dataloader is None and "dataloader" in train_state:
271+
self._pending_dataloader_state = train_state["dataloader"]
233272
step, tokens_seen, extra = restore_train_state(
234273
train_state,
235274
scheduler=scheduler,
@@ -240,6 +279,25 @@ def load(
240279

241280
return 0, 0, {}
242281

282+
def apply_dataloader_state(self, dataloader: Any) -> None:
283+
"""Apply any dataloader state stashed during load().
284+
285+
Training loops call load() before constructing the dataloader (since
286+
the dataloader depends on phase/annealing state that load() restores).
287+
This method applies the stashed state once the loader exists.
288+
289+
No-op if no state is pending, or if the loader does not support
290+
``load_state_dict`` (e.g., plain torch DataLoader for HF streaming).
291+
"""
292+
if self._pending_dataloader_state is None:
293+
return
294+
if dataloader is None or not hasattr(dataloader, "load_state_dict"):
295+
self._pending_dataloader_state = None
296+
return
297+
dataloader.load_state_dict(self._pending_dataloader_state)
298+
self._pending_dataloader_state = None
299+
logger.info("Applied stashed dataloader state")
300+
243301
def _resolve_load_path(self, path: str | None = None) -> Path | None:
244302
"""Resolve the checkpoint path to load from."""
245303
if path is not None:

kempnerforge/data/dataloader.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,12 @@ def __init__(
6767

6868
def __iter__(self):
6969
self.sampler.set_epoch(self._epoch)
70+
# Re-apply skip on every iter() so double-resume within the same epoch
71+
# stays aligned. The sampler consumes _skip once per iter(), and
72+
# _batches_yielded persists across save/load so the skip is re-computable.
73+
if self._batches_yielded > 0:
74+
self.sampler.set_skip(self._batches_yielded * self.batch_size)
7075
self._iterator = iter(self._dataloader)
71-
self._batches_yielded = 0
7276
return self
7377

7478
def __next__(self) -> dict[str, torch.Tensor]:
@@ -97,16 +101,16 @@ def state_dict(self) -> dict:
97101
}
98102

99103
def load_state_dict(self, state: dict) -> None:
100-
"""Restore from checkpoint. Restores sampler state and skips to saved batch position."""
104+
"""Restore from checkpoint. ``__iter__`` re-applies the sampler skip from
105+
``_batches_yielded``, so double-resume within the same epoch stays aligned.
106+
"""
101107
self._epoch = state.get("epoch", 0)
102-
batches_yielded = state.get("batches_yielded", 0)
108+
self._batches_yielded = state.get("batches_yielded", 0)
103109

104110
# Set sampler state for resumption
105111
if "sampler" in state:
106112
self.sampler.load_state_dict(state["sampler"])
107113

108-
# Skip ahead to the correct position in the current epoch
109-
if batches_yielded > 0:
110-
self.sampler.set_skip(batches_yielded * self.batch_size)
111-
112-
logger.info(f"Resumed DataLoader: epoch={self._epoch}, skip_batches={batches_yielded}")
114+
logger.info(
115+
f"Resumed DataLoader: epoch={self._epoch}, skip_batches={self._batches_yielded}"
116+
)

kempnerforge/data/dataset.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import bisect
16+
import contextlib
1617
import logging
1718
from pathlib import Path
1819

@@ -108,20 +109,27 @@ def __init__(
108109
self._cumulative_samples: list[int] = [0]
109110
total_tokens = 0
110111

111-
for f in self._files:
112-
if self._is_bin:
113-
# Raw binary: flat array of tokens. Infer dtype from file size
114-
# or use uint32 (most common for modern tokenizers with vocab > 65535)
115-
file_size = f.stat().st_size
116-
dtype = np.uint32 if file_size % 4 == 0 else np.uint16
117-
n_tokens = file_size // np.dtype(dtype).itemsize
118-
mmap = np.memmap(str(f), dtype=dtype, mode="r", shape=(n_tokens,))
119-
else:
120-
mmap = np.load(str(f), mmap_mode="r")
121-
n_samples = len(mmap) // seq_len
122-
self._mmaps.append(mmap)
123-
total_tokens += len(mmap)
124-
self._cumulative_samples.append(self._cumulative_samples[-1] + n_samples)
112+
# If any open fails partway, close the ones we already opened so they
113+
# don't leak via the exception traceback (pytest, logger.exception,
114+
# post-mortem debuggers all pin the partial `self` and its mmaps).
115+
try:
116+
for f in self._files:
117+
if self._is_bin:
118+
# Raw binary: flat array of tokens. Infer dtype from file size
119+
# or use uint32 (most common for modern tokenizers with vocab > 65535)
120+
file_size = f.stat().st_size
121+
dtype = np.uint32 if file_size % 4 == 0 else np.uint16
122+
n_tokens = file_size // np.dtype(dtype).itemsize
123+
mmap = np.memmap(str(f), dtype=dtype, mode="r", shape=(n_tokens,))
124+
else:
125+
mmap = np.load(str(f), mmap_mode="r")
126+
n_samples = len(mmap) // seq_len
127+
self._mmaps.append(mmap)
128+
total_tokens += len(mmap)
129+
self._cumulative_samples.append(self._cumulative_samples[-1] + n_samples)
130+
except Exception:
131+
self._close_mmaps()
132+
raise
125133

126134
self._total_samples = self._cumulative_samples[-1]
127135
logger.info(
@@ -177,6 +185,27 @@ def load_state_dict(self, state: dict) -> None:
177185
"""Restore from checkpoint. Only ``epoch`` is restored; sample count is derived."""
178186
self._epoch = state.get("epoch", 0)
179187

188+
def _close_mmaps(self) -> None:
189+
"""Release the underlying mmap objects. Idempotent."""
190+
for mm in self._mmaps:
191+
inner = getattr(mm, "_mmap", None)
192+
if inner is not None and not inner.closed:
193+
# BufferError: live views into the mapping still exist — can't
194+
# force-close safely; drop the ref and let GC finish it.
195+
# ValueError: already closed by another code path.
196+
with contextlib.suppress(BufferError, ValueError):
197+
inner.close()
198+
self._mmaps.clear()
199+
200+
def close(self) -> None:
201+
"""Release the underlying mmaps. Preferred path; do not rely on ``__del__``."""
202+
self._close_mmaps()
203+
204+
def __del__(self) -> None:
205+
"""GC safety net only. Prefer explicit :meth:`close`."""
206+
with contextlib.suppress(Exception):
207+
self._close_mmaps()
208+
180209

181210
class HuggingFaceDataset(Dataset):
182211
"""HuggingFace dataset with on-the-fly tokenization and sequence packing.

kempnerforge/data/sampler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@
1515
from torch.utils.data import Dataset, Sampler
1616

1717

18+
def _validate_weights(weights: list[float], context: str) -> None:
19+
"""Fail fast on empty, negative, or all-zero weight lists.
20+
21+
The two normalization branches in ``MixtureSampler`` disagree on all-zero
22+
input: the ``temperature == 1.0`` branch divides by ``sum(weights)`` and
23+
raises ``ZeroDivisionError``; the ``temperature != 1.0`` branch clamps via
24+
``max(w, 1e-12)`` and silently produces uniform sampling. Reject both
25+
cases up-front with a clear error so misconfigured phase transitions
26+
surface immediately instead of crashing mid-run or drifting silently.
27+
"""
28+
if not weights:
29+
raise ValueError(f"{context}: weights list is empty")
30+
if any(w < 0 for w in weights):
31+
raise ValueError(f"{context}: weights must be non-negative, got {weights}")
32+
if sum(weights) <= 0:
33+
raise ValueError(
34+
f"{context}: weights must sum to > 0 (at least one dataset must have "
35+
f"weight > 0), got {weights}"
36+
)
37+
38+
1839
class DistributedSampler(Sampler[int]):
1940
"""Deterministic distributed sampler with skip-ahead support.
2041
@@ -159,6 +180,8 @@ def __init__(
159180
self._dataset_sizes = [cumulative_sizes[i + 1] - cumulative_sizes[i] for i in range(n)]
160181
self._offsets = list(cumulative_sizes[:n])
161182

183+
_validate_weights(weights, "MixtureSampler(weights=...)")
184+
162185
# Apply temperature scaling and normalize
163186
if temperature != 1.0:
164187
import math as _math
@@ -285,6 +308,8 @@ def update_weights(self, weights: list[float], temperature: float = 1.0) -> None
285308
if len(weights) != n:
286309
raise ValueError(f"Expected {n} weights, got {len(weights)}")
287310

311+
_validate_weights(weights, "MixtureSampler.update_weights")
312+
288313
# Apply temperature scaling and normalize (same logic as __init__)
289314
if temperature != 1.0:
290315
import math as _math

kempnerforge/distributed/setup.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88

99
import logging
1010
import os
11+
12+
# ``random`` is aliased to ``_random`` because ``init_distributed`` has a
13+
# function-local ``import random`` for SLURM-port derivation (a fresh
14+
# ``random.Random(int(job_id))`` factory isolated from the global RNG that
15+
# ``_set_seed`` mutates). The underscore keeps the two ``random`` bindings
16+
# unambiguous when grepping the file.
17+
import random as _random
1118
from datetime import timedelta
1219

20+
import numpy as np
1321
import torch
1422
import torch.distributed as dist
1523
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
@@ -82,17 +90,59 @@ def _set_nccl_env() -> None:
8290
os.environ.setdefault("NCCL_IB_DISABLE", "0")
8391
os.environ.setdefault("NCCL_NET_GDR_LEVEL", "2")
8492

93+
# Ensure NCCL actually enforces the process-group timeout. The default in
94+
# PyTorch 2.2+ is "1", but a user shell/SLURM prolog may override it to
95+
# "0", at which point the PG timeout becomes advisory and stuck collectives
96+
# can hang indefinitely. Set a safe default and warn loudly if the user
97+
# has explicitly disabled it.
98+
existing = os.environ.get("TORCH_NCCL_ASYNC_ERROR_HANDLING")
99+
if existing == "0":
100+
logger.warning(
101+
"TORCH_NCCL_ASYNC_ERROR_HANDLING=0 detected — NCCL timeouts "
102+
"are advisory; stuck collectives can hang indefinitely."
103+
)
104+
else:
105+
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
106+
107+
108+
def _barrier_with_timeout(seconds: float, reason: str) -> None:
109+
"""dist.barrier with an explicit per-op timeout and a diagnostic log.
110+
111+
The process-group default timeout (``config.nccl_timeout_sec``) is sized
112+
for training collectives (minutes of reduce on large tensors). Init-path
113+
barriers should fail fast so mesh or env misconfiguration does not block
114+
a job for 30 minutes before surfacing a useful error.
115+
"""
116+
work = dist.barrier(async_op=True)
117+
try:
118+
done = work.wait(timeout=timedelta(seconds=seconds)) # type: ignore[reportOptionalMemberAccess]
119+
except RuntimeError as e:
120+
raise RuntimeError(
121+
f"Barrier timed out after {seconds}s during {reason}. "
122+
f"Common causes: MASTER_ADDR/MASTER_PORT disagreement across ranks, "
123+
f"a rank missing from the job, or the IB interface unreachable. "
124+
f"Underlying: {e}"
125+
) from e
126+
if done is False:
127+
raise RuntimeError(f"Barrier timed out after {seconds}s during {reason}.")
128+
85129

86130
def _set_seed(seed: int, rank: int, pp_rank: int = 0) -> None:
87131
"""Set deterministic seeds for reproducibility.
88132
89133
- Same seed across data-parallel replicas (for consistent dropout)
90134
- Different seed across pipeline stages (for stochastic depth variation)
135+
- Covers torch (CPU + all visible CUDA devices), Python's random, and
136+
NumPy's legacy global RNG — matches the four generators captured by
137+
``checkpoint.state.get_rng_state`` so cold start and warm resume
138+
seed the same set of generators.
91139
"""
92140
effective_seed = seed + pp_rank
93141
torch.manual_seed(effective_seed)
94142
if torch.cuda.is_available():
95-
torch.cuda.manual_seed(effective_seed)
143+
torch.cuda.manual_seed_all(effective_seed)
144+
_random.seed(effective_seed)
145+
np.random.seed(effective_seed)
96146

97147

98148
def init_distributed(config: DistributedConfig, seed: int = 42) -> DeviceMesh | None:
@@ -209,8 +259,10 @@ def init_distributed(config: DistributedConfig, seed: int = 42) -> DeviceMesh |
209259
mesh_dim_names=tuple(mesh_dims),
210260
)
211261

212-
# Ensure all ranks have finished mesh creation before proceeding
213-
dist.barrier()
262+
# Ensure all ranks have finished mesh creation before proceeding.
263+
# A 60s bound fails fast on mesh misconfiguration rather than inheriting
264+
# the 1800s PG timeout.
265+
_barrier_with_timeout(60.0, reason="DeviceMesh construction")
214266

215267
# Set seed (vary by PP rank for different dropout/stochastic depth per stage)
216268
pp_rank = 0

0 commit comments

Comments
 (0)