Skip to content

Commit 711800a

Browse files
committed
feat(speculative): optimize hidden_states file size with storage_precision & int8_quantization
1 parent e8f44ae commit 711800a

4 files changed

Lines changed: 225 additions & 34 deletions

File tree

angelslim/compressor/speculative/train/data/dataset_builder/offline_dataset_builder.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ def __init__(self, memmap_dirs: list):
240240
self._dir_offsets = [] # [np.memmap, ...]
241241
self._dir_metadata = [] # [metadata_dict, ...]
242242

243+
# Storage optimization attributes
244+
self._storage_precision = "float32" # Default precision
245+
self._int8_quantization = False # Whether int8 quantization is enabled
246+
243247
total_samples = 0
244248
for dir_idx, memmap_dir in enumerate(memmap_dirs):
245249
metadata_path = os.path.join(memmap_dir, "metadata.json")
@@ -249,6 +253,12 @@ def __init__(self, memmap_dirs: list):
249253
self._dir_metadata.append(metadata)
250254
n_samples = metadata["total_samples"]
251255

256+
# Read storage optimization config
257+
if "storage_optimization" in metadata:
258+
storage_opt = metadata["storage_optimization"]
259+
self._storage_precision = storage_opt.get("precision", "float32")
260+
self._int8_quantization = storage_opt.get("int8_quantization", False)
261+
252262
# Open memmap files (read-only mode)
253263
field_memmaps = {}
254264
for field_name, field_info in metadata["fields"].items():
@@ -289,6 +299,8 @@ def __init__(self, memmap_dirs: list):
289299
rank0_print(
290300
f"[MemmapDataset] Dir {dir_idx}: {memmap_dir} "
291301
f"({n_samples} samples, {metadata.get('total_tokens', 'N/A')} total tokens)"
302+
f" | Storage Opt: precision={self._storage_precision}, "
303+
f"int8_quantization={self._int8_quantization}"
292304
)
293305

294306
self.total_samples = total_samples
@@ -298,36 +310,30 @@ def __init__(self, memmap_dirs: list):
298310
f"from {len(memmap_dirs)} memmap directories"
299311
)
300312

