Skip to content

Commit ac71e0e

Browse files
authored
Merge pull request #62 from KempnerInstitute/train-state-ownership-check
Gate train_state.pt load behind an ownership check
2 parents 499adcf + 8b94a05 commit ac71e0e

2 files changed

Lines changed: 244 additions & 1 deletion

File tree

kempnerforge/checkpoint/manager.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
import json
1313
import logging
14+
import os
1415
import shutil
16+
import stat
1517
from pathlib import Path
1618
from typing import Any
1719

@@ -30,6 +32,44 @@
3032
_METADATA_FILE = "metadata.json"
3133

3234

35+
def _load_train_state(path: Path) -> dict[str, Any]:
36+
"""Load ``train_state.pt`` under an explicit trust boundary.
37+
38+
``train_state.pt`` carries scheduler state, dataloader state, and a
39+
caller-supplied ``extra`` dict, so it is loaded with ``weights_only=False``
40+
(i.e. full pickle). Any object in the file whose class defines
41+
``__reduce__`` runs arbitrary Python during ``torch.load``. On shared
42+
filesystems this is a real attack surface: anyone who can write into
43+
another user's checkpoint directory gets code execution in that user's
44+
training process on next resume.
45+
46+
Refuses to load files not owned by the current UID and warns when the
47+
file is group- or world-writable. This does not defend against a
48+
same-UID compromise — if the attacker can write as you, they already
49+
win — but it closes the common "group-writable shared checkpoint dir"
50+
foot-gun and makes the trust boundary visible.
51+
52+
Checkpoints imported from outside the lab (HuggingFace Hub, colleague
53+
transfers, etc.) will fail this check and must be either chown'd to the
54+
current user after inspection or converted to a weights-only-safe form.
55+
"""
56+
st = path.stat()
57+
uid = os.getuid()
58+
if st.st_uid != uid:
59+
raise PermissionError(
60+
f"Refusing to load {path}: owned by uid={st.st_uid}, current uid={uid}. "
61+
f"train_state.pt is a pickle and loading it executes arbitrary Python. "
62+
f"If you trust this checkpoint, chown it to the current user after inspection."
63+
)
64+
if st.st_mode & (stat.S_IWGRP | stat.S_IWOTH):
65+
logger.warning(
66+
f"{path} is group/world-writable (mode={oct(st.st_mode & 0o777)}); "
67+
f"train_state.pt is a pickle and any writer can inject arbitrary code "
68+
f"at load time. Consider chmod g-w,o-w on the checkpoint directory."
69+
)
70+
return torch.load(path, map_location="cpu", weights_only=False)
71+
72+
3373
class CheckpointManager:
3474
"""Manages save/load/cleanup of distributed checkpoints.
3575
@@ -203,8 +243,12 @@ def load(
203243
train_state_exists = train_state_path.exists()
204244

205245
if train_state_exists:
246+
# Rank-0-authoritative: only rank 0 reads the file. The
247+
# ownership check inside ``_load_train_state`` runs there and
248+
# the resulting state is broadcast to all ranks below. Other
249+
# ranks pass ``None`` into the broadcast.
206250
train_state = (
207-
torch.load(train_state_path, map_location="cpu", weights_only=False)
251+
_load_train_state(train_state_path)
208252
if self._rank == 0 or not dist.is_initialized()
209253
else None
210254
)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
"""Tests for the train_state.pt trust boundary.
2+
3+
``train_state.pt`` is loaded with ``weights_only=False`` (full pickle),
4+
because it carries scheduler state and an arbitrary caller-supplied
5+
``extra`` dict. That means any ``__reduce__`` in the file runs at load
6+
time. Shared-FS clusters (HolyLFS, Kempner lab storage) and imported
7+
"pretrained" checkpoints both make the write side of that file
8+
attacker-reachable, so the loader MUST at minimum refuse to execute
9+
pickles planted by a different UID.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import os
15+
import stat
16+
from pathlib import Path
17+
from unittest.mock import MagicMock
18+
19+
import pytest
20+
import torch
21+
22+
import kempnerforge.checkpoint.manager as manager_mod
23+
from kempnerforge.checkpoint.manager import CheckpointManager, _load_train_state
24+
from kempnerforge.config.schema import CheckpointConfig
25+
26+
27+
class _Payload:
28+
"""Pickle-time side effect. If ``__reduce__`` runs, the marker file appears."""
29+
30+
def __init__(self, marker: Path) -> None:
31+
self._marker = marker
32+
33+
def __reduce__(self):
34+
# Tells pickle: on load, call os.system(cmd). os.system is a stand-in
35+
# for any arbitrary command the attacker chooses.
36+
return (os.system, (f"touch {self._marker}",))
37+
38+
39+
def _write_malicious_train_state(path: Path, marker: Path) -> None:
40+
"""Write a torch-format file that fires a side effect on load.
41+
42+
``torch.save`` wraps pickle, so a ``__reduce__`` on any object inside
43+
still runs when the file is opened with ``torch.load(weights_only=False)``.
44+
"""
45+
train_state = {
46+
"step": 42,
47+
"tokens_seen": 1000,
48+
"rng": {},
49+
"payload": _Payload(marker),
50+
}
51+
torch.save(train_state, path)
52+
53+
54+
def _write_benign_train_state(path: Path, step: int = 7, tokens_seen: int = 128) -> None:
55+
torch.save({"step": step, "tokens_seen": tokens_seen, "rng": {}}, path)
56+
57+
58+
def _fake_ckpt_dir(tmp_path: Path, marker: Path) -> Path:
59+
"""Build a ``step_42`` checkpoint dir with a malicious train_state.pt
60+
and point ``latest`` at it so ``_resolve_load_path`` picks it up."""
61+
ckpt_dir = tmp_path / "step_42"
62+
ckpt_dir.mkdir()
63+
_write_malicious_train_state(ckpt_dir / "train_state.pt", marker)
64+
(tmp_path / "latest").symlink_to("step_42")
65+
return ckpt_dir
66+
67+
68+
def _make_manager(tmp_path: Path) -> CheckpointManager:
69+
config = CheckpointConfig(dir=str(tmp_path))
70+
model = MagicMock()
71+
model.state_dict.return_value = {}
72+
optimizer = MagicMock()
73+
optimizer.state_dict.return_value = {}
74+
return CheckpointManager(config=config, model=model, optimizer=optimizer)
75+
76+
77+
class TestLoadTrainStateOwnershipGate:
78+
def test_rejects_foreign_owned_file(self, tmp_path: Path) -> None:
79+
"""_load_train_state refuses when st_uid != current uid."""
80+
marker = tmp_path / "rce_marker"
81+
path = tmp_path / "train_state.pt"
82+
_write_malicious_train_state(path, marker)
83+
84+
real_uid = os.getuid()
85+
orig_getuid = manager_mod.os.getuid
86+
try:
87+
manager_mod.os.getuid = lambda: real_uid + 12345
88+
with pytest.raises(PermissionError, match="Refusing to load"):
89+
_load_train_state(path)
90+
finally:
91+
manager_mod.os.getuid = orig_getuid
92+
93+
assert not marker.exists(), (
94+
"payload fired despite ownership gate — torch.load was reached before the check"
95+
)
96+
97+
def test_accepts_own_file(self, tmp_path: Path) -> None:
98+
"""_load_train_state loads when st_uid matches current uid."""
99+
path = tmp_path / "train_state.pt"
100+
_write_benign_train_state(path)
101+
102+
loaded = _load_train_state(path)
103+
assert loaded["step"] == 7
104+
assert loaded["tokens_seen"] == 128
105+
106+
def test_warns_on_group_writable(self, tmp_path: Path) -> None:
107+
"""_load_train_state warns (but still loads) on group-writable files.
108+
109+
Same-UID group-writable is the footgun this case addresses: a
110+
colleague in your lab group can plant the file. We warn instead of
111+
refusing because the same-UID assumption typically holds on HPC
112+
shared FS and the user deserves a heads-up, not a hard block.
113+
114+
Asserts via a direct logger handler rather than pytest's caplog so
115+
the test is robust to other tests mutating logger propagation.
116+
"""
117+
path = tmp_path / "train_state.pt"
118+
_write_benign_train_state(path, step=1, tokens_seen=1)
119+
path.chmod(path.stat().st_mode | stat.S_IWGRP)
120+
121+
import logging
122+
123+
records: list[logging.LogRecord] = []
124+
125+
class _Capture(logging.Handler):
126+
def emit(self, record: logging.LogRecord) -> None:
127+
records.append(record)
128+
129+
handler = _Capture(level=logging.WARNING)
130+
logger = logging.getLogger("kempnerforge.checkpoint.manager")
131+
prior_level = logger.level
132+
logger.setLevel(logging.WARNING)
133+
logger.addHandler(handler)
134+
try:
135+
_load_train_state(path)
136+
finally:
137+
logger.removeHandler(handler)
138+
logger.setLevel(prior_level)
139+
140+
assert any("group/world-writable" in r.getMessage() for r in records), (
141+
"expected a warning about group/world-writable train_state.pt"
142+
)
143+
144+
def test_no_warning_on_private_mode(self, tmp_path: Path) -> None:
145+
"""Files with a tight mode (600/640 without group write) don't warn."""
146+
path = tmp_path / "train_state.pt"
147+
_write_benign_train_state(path, step=1, tokens_seen=1)
148+
path.chmod(0o600)
149+
150+
import logging
151+
152+
records: list[logging.LogRecord] = []
153+
154+
class _Capture(logging.Handler):
155+
def emit(self, record: logging.LogRecord) -> None:
156+
records.append(record)
157+
158+
handler = _Capture(level=logging.WARNING)
159+
logger = logging.getLogger("kempnerforge.checkpoint.manager")
160+
prior_level = logger.level
161+
logger.setLevel(logging.WARNING)
162+
logger.addHandler(handler)
163+
try:
164+
_load_train_state(path)
165+
finally:
166+
logger.removeHandler(handler)
167+
logger.setLevel(prior_level)
168+
169+
assert not any("group/world-writable" in r.getMessage() for r in records)
170+
171+
172+
class TestManagerLoadRejectsForeignCheckpoint:
173+
def test_load_raises_before_executing_pickle(self, tmp_path: Path) -> None:
174+
"""CheckpointManager.load() blocks a foreign-owned pickle before it fires.
175+
176+
Pre-fix, the pickle ran through ``torch.load(..., weights_only=False)``
177+
with no provenance check. Post-fix, ``_load_train_state`` raises
178+
``PermissionError`` before ``torch.load`` is reached, so the
179+
``__reduce__`` side effect never executes.
180+
181+
``exclude_keys`` skips the DCP model/optimizer load so we don't need
182+
a real checkpoint to reach the train_state.pt branch.
183+
"""
184+
marker = tmp_path / "rce_marker_2"
185+
_fake_ckpt_dir(tmp_path, marker)
186+
187+
real_uid = os.getuid()
188+
orig_getuid = manager_mod.os.getuid
189+
mgr = _make_manager(tmp_path)
190+
try:
191+
manager_mod.os.getuid = lambda: real_uid + 12345
192+
with pytest.raises(PermissionError, match="Refusing to load"):
193+
mgr.load(exclude_keys=["model", "optimizer"])
194+
finally:
195+
manager_mod.os.getuid = orig_getuid
196+
197+
assert not marker.exists(), (
198+
"ownership gate did not block the load — payload fired despite the check"
199+
)

0 commit comments

Comments
 (0)