Skip to content

Commit efdd29b

Browse files
committed
feat(speculative): optimize hidden_states file size with storage_precision & int8_quantization
1 parent e8f44ae commit efdd29b

4 files changed

Lines changed: 216 additions & 29 deletions

File tree

angelslim/compressor/speculative/train/data/dataset_builder/offline_dataset_builder.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ def __init__(self, memmap_dirs: list):
240240
self._dir_offsets = [] # [np.memmap, ...]
241241
self._dir_metadata = [] # [metadata_dict, ...]
242242

243+
# Storage optimization attributes
244+
self._storage_precision = "float32" # Default precision
245+
self._int8_quantization = False # Whether int8 quantization is enabled
246+
243247
total_samples = 0
244248
for dir_idx, memmap_dir in enumerate(memmap_dirs):
245249
metadata_path = os.path.join(memmap_dir, "metadata.json")
@@ -249,6 +253,12 @@ def __init__(self, memmap_dirs: list):
249253
self._dir_metadata.append(metadata)
250254
n_samples = metadata["total_samples"]
251255

256+
# Read storage optimization config
257+
if "storage_optimization" in metadata:
258+
storage_opt = metadata["storage_optimization"]
259+
self._storage_precision = storage_opt.get("precision", "float32")
260+
self._int8_quantization = storage_opt.get("int8_quantization", False)
261+
252262
# Open memmap files (read-only mode)
253263
field_memmaps = {}
254264
for field_name, field_info in metadata["fields"].items():
@@ -289,6 +299,8 @@ def __init__(self, memmap_dirs: list):
289299
rank0_print(
290300
f"[MemmapDataset] Dir {dir_idx}: {memmap_dir} "
291301
f"({n_samples} samples, {metadata.get('total_tokens', 'N/A')} total tokens)"
302+
f" | Storage Opt: precision={self._storage_precision}, "
303+
f"int8_quantization={self._int8_quantization}"
292304
)
293305

294306
self.total_samples = total_samples
@@ -298,36 +310,30 @@ def __init__(self, memmap_dirs: list):
298310
f"from {len(memmap_dirs)} memmap directories"
299311
)
300312