301-
def __len__(self) -> int:
302-
return self.total_samples
303-
304-
def get_sample_length(self, idx: int) -> int:
305-
"""Get the sequence length of a sample without loading data.
313+
def _convert_precision(self, tensor: torch.Tensor) -> torch.Tensor:
314+
"""Convert tensor precision based on storage config."""
315+
if self._storage_precision == "bfloat16":
316+
return tensor.to(torch.bfloat16)
317+
elif self._storage_precision == "float16":
318+
return tensor.to(torch.float16)
319+
else: # float32 or unknown
320+
return tensor.to(torch.float32)
306321

307-
This is an O(1) operation that only reads from the offsets array,
308-
useful for length-based bucketing samplers to reduce padding waste.
322+
@staticmethod
323+
def _dequantize_per_token_absmax_int8(
324+
quantized: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype = torch.bfloat16
325+
):
326+
"""Dequantize per-token absmax int8 back to float.
309327
310328
Args:
311-
idx: Global sample index.
312-
313-
Returns:
314-
Sequence length (number of tokens) of the sample.
315-
"""
316-
dir_idx, local_idx = self.global_index[idx]
317-
offsets = self._dir_offsets[dir_idx]
318-
return int(offsets[local_idx + 1]) - int(offsets[local_idx])
319-
320-
def get_all_sample_lengths(self) -> np.ndarray:
321-
"""Get sequence lengths for all samples efficiently.
329+
quantized: int8 tensor of shape [B, N, D]
330+
scale: float tensor of shape [B, N, 1]
331+
target_dtype: Target float dtype
322332
323333
Returns:
324-
numpy array of shape (total_samples,) with each sample's length.
334+
Dequantized float tensor of shape [B, N, D]
325335
"""
326-
lengths = np.empty(self.total_samples, dtype=np.int64)
327-
for i, (dir_idx, local_idx) in enumerate(self.global_index):
328-
offsets = self._dir_offsets[dir_idx]
329-
lengths[i] = int(offsets[local_idx + 1]) - int(offsets[local_idx])
330-
return lengths
336+
return quantized.to(target_dtype) * scale.to(target_dtype)
331337

332338
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
333339
dir_idx, local_idx = self.global_index[idx]
@@ -348,10 +354,27 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
348354
# additional copy.
349355
tensor = torch.from_numpy(np.ascontiguousarray(arr))
350356

357+
# Apply precision conversion for float tensors (skip int8 quantized data)
358+
if tensor.is_floating_point():
359+
tensor = self._convert_precision(tensor)
360+
351361
# Add batch dimension [1, seq_len, ...]
352362
tensor = tensor.unsqueeze(0)
353363
data[field_name] = tensor
354364

365+
# Dequantize int8 quantized hidden_states and target_hiddens
366+
if self._int8_quantization:
367+
target_dtype = (
368+
torch.bfloat16 if self._storage_precision == "bfloat16" else torch.float16
369+
)
370+
for base_name in ("hidden_states", "target_hiddens"):
371+
int8_key = f"{base_name}_int8"
372+
scales_key = f"{base_name}_scales"
373+
if int8_key in data and scales_key in data:
374+
data[base_name] = self._dequantize_per_token_absmax_int8(
375+
data.pop(int8_key), data.pop(scales_key), target_dtype=target_dtype
376+
)
377+
355378
# Generate attention_mask
356379
if "input_ids" in data:
357380
data["attention_mask"] = torch.ones_like(data["input_ids"])
@@ -463,7 +486,7 @@ def __init__(
463486
super().__init__(dataset)
464487
self.dataset = dataset
465488
self.batch_size = batch_size
466-
self.bucket_size = bucket_size or max(batch_size * 50, 100)
489+
self.bucket_size = bucket_size or max(batch_size * 20, 100)
467490
self.seed = seed
468491
self.epoch = 0
469492

angelslim/compressor/speculative/train/trainer/eagle3_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
self._perf_total_step_time: float = 0.0
113113
self._perf_last_step_end: float = 0.0
114114
self._perf_dataloader_wait_time: float = 0.0
115-
self._perf_log_interval: int = 50 # Log performance stats every N global steps
115+
self._perf_log_interval: int = 20 # Log performance stats every N global steps
116116
self._perf_last_logged_global_step: int = 0 # Last global step when PERF was logged
117117

118118
# PyTorch Profiler

tools/generate_hidden_for_draft_model.py

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
rank: int = 0,
8383
draft_vocab_size: int = None,
8484
target_vocab_size: int = None,
85+
storage_precision: str = "bfloat16",
86+
enable_int8_quantization: bool = False,
8587
):
8688
"""
8789
Initialize the hidden state generator.
@@ -92,12 +94,20 @@ def __init__(
9294
rank: Process rank for distributed training
9395
draft_vocab_size: Size of draft model vocabulary (required for vocab mapping)
9496
target_vocab_size: Size of target model vocabulary (required for vocab mapping)
97+
storage_precision: Storage precision for hidden states (bfloat16, float16, float32)
98+
enable_int8_quantization: Whether to enable per-token absmax int8 quantization
99+
for hidden_states and target_hiddens. Reduces storage by ~50%.
95100
"""
96101
self.target_model = target_model
97102
self.output_dir = Path(output_dir)
98103
self.rank = rank
99104
self.draft_vocab_size = draft_vocab_size
100105
self.target_vocab_size = target_vocab_size
106+
107+
# Storage optimization config
108+
self.storage_precision = storage_precision
109+
self.enable_int8_quantization = enable_int8_quantization
110+
101111
_max_pixels = os.environ.get("MAX_PIXELS")
102112
_min_pixels = os.environ.get("MIN_PIXELS", "1024")
103113
self.max_pixels = int(_max_pixels) if _max_pixels is not None else None
@@ -122,6 +132,89 @@ def __init__(
122132
# Resolve image_pad token id for vLLM loss_mask rebuilding
123133
self._image_pad_token_id = None
124134

135+
@staticmethod
136+
def _quantize_per_token_absmax_int8(tensor: torch.Tensor):
137+
"""Per-token absmax int8 quantization.
138+
139+
For a tensor of shape [B, N, D], compute per-token (per row in the
140+
last two dims) scale = max(|x|) / 127, then quantize to int8.
141+
142+
Args:
143+
tensor: Float tensor of shape [B, N, D]
144+
145+
Returns:
146+
Tuple of (quantized_int8 [B, N, D], scales [B, N, 1])
147+
"""
148+
# tensor: [B, N, D]
149+
absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) # [B, N, 1]
150+
scale = absmax / 127.0 # [B, N, 1]
151+
quantized = (tensor / scale).round().clamp(-127, 127).to(torch.int8) # [B, N, D]
152+
return quantized, scale
153+
154+
@staticmethod
155+
def _dequantize_per_token_absmax_int8(
156+
quantized: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype = torch.bfloat16
157+
):
158+
"""Dequantize per-token absmax int8 back to float.
159+
160+
Args:
161+
quantized: int8 tensor of shape [B, N, D]
162+
scale: float tensor of shape [B, N, 1]
163+
target_dtype: Target float dtype
164+
165+
Returns:
166+
Dequantized float tensor of shape [B, N, D]
167+
"""
168+
return quantized.to(target_dtype) * scale.to(target_dtype)
169+
170+
def _convert_precision(self, data_point: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
171+
"""Convert data precision for storage."""
172+
dtype_map = {
173+
"bfloat16": torch.bfloat16,
174+
"float16": torch.float16,
175+
"float32": torch.float32,
176+
}
177+
target_dtype = dtype_map.get(self.storage_precision, torch.bfloat16)
178+
179+
converted_data_point = {}
180+
for key, tensor in data_point.items():
181+
if isinstance(tensor, torch.Tensor) and tensor.is_floating_point():
182+
if tensor.dtype != target_dtype:
183+
converted_data_point[key] = tensor.to(target_dtype)
184+
else:
185+
converted_data_point[key] = tensor
186+
else:
187+
converted_data_point[key] = tensor
188+
189+
return converted_data_point
190+
191+
def _apply_int8_quantization(
192+
self, data_point: Dict[str, torch.Tensor]
193+
) -> Dict[str, torch.Tensor]:
194+
"""Apply per-token absmax int8 quantization to hidden_states and target_hiddens.
195+
196+
Replaces the original float tensors with int8 data + float32 scales:
197+
hidden_states -> hidden_states_int8 (int8) + hidden_states_scales (float32)
198+
target_hiddens -> target_hiddens_int8 (int8) + target_hiddens_scales (float32)
199+
200+
This reduces storage by ~50% for these large tensors.
201+
"""
202+
if not self.enable_int8_quantization:
203+
return data_point
204+
205+
new_data_point = {}
206+
for key, tensor in data_point.items():
207+
if key in ("hidden_states", "target_hiddens") and isinstance(tensor, torch.Tensor):
208+
# Quantize: tensor shape [B, N, D]
209+
q_int8, scales = self._quantize_per_token_absmax_int8(tensor.float())
210+
new_data_point[f"{key}_int8"] = q_int8 # [B, N, D] int8
211+
new_data_point[f"{key}_scales"] = scales # [B, N, 1] float32
212+
# Do NOT keep the original float tensor
213+
else:
214+
new_data_point[key] = tensor
215+
216+
return new_data_point
217+
125218
def _resolve_image_pad_token_id(self):
126219
"""Lazily resolve the image_pad token id from the target model's tokenizer."""
127220
if self._image_pad_token_id is not None:
@@ -287,7 +380,9 @@ def _init_memmap(self, data_point: Dict[str, torch.Tensor], total_samples: int):
287380
extra_dims = per_sample_shape[1:] # () or (D,) or (3*D,)
288381

289382
# Determine numpy dtype
290-
if tensor.dtype == torch.bfloat16:
383+
if tensor.dtype == torch.int8:
384+
np_dtype = np.int8
385+
elif tensor.dtype == torch.bfloat16:
291386
np_dtype = np.float16
292387
elif tensor.dtype == torch.float16:
293388
np_dtype = np.float16
@@ -412,12 +507,23 @@ def _expand_sample_capacity(self):
412507
)
413508

