From 700fa684da4f65f367a9968a6bbbd153ffccc2e8 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Thu, 19 Mar 2026 13:41:54 -0400 Subject: [PATCH 1/3] Add support for async hidden states connector Signed-off-by: Fynn Schmitt-Ulms --- scripts/data_generation_offline.py | 5 +++ .../data_generation/vllm_client.py | 39 +++++++++++++++++++ src/speculators/train/data.py | 6 +++ 3 files changed, 50 insertions(+) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index ca1be6bc0..f42ef6197 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -30,6 +30,7 @@ DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT, generate_hidden_states_async, + wait_for_lock_async, ) from speculators.train.data import build_client_item from speculators.train.logger import setup_root_logger @@ -288,6 +289,10 @@ async def worker( timeout=request_timeout, max_retries=max_retries, ) + lock_path = hidden_states_path + ".lock" + if Path(lock_path).exists(): + await wait_for_lock_async(lock_path) + async with write_semaphore: # Limit number of active disk writes await asyncio.to_thread( shutil.move, hidden_states_path, target_hidden_states_path diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 2b1a1f97b..30c54ecc6 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -1,6 +1,8 @@ import asyncio +import fcntl import functools import logging +import os import time from typing import TYPE_CHECKING, Any, TypedDict @@ -122,6 +124,43 @@ class ClientItem(TypedDict): instead of passing `token_ids` to Completions API.""" +async def _poll_lock_async(fd, poll_interval): + while True: + try: + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + return + except BlockingIOError: + await asyncio.sleep(poll_interval) + + +async def wait_for_lock_async(lock_path, timeout=10.0, poll_interval=0.1): + fd = os.open(lock_path, os.O_RDONLY) + try: + await asyncio.wait_for(_poll_lock_async(fd, poll_interval), timeout=timeout) + finally: + os.close(fd) + os.remove(lock_path) + + +def wait_for_lock(lock_path, timeout=10.0, poll_interval=0.1): + fd = os.open(lock_path, os.O_RDONLY) + try: + deadline = time.monotonic() + timeout + while True: + try: + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + return + except BlockingIOError: + if time.monotonic() >= deadline: + raise TimeoutError( + f"Timed out waiting for lock: {lock_path}" + ) from None + time.sleep(poll_interval) + finally: + os.close(fd) + os.remove(lock_path) + + @with_retries async def generate_hidden_states_async( client: openai.AsyncClient, diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index fc4a4af9a..9a98d452c 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -21,6 +21,7 @@ DEFAULT_REQUEST_TIMEOUT, ClientItem, generate_hidden_states, + wait_for_lock, ) from speculators.train.noise_transforms import TransformTensors @@ -259,6 +260,11 @@ def _compute_approx_lengths(self) -> list[int]: def _maybe_load_hs_file(self, index: int) -> dict[str, torch.Tensor] | None: file_idx = self._map_to_file_idx(index) candidate_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" + + lock_path = str(candidate_path) + ".lock" + if Path(lock_path).exists(): + wait_for_lock(lock_path) + if candidate_path.exists(): return load_file(candidate_path) From e7dceba4106ad70a1a3c2827bb67b300aeadbad0 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 14 Apr 2026 20:03:31 +0000 Subject: [PATCH 2/3] Add coderabbit suggestions Signed-off-by: Fynn Schmitt-Ulms --- scripts/data_generation_offline.py | 10 ++++---- .../data_generation/vllm_client.py | 14 +++++++---- src/speculators/train/data.py | 25 +++++++++++-------- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index f42ef6197..c91d527be 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -171,21 +171,21 @@ def parse_args(): def get_existing_hidden_state_indices(output_path: Path) -> list[int]: """Find existing `hs_i.safetensors` files (where i is the file index)""" - existing_file_indices = [] + existing_file_indices_set: set[int] = set() if not output_path.exists(): - return existing_file_indices + return [] for file_path in output_path.iterdir(): if file_path.name.startswith("hs_") and file_path.name.endswith(".safetensors"): index_str = file_path.stem[3:] # Remove "hs_" prefix try: file_index = int(index_str) - existing_file_indices.append(file_index) + existing_file_indices_set.add(file_index) except ValueError: continue - return sorted(existing_file_indices) + return sorted(existing_file_indices_set) def get_indices_to_process( @@ -248,7 +248,7 @@ def check_safetensors_file(path: Path, tokens: list[int]): ) -async def worker( +async def worker( # noqa: C901 client, model: str, queue: "asyncio.Queue[dict[str, Any]]", diff --git a/src/speculators/data_generation/vllm_client.py b/src/speculators/data_generation/vllm_client.py index 30c54ecc6..6549444a9 100644 --- a/src/speculators/data_generation/vllm_client.py +++ b/src/speculators/data_generation/vllm_client.py @@ -137,9 +137,11 @@ async def wait_for_lock_async(lock_path, timeout=10.0, poll_interval=0.1): fd = os.open(lock_path, os.O_RDONLY) try: await asyncio.wait_for(_poll_lock_async(fd, poll_interval), timeout=timeout) - finally: + except BaseException: os.close(fd) - os.remove(lock_path) + raise + os.close(fd) + os.remove(lock_path) def wait_for_lock(lock_path, timeout=10.0, poll_interval=0.1): @@ -149,16 +151,18 @@ def wait_for_lock(lock_path, timeout=10.0, poll_interval=0.1): while True: try: fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) - return + break except BlockingIOError: if time.monotonic() >= deadline: raise TimeoutError( f"Timed out waiting for lock: {lock_path}" ) from None time.sleep(poll_interval) - finally: + except BaseException: os.close(fd) - os.remove(lock_path) + raise + os.close(fd) + os.remove(lock_path) @with_retries diff --git a/src/speculators/train/data.py b/src/speculators/train/data.py index 9a98d452c..048e2c5f0 100644 --- a/src/speculators/train/data.py +++ b/src/speculators/train/data.py @@ -285,19 +285,22 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None: timeout=self.request_timeout, max_retries=self.max_retries, ) - except Exception as e: # noqa: BLE001 - warnings.warn(str(e), stacklevel=1) - return None - loaded_hs = load_file(hs_filepath) + loaded_hs = load_file(hs_filepath) - match self.on_generate: - case "cache": - file_idx = self._map_to_file_idx(index) - target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" - shutil.move(hs_filepath, target_path) - case "delete": - Path(hs_filepath).unlink() + match self.on_generate: + case "cache": + file_idx = self._map_to_file_idx(index) + target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors" + shutil.move(hs_filepath, target_path) + case "delete": + Path(hs_filepath).unlink() + except Exception as e: # noqa: BLE001 + warnings.warn( + f"Failed to load/cache hidden states for sample {index}: {e}", + stacklevel=1, + ) + return None return loaded_hs From a21a509012d57ae72a0aa3f5cc4e6abb7202ff35 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 26 May 2026 18:52:41 +0000 Subject: [PATCH 3/3] Add noqa Signed-off-by: Fynn Schmitt-Ulms --- scripts/data_generation_offline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_generation_offline.py b/scripts/data_generation_offline.py index c91d527be..d55b9601c 100644 --- a/scripts/data_generation_offline.py +++ b/scripts/data_generation_offline.py @@ -290,7 +290,7 @@ async def worker( # noqa: C901 max_retries=max_retries, ) lock_path = hidden_states_path + ".lock" - if Path(lock_path).exists(): + if Path(lock_path).exists(): # noqa: ASYNC240 await wait_for_lock_async(lock_path) async with write_semaphore: # Limit number of active disk writes