Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions scripts/data_generation_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DEFAULT_MAX_RETRIES,
DEFAULT_REQUEST_TIMEOUT,
generate_hidden_states_async,
wait_for_lock_async,
)
from speculators.train.logger import setup_root_logger

Expand Down Expand Up @@ -169,21 +170,21 @@ def parse_args():
def get_existing_hidden_state_indices(output_path: Path) -> list[int]:
"""Find existing `hs_i.safetensors` files (where i is the file index)"""

existing_file_indices = []
existing_file_indices_set: set[int] = set()

if not output_path.exists():
return existing_file_indices
return []

for file_path in output_path.iterdir():
if file_path.name.startswith("hs_") and file_path.name.endswith(".safetensors"):
index_str = file_path.stem[3:] # Remove "hs_" prefix
try:
file_index = int(index_str)
existing_file_indices.append(file_index)
existing_file_indices_set.add(file_index)
except ValueError:
continue

return sorted(existing_file_indices)
return sorted(existing_file_indices_set)


def get_indices_to_process(
Expand Down Expand Up @@ -246,7 +247,7 @@ def check_safetensors_file(path: Path, tokens: list[int]):
)


async def worker(
async def worker( # noqa: C901
client,
model: str,
queue: "asyncio.Queue[dict[str, Any]]",
Expand Down Expand Up @@ -289,6 +290,10 @@ async def worker(
timeout=request_timeout,
max_retries=max_retries,
)
lock_path = hidden_states_path + ".lock"
if Path(lock_path).exists():
await wait_for_lock_async(lock_path)

async with write_semaphore: # Limit number of active disk writes
await asyncio.to_thread(
shutil.move, hidden_states_path, target_hidden_states_path
Expand Down
43 changes: 43 additions & 0 deletions src/speculators/data_generation/vllm_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import fcntl
import functools
import logging
import os
import time

import openai
Expand Down Expand Up @@ -99,6 +101,47 @@ def extract_output(completion, token_ids) -> str:
return completion.kv_transfer_params.get("hidden_states_path")


async def _poll_lock_async(fd, poll_interval):
while True:
try:
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
return
except BlockingIOError:
await asyncio.sleep(poll_interval)


async def wait_for_lock_async(lock_path, timeout=10.0, poll_interval=0.1):
fd = os.open(lock_path, os.O_RDONLY)
try:
await asyncio.wait_for(_poll_lock_async(fd, poll_interval), timeout=timeout)
except BaseException:
os.close(fd)
raise
os.close(fd)
os.remove(lock_path)


def wait_for_lock(lock_path, timeout=10.0, poll_interval=0.1):
fd = os.open(lock_path, os.O_RDONLY)
try:
deadline = time.monotonic() + timeout
while True:
try:
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
break
except BlockingIOError:
if time.monotonic() >= deadline:
raise TimeoutError(
f"Timed out waiting for lock: {lock_path}"
) from None
time.sleep(poll_interval)
except BaseException:
os.close(fd)
raise
os.close(fd)
os.remove(lock_path)


@with_retries
async def generate_hidden_states_async(
client: openai.AsyncClient,
Expand Down
31 changes: 20 additions & 11 deletions src/speculators/train/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DEFAULT_MAX_RETRIES,
DEFAULT_REQUEST_TIMEOUT,
generate_hidden_states,
wait_for_lock,
)
from speculators.train.noise_transforms import TransformTensors

Expand Down Expand Up @@ -248,6 +249,11 @@ def _compute_approx_lengths(self) -> list[int]:
def _maybe_load_hs_file(self, index: int) -> dict[str, torch.Tensor] | None:
file_idx = self._map_to_file_idx(index)
candidate_path = self.hidden_states_path / f"hs_{file_idx}.safetensors"

lock_path = str(candidate_path) + ".lock"
if Path(lock_path).exists():
wait_for_lock(lock_path)

if candidate_path.exists():
return load_file(candidate_path)

Expand All @@ -266,19 +272,22 @@ def _maybe_generate_hs(self, index: int) -> dict[str, torch.Tensor] | None:
timeout=self.request_timeout,
max_retries=self.max_retries,
)
except Exception as e: # noqa: BLE001
warnings.warn(str(e), stacklevel=1)
return None

loaded_hs = load_file(hs_filepath)
loaded_hs = load_file(hs_filepath)

match self.on_generate:
case "cache":
file_idx = self._map_to_file_idx(index)
target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors"
shutil.move(hs_filepath, target_path)
case "delete":
Path(hs_filepath).unlink()
match self.on_generate:
case "cache":
file_idx = self._map_to_file_idx(index)
target_path = self.hidden_states_path / f"hs_{file_idx}.safetensors"
shutil.move(hs_filepath, target_path)
case "delete":
Path(hs_filepath).unlink()
except Exception as e: # noqa: BLE001
warnings.warn(
f"Failed to load/cache hidden states for sample {index}: {e}",
stacklevel=1,
)
return None

return loaded_hs

Expand Down
Loading