301-
def __len__(self) -> int:
302-
return self.total_samples
303-
304-
def get_sample_length(self, idx: int) -> int:
305-
"""Get the sequence length of a sample without loading data.
313+
def _convert_precision(self, tensor: torch.Tensor) -> torch.Tensor:
314+
"""Convert tensor precision based on storage config."""
315+
if self._storage_precision == "bfloat16":
316+
return tensor.to(torch.bfloat16)
317+
elif self._storage_precision == "float16":
318+
return tensor.to(torch.float16)
319+
else: # float32 or unknown
320+
return tensor.to(torch.float32)
306321

307-
This is an O(1) operation that only reads from the offsets array,
308-
useful for length-based bucketing samplers to reduce padding waste.
322+
@staticmethod
323+
def _dequantize_per_token_absmax_int8(
324+
quantized: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype = torch.bfloat16
325+
):
326+
"""Dequantize per-token absmax int8 back to float.
309327
310328
Args:
311-
idx: Global sample index.
312-
313-
Returns:
314-
Sequence length (number of tokens) of the sample.
315-
"""
316-
dir_idx, local_idx = self.global_index[idx]
317-
offsets = self._dir_offsets[dir_idx]
318-
return int(offsets[local_idx + 1]) - int(offsets[local_idx])
319-
320-
def get_all_sample_lengths(self) -> np.ndarray:
321-
"""Get sequence lengths for all samples efficiently.
329+
quantized: int8 tensor of shape [B, N, D]
330+
scale: float tensor of shape [B, N, 1]
331+
target_dtype: Target float dtype
322332
323333
Returns:
324-
numpy array of shape (total_samples,) with each sample's length.
334+
Dequantized float tensor of shape [B, N, D]
325335
"""
326-
lengths = np.empty(self.total_samples, dtype=np.int64)
327-
for i, (dir_idx, local_idx) in enumerate(self.global_index):
328-
offsets = self._dir_offsets[dir_idx]
329-
lengths[i] = int(offsets[local_idx + 1]) - int(offsets[local_idx])
330-
return lengths
336+
return quantized.to(target_dtype) * scale.to(target_dtype)
331337

332338
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
333339
dir_idx, local_idx = self.global_index[idx]
@@ -348,10 +354,27 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
348354
# additional copy.
349355
tensor = torch.from_numpy(np.ascontiguousarray(arr))
350356

357+
# Apply precision conversion for float tensors (skip int8 quantized data)
358+
if tensor.is_floating_point():
359+
tensor = self._convert_precision(tensor)
360+
351361
# Add batch dimension [1, seq_len, ...]
352362
tensor = tensor.unsqueeze(0)
353363
data[field_name] = tensor
354364

365+
# Dequantize int8 quantized hidden_states and target_hiddens
366+
if self._int8_quantization:
367+
target_dtype = (
368+
torch.bfloat16 if self._storage_precision == "bfloat16" else torch.float16
369+
)
370+
for base_name in ("hidden_states", "target_hiddens"):
371+
int8_key = f"{base_name}_int8"
372+
scales_key = f"{base_name}_scales"
373+
if int8_key in data and scales_key in data:
374+
data[base_name] = self._dequantize_per_token_absmax_int8(
375+
data.pop(int8_key), data.pop(scales_key), target_dtype=target_dtype
376+
)
377+
355378
# Generate attention_mask
356379
if "input_ids" in data:
357380
data["attention_mask"] = torch.ones_like(data["input_ids"])
@@ -463,7 +486,7 @@ def __init__(
463486
super().__init__(dataset)
464487
self.dataset = dataset
465488
self.batch_size = batch_size
466-
self.bucket_size = bucket_size or max(batch_size * 50, 100)
489+
self.bucket_size = bucket_size or max(batch_size * 20, 100)
467490
self.seed = seed
468491
self.epoch = 0
469492

angelslim/compressor/speculative/train/trainer/eagle3_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
self._perf_total_step_time: float = 0.0
113113
self._perf_last_step_end: float = 0.0
114114
self._perf_dataloader_wait_time: float = 0.0
115-
self._perf_log_interval: int = 50 # Log performance stats every N global steps
115+
self._perf_log_interval: int = 20 # Log performance stats every N global steps
116116
self._perf_last_logged_global_step: int = 0 # Last global step when PERF was logged
117117

118118
# PyTorch Profiler

tools/generate_hidden_for_draft_model.py

Lines changed: 134 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
rank: int = 0,
8383
draft_vocab_size: int = None,
8484
target_vocab_size: int = None,
85+
storage_precision: str = "bfloat16",
86+
enable_int8_quantization: bool = False,
8587
):
8688
"""
8789
Initialize the hidden state generator.
@@ -92,12 +94,20 @@ def __init__(
9294
rank: Process rank for distributed training
9395
draft_vocab_size: Size of draft model vocabulary (required for vocab mapping)
9496
target_vocab_size: Size of target model vocabulary (required for vocab mapping)
97+
storage_precision: Storage precision for hidden states (bfloat16, float16, float32)
98+
enable_int8_quantization: Whether to enable per-token absmax int8 quantization
99+
for hidden_states and target_hiddens. Reduces storage by ~50%.
95100
"""
96101
self.target_model = target_model
97102
self.output_dir = Path(output_dir)
98103
self.rank = rank
99104
self.draft_vocab_size = draft_vocab_size
100105
self.target_vocab_size = target_vocab_size
106+
107+
# Storage optimization config
108+
self.storage_precision = storage_precision
109+
self.enable_int8_quantization = enable_int8_quantization
110+
101111
_max_pixels = os.environ.get("MAX_PIXELS")
102112
_min_pixels = os.environ.get("MIN_PIXELS", "1024")
103113
self.max_pixels = int(_max_pixels) if _max_pixels is not None else None
@@ -122,6 +132,89 @@ def __init__(
122132
# Resolve image_pad token id for vLLM loss_mask rebuilding
123133
self._image_pad_token_id = None
124134

135+
@staticmethod
136+
def _quantize_per_token_absmax_int8(tensor: torch.Tensor):
137+
"""Per-token absmax int8 quantization.
138+
139+
For a tensor of shape [B, N, D], compute per-token (per row in the
140+
last two dims) scale = max(|x|) / 127, then quantize to int8.
141+
142+
Args:
143+
tensor: Float tensor of shape [B, N, D]
144+
145+
Returns:
146+
Tuple of (quantized_int8 [B, N, D], scales [B, N, 1])
147+
"""
148+
# tensor: [B, N, D]
149+
absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) # [B, N, 1]
150+
scale = absmax / 127.0 # [B, N, 1]
151+
quantized = (tensor / scale).round().clamp(-127, 127).to(torch.int8) # [B, N, D]
152+
return quantized, scale
153+
154+
@staticmethod
155+
def _dequantize_per_token_absmax_int8(
156+
quantized: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype = torch.bfloat16
157+
):
158+
"""Dequantize per-token absmax int8 back to float.
159+
160+
Args:
161+
quantized: int8 tensor of shape [B, N, D]
162+
scale: float tensor of shape [B, N, 1]
163+
target_dtype: Target float dtype
164+
165+
Returns:
166+
Dequantized float tensor of shape [B, N, D]
167+
"""
168+
return quantized.to(target_dtype) * scale.to(target_dtype)
169+
170+
def _convert_precision(self, data_point: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
171+
"""Convert data precision for storage."""
172+
dtype_map = {
173+
"bfloat16": torch.bfloat16,
174+
"float16": torch.float16,
175+
"float32": torch.float32,
176+
}
177+
target_dtype = dtype_map.get(self.storage_precision, torch.bfloat16)
178+
179+
converted_data_point = {}
180+
for key, tensor in data_point.items():
181+
if isinstance(tensor, torch.Tensor) and tensor.is_floating_point():
182+
if tensor.dtype != target_dtype:
183+
converted_data_point[key] = tensor.to(target_dtype)
184+
else:
185+
converted_data_point[key] = tensor
186+
else:
187+
converted_data_point[key] = tensor
188+
189+
return converted_data_point
190+
191+
def _apply_int8_quantization(
192+
self, data_point: Dict[str, torch.Tensor]
193+
) -> Dict[str, torch.Tensor]:
194+
"""Apply per-token absmax int8 quantization to hidden_states and target_hiddens.
195+
196+
Replaces the original float tensors with int8 data + float32 scales:
197+
hidden_states -> hidden_states_int8 (int8) + hidden_states_scales (float32)
198+
target_hiddens -> target_hiddens_int8 (int8) + target_hiddens_scales (float32)
199+
200+
This reduces storage by ~50% for these large tensors.
201+
"""
202+
if not self.enable_int8_quantization:
203+
return data_point
204+
205+
new_data_point = {}
206+
for key, tensor in data_point.items():
207+
if key in ("hidden_states", "target_hiddens") and isinstance(tensor, torch.Tensor):
208+
# Quantize: tensor shape [B, N, D]
209+
q_int8, scales = self._quantize_per_token_absmax_int8(tensor.float())
210+
new_data_point[f"{key}_int8"] = q_int8 # [B, N, D] int8
211+
new_data_point[f"{key}_scales"] = scales # [B, N, 1] float32
212+
# Do NOT keep the original float tensor
213+
else:
214+
new_data_point[key] = tensor
215+
216+
return new_data_point
217+
125218
def _resolve_image_pad_token_id(self):
126219
"""Lazily resolve the image_pad token id from the target model's tokenizer."""
127220
if self._image_pad_token_id is not None:
@@ -412,12 +505,19 @@ def _expand_sample_capacity(self):
412505
)
413506

