Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 219 additions & 67 deletions scripts/prepare_hidden_states.py

Large diffs are not rendered by default.

374 changes: 300 additions & 74 deletions scripts/train_dflash.py

Large diffs are not rendered by default.

424 changes: 391 additions & 33 deletions specforge/core/dflash.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions specforge/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .preprocessing import (
build_eagle3_dataset,
build_offline_dflash_dataset,
build_offline_eagle3_dataset,
generate_vocab_mapping_file,
)
from .utils import prepare_dp_dataloaders

__all__ = [
"build_eagle3_dataset",
"build_offline_dflash_dataset",
"build_offline_eagle3_dataset",
"generate_vocab_mapping_file",
"prepare_dp_dataloaders",
Expand Down
188 changes: 188 additions & 0 deletions specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,194 @@ def build_offline_eagle3_dataset(
)


# ==============================
# Offline DFlash Dataset
# ==============================
class OfflineDFlashDataset(torch.utils.data.Dataset):
"""Offline dataset for DFlash training from pre-computed hidden states."""

def __init__(
self,
datapath,
transform=None,
max_len=2048,
block_size=16,
use_usp_preprocess=False,
):
self.datapaths = datapath
self.transform = transform
self._epoch = 0
self.max_len = max_len
self.block_size = block_size
self.use_usp_preprocess = use_usp_preprocess
if use_usp_preprocess:
sp_group = get_draft_sp_group()
self.sp_rank = torch.distributed.get_rank(sp_group)
self.sp_size = torch.distributed.get_world_size(sp_group)
ring_group = get_sp_ring_group()
self.ring_rank = torch.distributed.get_rank(ring_group)
self.sp_ring_size = torch.distributed.get_world_size(ring_group)

@staticmethod
def process_data(data, max_len, block_size=16, transform=None):
"""Only apply max_len truncation. Block-size truncation and loss_mask
processing are handled by OnlineDFlashModel.forward() to align with online mode.
"""
hidden_state = data["hidden_state"].squeeze(0)[:max_len]
input_ids = data["input_ids"][:max_len]
loss_mask = data["loss_mask"][:max_len]

new_data = {
"hidden_state": hidden_state[None, :],
"input_ids": input_ids[None, :],
"loss_mask": loss_mask[None, :],
"attention_mask": torch.ones(1, len(input_ids), dtype=torch.long),
}
if transform:
new_data = transform(new_data)
return new_data

@staticmethod
def process_data_usp(
data,
max_len,
transform=None,
sp_rank=0,
sp_size=1,
):
new_data = {}

input_ids = data["input_ids"]
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)

global_len = min(max_len, input_ids.shape[1])
chunk_size = (global_len + sp_size - 1) // sp_size
start = sp_rank * chunk_size
local_len = chunk_size
end = min(start + local_len, global_len)

def _slice_and_pad_1d(tensor):
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
tensor = tensor[:, :global_len]
sliced = tensor[:, start:min(end, tensor.shape[1])]
valid_len = sliced.shape[1]
if valid_len < local_len:
pad_len = local_len - valid_len
if tensor.ndim == 2:
sliced = F.pad(sliced, (0, pad_len))
else:
sliced = F.pad(sliced, (0, 0, 0, pad_len))
return sliced.contiguous(), valid_len

def _slice_and_pad_hidden(tensor):
# Offline DFlash stores hidden_state as [1, seq_len, hidden_size].
# USP shards only the sequence dimension and must keep hidden_size intact.
if tensor.ndim == 3:
tensor = tensor.squeeze(0)
tensor = tensor[:global_len]
sliced = tensor[start:min(end, tensor.shape[0])]
valid_len = sliced.shape[0]
if valid_len < local_len:
pad_len = local_len - valid_len
sliced = F.pad(sliced, (0, 0, 0, pad_len))
return sliced.unsqueeze(0).contiguous(), valid_len

new_data["hidden_state"], _ = _slice_and_pad_hidden(data["hidden_state"])
new_data["input_ids"], valid_len = _slice_and_pad_1d(input_ids)

full_loss_mask = data["loss_mask"]
if full_loss_mask.ndim == 1:
full_loss_mask = full_loss_mask.unsqueeze(0)
full_loss_mask = full_loss_mask[:, :global_len]
new_data["loss_mask"], _ = _slice_and_pad_1d(full_loss_mask)

attention_mask = torch.zeros((1, local_len), dtype=torch.long)
attention_mask[:, :valid_len] = 1
new_data["attention_mask"] = attention_mask

position_ids = torch.zeros((1, local_len), dtype=torch.long)
if valid_len > 0:
position_ids[:, :valid_len] = torch.arange(
start, start + valid_len, dtype=torch.long
)
new_data["position_ids"] = position_ids

if transform:
new_data = transform(new_data)
return new_data

def __len__(self):
return len(self.datapaths)

def _open_file(self, index):
data_path = self.datapaths[index]
if data_path.endswith(".gz"):
with gzip.open(data_path, "rb") as f:
return torch.load(io.BytesIO(f.read()), weights_only=False)
return torch.load(data_path, weights_only=False, mmap=True)

def __getitem__(self, index):
for offset in range(len(self.datapaths)):
current_index = (index + offset) % len(self.datapaths)
current_path = self.datapaths[current_index]
try:
data = self._open_file(current_index)
except Exception as e:
print(
f"ERROR Failed to load DFlash sample index={current_index} "
f"path={current_path} error={e}"
)
continue

