Skip to content

Commit 2a5d4b0

Browse files
authored
Fix waiting for lock acquisition during online training (#566)
<!-- markdownlint-disable --> PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED. ## Purpose #424 accidentally added lock waiting logic but only for the first load attempt of the hidden states file. When that first load (intended for pre-existing data) fails, the generate hidden states is called and new hidden states are generated. Unfortunately the new hidden states were loaded without waiting for lock. <!--- Why your changes are needed --> ## Description Make `_maybe_load_hs_file` a helper function so that it can be used by both loads. <!--- High-level concise summary of changes --> ## Related Issue <!--- Link related issue if applicable --> ## Tests Tested locally that this fixes these errors (which subsequently led to training crashing): ``` speculators/src/speculators/train/data.py:300: UserWarning: Failed to load/cache hidden states for sample 29: No such file or directory: /tmp/pytest-of-fynnsu/pytest-55/test_online_smoke_Qwen_Qwen3_V0/hidden_states/chatcmpl-b549e54beee9b9c2-93723e8e.safetensors ``` <!--- Please describe in detail how you tested your changes. --> I have filled in: - [x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". - [x] The test plan/results, such as providing test command and pasting the results. - [ ] (Optional) The necessary documentation update. - [x] I (a human) have written or reviewed the code in this pr to the best of my ability. Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
1 parent 1ad7048 commit 2a5d4b0

1 file changed

Lines changed: 15 additions & 15 deletions

File tree

src/speculators/train/data.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ def __getitem__(self, index) -> BatchType | None:
178178
return data
179179

180180

181+
def _maybe_load_hs_file(file_path: Path) -> dict[str, torch.Tensor] | None:
182+
lock_path = str(file_path) + ".lock"
183+
if Path(lock_path).exists():
184+
wait_for_lock(lock_path)
185+
186+
if file_path.exists():
187+
return load_file(file_path)
188+
189+
return None
190+
191+
181192
class ArrowDataset(BaseDataset):
182193
def __init__(
183194
self,
@@ -258,19 +269,6 @@ def _compute_approx_lengths(self) -> list[int]:
258269
"""Get lengths of the dataset samples."""
259270
return list(self.data.with_format(None)["seq_len"])
260271

261-
def _maybe_load_hs_file(self, index: int) -> dict[str, torch.Tensor] | None:
262-
file_idx = self._map_to_file_idx(index)
263-
candidate_path = self.hidden_states_path / f"hs_{file_idx}.safetensors"
264-
265-
lock_path = str(candidate_path) + ".lock"
266-
if Path(lock_path).exists():
267-
wait_for_lock(lock_path)
268-
269-
if candidate_path.exists():
270-
return load_file(candidate_path)
271-
272-
return None
273-
274272
def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None:
275273
if not self.client:
276274
self._setup_client()
@@ -287,7 +285,7 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None:
287285
max_retries=self.max_retries,
288286
)
289287

290-
loaded_hs = load_file(hs_filepath)
288+
loaded_hs = _maybe_load_hs_file(Path(hs_filepath))
291289

292290
match self.on_generate:
293291
case "cache":
@@ -306,7 +304,9 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None:
306304
return loaded_hs
307305

308306
def _get_raw_data(self, index):
309-
loaded_hs = self._maybe_load_hs_file(index)
307+
file_idx = self._map_to_file_idx(index)
308+
candidate_path = self.hidden_states_path / f"hs_{file_idx}.safetensors"
309+
loaded_hs = _maybe_load_hs_file(candidate_path)
310310

311311
if loaded_hs is None:
312312
match self.on_missing:

0 commit comments

Comments
 (0)