Skip to content

Commit 758b5ee

Browse files
committed
Handle token export mkdir races
1 parent aff43a4 commit 758b5ee

2 files changed

Lines changed: 66 additions & 1 deletion

File tree

src/prime_rl/trainer/rl/token_export.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import atexit
22
import json
33
import math
4+
import time
45
from collections.abc import Mapping, Sequence
56
from pathlib import Path
67
from typing import Any
@@ -12,6 +13,8 @@
1213
from prime_rl.trainer.rl.loss import compute_importance_ratio_and_mismatch_kl
1314

1415
SCHEMA_VERSION = 1
16+
MKDIR_RETRY_DELAY_SECONDS = 0.1
17+
MKDIR_MAX_ATTEMPTS = 5
1518

1619

1720
class DisabledTokenExporter:
@@ -88,7 +91,7 @@ def _start_step(self, step: int) -> None:
8891
self._current_step = step
8992
self._sequences_this_step = 0
9093
step_dir = self.output_dir / f"step_{step}"
91-
step_dir.mkdir(parents=True, exist_ok=True)
94+
_mkdir_existing_dir_ok(step_dir)
9295
self._file = (step_dir / f"rank_{self.rank}.jsonl").open("w", encoding="utf-8")
9396

9497
def _write(self, record: dict[str, Any]) -> None:
@@ -113,6 +116,19 @@ def setup_token_exporter(
113116
return exporter
114117

115118

119+
def _mkdir_existing_dir_ok(path: Path) -> None:
120+
for attempt in range(MKDIR_MAX_ATTEMPTS):
121+
try:
122+
path.mkdir(parents=True, exist_ok=True)
123+
return
124+
except FileExistsError:
125+
if path.is_dir():
126+
return
127+
if attempt == MKDIR_MAX_ATTEMPTS - 1:
128+
raise
129+
time.sleep(MKDIR_RETRY_DELAY_SECONDS)
130+
131+
116132
def _export_columns(
117133
micro_batch: Mapping[str, Any], model_output: Mapping[str, Tensor], loss_config: Any
118134
) -> dict[str, list[Any]]:
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from prime_rl.trainer.rl import token_export
6+
from prime_rl.trainer.rl.token_export import _mkdir_existing_dir_ok
7+
8+
9+
def test_mkdir_existing_dir_ok_retries_transient_file_exists(
10+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
11+
) -> None:
12+
target = tmp_path / "token_exports" / "step_31"
13+
original_mkdir = Path.mkdir
14+
calls = 0
15+
16+
def flaky_mkdir(
17+
self: Path,
18+
mode: int = 0o777,
19+
parents: bool = False,
20+
exist_ok: bool = False,
21+
) -> None:
22+
nonlocal calls
23+
if self == target and calls == 0:
24+
calls += 1
25+
raise FileExistsError(str(self))
26+
original_mkdir(self, mode=mode, parents=parents, exist_ok=exist_ok)
27+
28+
def create_dir_during_retry(_: float) -> None:
29+
original_mkdir(target, parents=True, exist_ok=True)
30+
31+
monkeypatch.setattr(Path, "mkdir", flaky_mkdir)
32+
monkeypatch.setattr(token_export.time, "sleep", create_dir_during_retry)
33+
34+
_mkdir_existing_dir_ok(target)
35+
36+
assert target.is_dir()
37+
assert calls == 1
38+
39+
40+
def test_mkdir_existing_dir_ok_raises_when_path_is_file(
41+
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
42+
) -> None:
43+
target = tmp_path / "token_exports" / "step_31"
44+
target.parent.mkdir(parents=True)
45+
target.write_text("not a directory", encoding="utf-8")
46+
monkeypatch.setattr(token_export.time, "sleep", lambda _: None)
47+
48+
with pytest.raises(FileExistsError):
49+
_mkdir_existing_dir_ok(target)

0 commit comments

Comments
 (0)