414507
def _write_sample_to_memmap(self, data_point: Dict[str, torch.Tensor]):
415-
"""Write a single sample to packed memmap files.
508+
"""Write a single sample to packed memmap files with storage optimizations.
416509
417510
Data is appended at position self._total_tokens_written in the packed array.
418511
The offsets array is updated to record the boundary.
419512
"""
420-
# Get the seq_len of the current sample
513+
# Apply storage optimizations:
514+
# 1. Per-token absmax int8 quantization for hidden_states/target_hiddens
515+
data_point = self._apply_int8_quantization(data_point)
516+
517+
# 2. Convert remaining float tensors to target precision
518+
data_point = self._convert_precision(data_point)
519+
520+
# Get seq_len after optimization
421521
sample_seq_len = 0
422522
for _, tensor in data_point.items():
423523
if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2:
@@ -524,11 +624,15 @@ def _finalize_memmap(self):
524624
old_offsets_path = self._memmap_dir / "offsets.npy"
525625
os.replace(str(final_offsets_path), str(old_offsets_path))
526626

527-
# Save metadata JSON
627+
# Save metadata JSON with storage optimization info
528628
metadata = {
529629
"format": "packed", # Distinguish from the old rectangular format
530630
"total_samples": self._sample_count,
531631
"total_tokens": total_tokens,
632+
"storage_optimization": {
633+
"precision": self.storage_precision,
634+
"int8_quantization": self.enable_int8_quantization,
635+
},
532636
"fields": {},
533637
}
534638
for field_name in self._field_dtypes:
@@ -882,6 +986,23 @@ def parse_arguments() -> argparse.Namespace:
882986
help="Path to draft model config file, used to read draft_vocab_size and vocab_size "
883987
"for computing vocab mapping",
884988
)
989+
990+
# Storage optimization arguments
991+
parser.add_argument(
992+
"--storage_precision",
993+
type=str,
994+
default="bfloat16",
995+
choices=["bfloat16", "float16", "float32"],
996+
help="Storage precision for hidden states. bfloat16/float16 reduce storage by 50%%",
997+
)
998+
parser.add_argument(
999+
"--enable_int8_quantization",
1000+
action="store_true",
1001+
default=False,
1002+
help="Enable per-token absmax int8 quantization for hidden_states and target_hiddens. "
1003+
"Reduces storage by ~50%% for these tensors with minimal quality loss.",
1004+
)
1005+
8851006
return parser.parse_args()
8861007

8871008

@@ -1074,12 +1195,22 @@ def main():
10741195
output_dir = f"{args.outdir}/rank_{rank}"
10751196
logger.info(f"writing hidden states to {output_dir}", extra={"rank": rank})
10761197

1198+
# Log storage optimization settings
1199+
logger.info(
1200+
f"[Storage Optimization Settings] "
1201+
f"Precision: {args.storage_precision}, "
1202+
f"Int8 Quantization: {args.enable_int8_quantization}",
1203+
extra={"rank": rank},
1204+
)
1205+
10771206
generator = HiddenStateGenerator(
10781207
target_model,
10791208
output_dir,
10801209
rank=rank,
10811210
draft_vocab_size=draft_vocab_size,
10821211
target_vocab_size=target_vocab_size,
1212+
storage_precision=args.storage_precision,
1213+
enable_int8_quantization=args.enable_int8_quantization,
10831214
)
10841215
successful, failed = generator.generate(dataset_slice)
10851216

0 commit comments

Comments
 (0)