diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index 04882c09..39f827dd 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -30,6 +30,7 @@ DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT, generate_hidden_states_async, + wait_for_lock_async, ) from speculators.train.logger import setup_root_logger @@ -169,21 +170,21 @@ def parse_args(): def get_existing_hidden_state_indices(output_path: Path) -> list[int]: """Find existing `hs_i.safetensors` files (where i is the file index)""" - existing_file_indices = [] + existing_file_indices_set: set[int] = set() if not output_path.exists(): - return existing_file_indices + return [] for file_path in output_path.iterdir(): if file_path.name.startswith("hs_") and file_path.name.endswith(".safetensors"): index_str = file_path.stem[3:] # Remove "hs_" prefix try: file_index = int(index_str) - existing_file_indices.append(file_index) + existing_file_indices_set.add(file_index) except ValueError: continue - return sorted(existing_file_indices) + return sorted(existing_file_indices_set) def get_indices_to_process( @@ -246,7 +247,7 @@ def check_safetensors_file(path: Path, tokens: list[int]): ) -async def worker( +async def worker( # noqa: C901 client, model: str, queue: "asyncio.Queue[dict[str, Any]]", @@ -289,6 +290,10 @@ async def worker( timeout=request_timeout, max_retries=max_retries, ) + lock_path = hidden_states_path + ".lock" + if Path(lock_path).exists(): + await wait_for_lock_async(lock_path) + async with write_semaphore: # Limit number of active disk writes await asyncio.to_thread( shutil.move, hidden_states_path, target_hidden_states_path diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index da766a65..3336ea69 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -1,6 +1,8 @@ import asyncio +import fcntl import functools import logging +import os import time import openai @@ -99,6 +101,47 @@ def extract_output(completion, token_ids) -> str: return completion.kv_transfer_params.get("hidden_states_path") +async def _poll_lock_async(fd, poll_interval): + while True: + try: + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + return + except BlockingIOError: + await asyncio.sleep(poll_interval) + + +async def wait_for_lock_async(lock_path, timeout=10.0, poll_interval=0.1): + fd = os.open(lock_path, os.O_RDONLY) + try: + await asyncio.wait_for(_poll_lock_async(fd, poll_interval), timeout=timeout) + except BaseException: + os.close(fd) + raise + os.close(fd) + os.remove(lock_path) + + +def wait_for_lock(lock_path, timeout=10.0, poll_interval=0.1): + fd = os.open(lock_path, os.O_RDONLY) + try: + deadline = time.monotonic() + timeout + while True: + try: + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + break + except BlockingIOError: + if time.monotonic() >= deadline: + raise TimeoutError( + f"Timed out waiting for lock: {lock_path}" + ) from None + time.sleep(poll_interval) + except BaseException: + os.close(fd) + raise + os.close(fd) + os.remove(lock_path) + + @with_retries async def generate_hidden_states_async( client: openai.AsyncClient, diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index b1e05c91..1dd65334 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -20,6 +20,7 @@ DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT, generate_hidden_states, + wait_for_lock, ) from speculators.train.noise_transforms import TransformTensors @@ -248,6 +249,11 @@ def _compute_approx_lengths(self) -> list[int]: def _maybe_load_hs_file(self, index: int) -> dict[str, torch.Tensor] | None: file_idx = self._map_to_file_idx(index) candidate_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" + + lock_path = str(candidate_path) + ".lock" + if Path(lock_path).exists(): + wait_for_lock(lock_path) + if candidate_path.exists(): return load_file(candidate_path) @@ -266,19 +272,22 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: timeout=self.request_timeout, max_retries=self.max_retries, ) - except Exception as e: # noqa: BLE001 - warnings.warn(str(e), stacklevel=1) - return None - loaded_hs = load_file(hs_filepath) + loaded_hs = load_file(hs_filepath) - match self.on_generate: - case "cache": - file_idx = self._map_to_file_idx(index) - target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" - shutil.move(hs_filepath, target_path) - case "delete": - Path(hs_filepath).unlink() + match self.on_generate: + case "cache": + file_idx = self._map_to_file_idx(index) + target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" + shutil.move(hs_filepath, target_path) + case "delete": + Path(hs_filepath).unlink() + except Exception as e: # noqa: BLE001 + warnings.warn( + f"Failed to load/cache hidden states for sample {index}: {e}", + stacklevel=1, + ) + return None return loaded_hs