Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ dist/
eval/
*_ckpt*/
output/
outputs/
outs/
wandb/
tools/results/
__pycache__/
__pycache__/outputs/
6 changes: 4 additions & 2 deletions angelslim/compressor/speculative/train/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,13 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
"target_hiddens": None,
}

# Check if both hidden_states and target_hiddens exist in all features
if all("hidden_states" in item and "target_hiddens" in item for item in features):
# Handle hidden_states and target_hiddens independently
if all("hidden_states" in item for item in features):
batch["hidden_states"] = torch.cat(
[paddingtensor(item["hidden_states"], max_length) for item in features]
)

if all("target_hiddens" in item for item in features):
batch["target_hiddens"] = torch.cat(
[paddingtensor(item["target_hiddens"], max_length) for item in features]
)
Expand Down
37 changes: 26 additions & 11 deletions angelslim/compressor/speculative/train/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,41 +174,56 @@ def _create_online_datasets(
if self.display:
num_proc = None

# Determine min_loss_tokens for DFlash filtering
min_loss_tokens = None
if self.data_args.modal_type == "DFlash":
block_size = getattr(self.data_args, "block_size", 16)
min_loss_tokens = 2 * block_size

# Create training dataset
train_dataset = None
if self.data_args.train_data_path is not None:
train_path = getattr(self.data_args, "train_data_path", None)
if train_path is not None:
train_dataset = self.online_dataset_builder.build_dataset(
self.data_args.train_data_path,
train_path,
num_proc=num_proc,
shuffle=True,
sample_num=self.data_args.sample_num,
sample_num=getattr(self.data_args, "sample_num", None),
min_loss_tokens=min_loss_tokens,
)

# Create evaluation dataset
eval_dataset = None
if self.data_args.eval_data_path is not None:
eval_path = getattr(self.data_args, "eval_data_path", None)
if eval_path is not None:
eval_dataset = self.online_dataset_builder.build_dataset(
self.data_args.eval_data_path,
eval_path,
num_proc=num_proc,
shuffle=False,
sample_num=self.data_args.sample_num,
sample_num=getattr(self.data_args, "sample_num", None),
min_loss_tokens=min_loss_tokens,
)

data_collator = self.online_dataset_builder.get_data_collator()

return train_dataset, eval_dataset, data_collator

def _create_offline_datasets(self) -> Tuple[Dataset, Optional[Dataset]]:
def _create_offline_datasets(self) -> Tuple[Dataset, Optional[Dataset], Any]:
"""
Create offline datasets from pre-computed .ckpt files.

Returns:
Tuple of (train_dataset, eval_dataset)
Tuple of (train_dataset, eval_dataset, data_collator)
"""
if self.offline_dataset_builder is None:
return None, None, None

# Create train dataset
train_dataset = self.offline_dataset_builder.build_dataset(
self.data_args.train_hidden_path
)
train_dataset = None
if self.data_args.train_hidden_path is not None:
train_dataset = self.offline_dataset_builder.build_dataset(
self.data_args.train_hidden_path
)

# Create eval dataset if path is provided
eval_dataset = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
class DatasetBuilder(metaclass=ABCMeta):
@abstractmethod
def build_dataset(
self, datapath: str, num_proc: int = 8, shuffle: bool = True, **kwargs
self,
datapath: str,
num_proc: int = 8,
shuffle: bool = True,
min_loss_tokens: Optional[int] = None,
**kwargs,
) -> Dataset:
pass

Expand Down Expand Up @@ -127,6 +132,7 @@ def build_dataset(
num_proc: int = 8,
shuffle: bool = True,
sample_num: Optional[int] = None,
min_loss_tokens: Optional[int] = None,
) -> Dataset:
try:
# Load dataset
Expand Down Expand Up @@ -161,6 +167,18 @@ def build_dataset(
num_proc=num_proc,
desc="Filtering empty input_ids",
)

if min_loss_tokens is not None:
processed_ds = processed_ds.filter(
lambda batch: [
sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens
for m in batch["loss_mask"]
],
batched=True,
num_proc=num_proc,
desc=f"Filtering sequences with loss tokens < {min_loss_tokens}",
)

processed_ds.set_format(type="torch")

return processed_ds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def build_dataset(
num_proc: int = 8,
shuffle: bool = True,
sample_num: Optional[int] = None,
min_loss_tokens: Optional[int] = None,
) -> Dataset:
try:
# Load dataset
Expand Down Expand Up @@ -146,11 +147,19 @@ def build_dataset(
num_proc=num_proc,
desc="Filtering empty input_ids",
)
if min_loss_tokens is not None:
processed_ds = processed_ds.filter(
lambda batch: [
sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens
for m in batch["loss_mask"]
],
batched=True,
num_proc=num_proc,
desc=f"Filtering sequences with loss tokens < {min_loss_tokens}",
)

torch_columns = [c for c in processed_ds.column_names if c != "image_paths"]
processed_ds.set_format(type="torch", columns=torch_columns, output_all_columns=True)
rank0_print(
f"processed_ds size:{len(processed_ds)}, columns: {processed_ds.column_names}"
)

return processed_ds

Expand Down Expand Up @@ -324,6 +333,7 @@ def build_dataset(
num_proc: int = 8,
shuffle: bool = True,
sample_num: Optional[int] = None,
min_loss_tokens: Optional[int] = None,
) -> Dataset:
try:
# Load dataset
Expand Down Expand Up @@ -374,6 +384,16 @@ def build_dataset(
num_proc=num_proc,
desc="Filtering empty input_ids",
)
if min_loss_tokens is not None:
processed_ds = processed_ds.filter(
lambda batch: [
sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens
for m in batch["loss_mask"]
],
batched=True,
num_proc=num_proc,
desc=f"Filtering sequences with loss tokens < {min_loss_tokens}",
)
torch_columns = [c for c in processed_ds.column_names if c != "image_paths"]
processed_ds.set_format(type="torch", columns=torch_columns, output_all_columns=True)

Expand Down Expand Up @@ -572,6 +592,7 @@ def build_dataset(
num_proc: int = 8,
shuffle: bool = True,
sample_num: Optional[int] = None,
min_loss_tokens: Optional[int] = None,
) -> Dataset:
try:
# Load dataset
Expand Down Expand Up @@ -623,6 +644,18 @@ def build_dataset(
num_proc=num_proc,
desc="Filtering empty input_ids",
)

if min_loss_tokens is not None:
processed_ds = processed_ds.filter(
lambda batch: [
sum(sum(x) if isinstance(x, list) else x for x in m) >= min_loss_tokens
for m in batch["loss_mask"]
],
batched=True,
num_proc=num_proc,
desc=f"Filtering sequences with loss tokens < {min_loss_tokens}",
)

processed_ds.set_format(type="torch")

return processed_ds
Expand Down Expand Up @@ -886,6 +919,7 @@ def build_dataset(
num_proc: int = 8,
shuffle: bool = True,
sample_num: Optional[int] = None,
min_loss_tokens: Optional[int] = None,
) -> Dataset:
try:
if not isinstance(datapath, list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

from .draft_model_factory import DraftModelConfig, create_draft_model
from .llama_eagle3 import CosyVoice3Eagle3LlamaForCausalLM, Eagle3LlamaForCausalLM
from .qwen_dflash import QwenDFlashDraftModel

__all__ = [
"create_draft_model",
"DraftModelConfig",
"Eagle3LlamaForCausalLM",
"CosyVoice3Eagle3LlamaForCausalLM",
"QwenDFlashDraftModel",
]
Loading
Loading