414509
def _write_sample_to_memmap(self, data_point: Dict[str, torch.Tensor]):
415-
"""Write a single sample to packed memmap files.
510+
"""Write a single sample to packed memmap files with storage optimizations.
416511
417512
Data is appended at position self._total_tokens_written in the packed array.
418513
The offsets array is updated to record the boundary.
419514
"""
420-
# Get the seq_len of the current sample
515+
# Apply storage optimizations:
516+
# 1. Per-token absmax int8 quantization for hidden_states/target_hiddens
517+
data_point = self._apply_int8_quantization(data_point)
518+
519+
# 2. Convert remaining float tensors to target precision
520+
data_point = self._convert_precision(data_point)
521+
522+
# Initialize memmap on first write
523+
if not self._memmap_initialized:
524+
self._init_memmap(data_point, self._total_samples_estimate)
525+
526+
# Get seq_len after optimization
421527
sample_seq_len = 0
422528
for _, tensor in data_point.items():
423529
if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2:
@@ -444,6 +550,8 @@ def _write_sample_to_memmap(self, data_point: Dict[str, torch.Tensor]):
444550
arr = tensor.squeeze(0) # [N, ...]
445551
if tensor.dtype == torch.bfloat16:
446552
arr = arr.float().half() # bfloat16 -> float32 -> float16
553+
elif tensor.dtype == torch.int8:
554+
pass # int8 can be directly converted to numpy
447555
arr_np = arr.contiguous().numpy()
448556

449557
seq_len = arr_np.shape[0]
@@ -524,11 +632,15 @@ def _finalize_memmap(self):
524632
old_offsets_path = self._memmap_dir / "offsets.npy"
525633
os.replace(str(final_offsets_path), str(old_offsets_path))
526634

527-
# Save metadata JSON
635+
# Save metadata JSON with storage optimization info
528636
metadata = {
529637
"format": "packed", # Distinguish from the old rectangular format
530638
"total_samples": self._sample_count,
531639
"total_tokens": total_tokens,
640+
"storage_optimization": {
641+
"precision": self.storage_precision,
642+
"int8_quantization": self.enable_int8_quantization,
643+
},
532644
"fields": {},
533645
}
534646
for field_name in self._field_dtypes:
@@ -659,10 +771,6 @@ def _process_single_sample(self, idx: int, row: Dict[str, Any]) -> bool:
659771
batch_token_dict = dict(zip(unique_ids.tolist(), counts.tolist()))
660772
self.token_dict.update(batch_token_dict)
661773

662-
# Initialize memmap on first write
663-
if not self._memmap_initialized:
664-
self._init_memmap(data_point, self._total_samples_estimate)
665-
666774
# Write directly to memmap
667775
self._write_sample_to_memmap(data_point)
668776
return True
@@ -882,6 +990,23 @@ def parse_arguments() -> argparse.Namespace:
882990
help="Path to draft model config file, used to read draft_vocab_size and vocab_size "
883991
"for computing vocab mapping",
884992
)
993+
994+
# Storage optimization arguments
995+
parser.add_argument(
996+
"--storage_precision",
997+
type=str,
998+
default="bfloat16",
999+
choices=["bfloat16", "float16", "float32"],
1000+
help="Storage precision for hidden states. bfloat16/float16 reduce storage by 50%%",
1001+
)
1002+
parser.add_argument(
1003+
"--enable_int8_quantization",
1004+
action="store_true",
1005+
default=False,
1006+
help="Enable per-token absmax int8 quantization for hidden_states and target_hiddens. "
1007+
"Reduces storage by ~50%% for these tensors with minimal quality loss.",
1008+
)
1009+
8851010
return parser.parse_args()
8861011

8871012

@@ -1074,12 +1199,22 @@ def main():
10741199
output_dir = f"{args.outdir}/rank_{rank}"
10751200
logger.info(f"writing hidden states to {output_dir}", extra={"rank": rank})
10761201

1202+
# Log storage optimization settings
1203+
logger.info(
1204+
f"[Storage Optimization Settings] "
1205+
f"Precision: {args.storage_precision}, "
1206+
f"Int8 Quantization: {args.enable_int8_quantization}",
1207+
extra={"rank": rank},
1208+
)
1209+
10771210
generator = HiddenStateGenerator(
10781211
target_model,
10791212
output_dir,
10801213
rank=rank,
10811214
draft_vocab_size=draft_vocab_size,
10821215
target_vocab_size=target_vocab_size,
1216+
storage_precision=args.storage_precision,
1217+
enable_int8_quantization=args.enable_int8_quantization,
10831218
)
10841219
successful, failed = generator.generate(dataset_slice)
10851220

0 commit comments

Comments
 (0)