if self.use_usp_preprocess:
processed = self.process_data_usp(
data,
self.max_len,
transform=self.transform,
sp_rank=self.sp_rank,
sp_size=self.sp_size,
)
else:
processed = self.process_data(
data, self.max_len, self.block_size, self.transform
)

processed["sample_index"] = current_index
processed["sample_path"] = current_path
return processed

raise RuntimeError("No valid DFlash samples available after filtering corrupted entries.")

def set_epoch(self, epoch):
self._epoch = epoch


def build_offline_dflash_dataset(
hidden_states_path: str,
max_len: int = 2048,
block_size: int = 16,
use_usp_preprocess: bool = False,
) -> torch.utils.data.Dataset:
"""Build offline DFlash dataset from pre-computed hidden states.

Args:
hidden_states_path: Path to directory containing hidden state files.
max_len: Maximum sequence length.
block_size: Block size for DFlash (for truncation).

Returns:
OfflineDFlashDataset instance.
"""
return OfflineDFlashDataset(
list_local_files(hidden_states_path),
max_len=max_len,
block_size=block_size,
use_usp_preprocess=use_usp_preprocess,
)


# ==============================
# Vocab Mapping
# ==============================
Expand Down
71 changes: 54 additions & 17 deletions specforge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ class DataCollatorWithPadding:
Datacollator that will dynamically pad the inputs for batching.
"""

def __init__(self):
def __init__(self, requires_target: bool = True):
"""Initialize DataCollator.

Args:
requires_target: If True, requires 'target' field when 'hidden_state' is
present. Set to False for DFlash and True for Eagle3.
"""
self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group())
self.ulysses_degree = torch.distributed.get_world_size(get_sp_ulysses_group())
self.requires_target = requires_target

def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor:
"""
Expand Down Expand Up @@ -96,8 +103,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
max_length = (
(max_length + self.sp_degree - 1) // self.sp_degree
) * self.sp_degree
# position max len, ulysses do not need chuck position ids
position_max_len = max_length * self.ulysses_degree
# Eagle3 USP keeps a globally-expanded position_ids tensor because its
# attention path gathers rotary positions across Ulysses ranks. DFlash
# applies RoPE before SeqAllToAll4D, so its USP path only needs local
# absolute positions. `requires_target=False` is the DFlash collator mode.
if self.requires_target:
position_max_len = max_length * self.ulysses_degree
else:
position_max_len = max_length

batch_input_ids = torch.cat(
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
Expand Down Expand Up @@ -127,26 +140,44 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
"hidden_state": None,
"target": None,
}
if all("sample_index" in item for item in features):
batch["sample_index"] = [int(item["sample_index"]) for item in features]
if all("sample_path" in item for item in features):
batch["sample_path"] = [item["sample_path"] for item in features]
if batch_position_ids is not None:
batch["position_ids"] = batch_position_ids
if all("hidden_state" in item for item in features):
assert all(
"target" in item for item in features
), "target is required when hidden_state is provided"
if self.sp_degree > 1: # USP mode
if self.sp_degree > 1 and self.requires_target: # Eagle3 USP mode
batch["hidden_state"] = torch.cat(
[item["hidden_state"] for item in features]
)
elif self.sp_degree > 1: # DFlash USP mode
batch["hidden_state"] = torch.cat(
[
self.paddingtensor(item["hidden_state"], max_length)
for item in features
]
)
else:
batch["hidden_state"] = torch.cat(
[
self.paddingtensor(item["hidden_state"], max_length)
for item in features
]
)
batch["target"] = torch.cat(
[self.paddingtensor(item["target"], max_length) for item in features]
)
if self.requires_target:
if not all("target" in item for item in features):
raise ValueError(
"requires_target=True but 'target' field missing in some features. "
"Use requires_target=False for DFlash training."
)
if all("target" in item for item in features):
batch["target"] = torch.cat(
[
self.paddingtensor(item["target"], max_length)
for item in features
]
)
return batch


Expand Down Expand Up @@ -234,18 +265,21 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
"target": None,
}
if all("hidden_state" in item for item in features):
assert all(
"target" in item for item in features
), "target is required when hidden_state is provided"
batch["hidden_state"] = torch.cat(
[
self.paddingtensor(item["hidden_state"], max_length)
for item in features
]
)
batch["target"] = torch.cat(
[self.paddingtensor(item["target"], max_length) for item in features]
)
# target is optional for DFlash (only hidden_state is needed)
# but required for Eagle3 (both hidden_state and target are needed)
if all("target" in item for item in features):
batch["target"] = torch.cat(
[
self.paddingtensor(item["target"], max_length)
for item in features
]
)
return batch


Expand All @@ -258,6 +292,7 @@ def prepare_dp_dataloaders(
shuffle: Optional[bool] = False,
is_vlm: Optional[bool] = False,
prefetch_factor: Optional[int] = 2,
requires_target: bool = True,
**dataloader_kwargs,
) -> DataLoader:
"""
Expand All @@ -271,6 +306,8 @@ def prepare_dp_dataloaders(
pin_memory: Whether to pin memory for data loading.
shuffle: Whether to shuffle the dataset.
is_vlm: Whether the dataset is a vision-language model dataset.
requires_target: Whether 'target' field is required when 'hidden_state' is present.
Set to False for DFlash, True for Eagle3.
**dataloader_kwargs: Additional keyword arguments for the DataLoader.

Returns:
Expand All @@ -296,7 +333,7 @@ def prepare_dp_dataloaders(
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
collate_fn=datacollator_cls(),
collate_fn=datacollator_cls(requires_target=requires_target),
drop_last=True,
**dataloader_kwargs,
)
Expand Down
Loading
Loading