@@ -82,6 +82,8 @@ def __init__(
8282 rank : int = 0 ,
8383 draft_vocab_size : int = None ,
8484 target_vocab_size : int = None ,
85+ storage_precision : str = "bfloat16" ,
86+ enable_int8_quantization : bool = False ,
8587 ):
8688 """
8789 Initialize the hidden state generator.
@@ -92,12 +94,20 @@ def __init__(
9294 rank: Process rank for distributed training
9395 draft_vocab_size: Size of draft model vocabulary (required for vocab mapping)
9496 target_vocab_size: Size of target model vocabulary (required for vocab mapping)
97+ storage_precision: Storage precision for hidden states (bfloat16, float16, float32)
98+ enable_int8_quantization: Whether to enable per-token absmax int8 quantization
99+ for hidden_states and target_hiddens. Reduces storage by ~50%.
95100 """
96101 self .target_model = target_model
97102 self .output_dir = Path (output_dir )
98103 self .rank = rank
99104 self .draft_vocab_size = draft_vocab_size
100105 self .target_vocab_size = target_vocab_size
106+
107+ # Storage optimization config
108+ self .storage_precision = storage_precision
109+ self .enable_int8_quantization = enable_int8_quantization
110+
101111 _max_pixels = os .environ .get ("MAX_PIXELS" )
102112 _min_pixels = os .environ .get ("MIN_PIXELS" , "1024" )
103113 self .max_pixels = int (_max_pixels ) if _max_pixels is not None else None
@@ -122,6 +132,89 @@ def __init__(
122132 # Resolve image_pad token id for vLLM loss_mask rebuilding
123133 self ._image_pad_token_id = None
124134
135+ @staticmethod
136+ def _quantize_per_token_absmax_int8 (tensor : torch .Tensor ):
137+ """Per-token absmax int8 quantization.
138+
139+ For a tensor of shape [B, N, D], compute per-token (per row in the
140+ last two dims) scale = max(|x|) / 127, then quantize to int8.
141+
142+ Args:
143+ tensor: Float tensor of shape [B, N, D]
144+
145+ Returns:
146+ Tuple of (quantized_int8 [B, N, D], scales [B, N, 1])
147+ """
148+ # tensor: [B, N, D]
149+ absmax = tensor .abs ().amax (dim = - 1 , keepdim = True ).clamp (min = 1e-10 ) # [B, N, 1]
150+ scale = absmax / 127.0 # [B, N, 1]
151+ quantized = (tensor / scale ).round ().clamp (- 127 , 127 ).to (torch .int8 ) # [B, N, D]
152+ return quantized , scale
153+
154+ @staticmethod
155+ def _dequantize_per_token_absmax_int8 (
156+ quantized : torch .Tensor , scale : torch .Tensor , target_dtype : torch .dtype = torch .bfloat16
157+ ):
158+ """Dequantize per-token absmax int8 back to float.
159+
160+ Args:
161+ quantized: int8 tensor of shape [B, N, D]
162+ scale: float tensor of shape [B, N, 1]
163+ target_dtype: Target float dtype
164+
165+ Returns:
166+ Dequantized float tensor of shape [B, N, D]
167+ """
168+ return quantized .to (target_dtype ) * scale .to (target_dtype )
169+
170+ def _convert_precision (self , data_point : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
171+ """Convert data precision for storage."""
172+ dtype_map = {
173+ "bfloat16" : torch .bfloat16 ,
174+ "float16" : torch .float16 ,
175+ "float32" : torch .float32 ,
176+ }
177+ target_dtype = dtype_map .get (self .storage_precision , torch .bfloat16 )
178+
179+ converted_data_point = {}
180+ for key , tensor in data_point .items ():
181+ if isinstance (tensor , torch .Tensor ) and tensor .is_floating_point ():
182+ if tensor .dtype != target_dtype :
183+ converted_data_point [key ] = tensor .to (target_dtype )
184+ else :
185+ converted_data_point [key ] = tensor
186+ else :
187+ converted_data_point [key ] = tensor
188+
189+ return converted_data_point
190+
191+ def _apply_int8_quantization (
192+ self , data_point : Dict [str , torch .Tensor ]
193+ ) -> Dict [str , torch .Tensor ]:
194+ """Apply per-token absmax int8 quantization to hidden_states and target_hiddens.
195+
196+ Replaces the original float tensors with int8 data + float32 scales:
197+ hidden_states -> hidden_states_int8 (int8) + hidden_states_scales (float32)
198+ target_hiddens -> target_hiddens_int8 (int8) + target_hiddens_scales (float32)
199+
200+ This reduces storage by ~50% for these large tensors.
201+ """
202+ if not self .enable_int8_quantization :
203+ return data_point
204+
205+ new_data_point = {}
206+ for key , tensor in data_point .items ():
207+ if key in ("hidden_states" , "target_hiddens" ) and isinstance (tensor , torch .Tensor ):
208+ # Quantize: tensor shape [B, N, D]
209+ q_int8 , scales = self ._quantize_per_token_absmax_int8 (tensor .float ())
210+ new_data_point [f"{ key } _int8" ] = q_int8 # [B, N, D] int8
211+ new_data_point [f"{ key } _scales" ] = scales # [B, N, 1] float32
212+ # Do NOT keep the original float tensor
213+ else :
214+ new_data_point [key ] = tensor
215+
216+ return new_data_point
217+
125218 def _resolve_image_pad_token_id (self ):
126219 """Lazily resolve the image_pad token id from the target model's tokenizer."""
127220 if self ._image_pad_token_id is not None :
@@ -287,7 +380,9 @@ def _init_memmap(self, data_point: Dict[str, torch.Tensor], total_samples: int):
287380 extra_dims = per_sample_shape [1 :] # () or (D,) or (3*D,)
288381
289382 # Determine numpy dtype
290- if tensor .dtype == torch .bfloat16 :
383+ if tensor .dtype == torch .int8 :
384+ np_dtype = np .int8
385+ elif tensor .dtype == torch .bfloat16 :
291386 np_dtype = np .float16
292387 elif tensor .dtype == torch .float16 :
293388 np_dtype = np .float16
@@ -412,12 +507,23 @@ def _expand_sample_capacity(self):
412507 )
413508
414509 def _write_sample_to_memmap (self , data_point : Dict [str , torch .Tensor ]):
415- """Write a single sample to packed memmap files.
510+ """Write a single sample to packed memmap files with storage optimizations .
416511
417512 Data is appended at position self._total_tokens_written in the packed array.
418513 The offsets array is updated to record the boundary.
419514 """
420- # Get the seq_len of the current sample
515+ # Apply storage optimizations:
516+ # 1. Per-token absmax int8 quantization for hidden_states/target_hiddens
517+ data_point = self ._apply_int8_quantization (data_point )
518+
519+ # 2. Convert remaining float tensors to target precision
520+ data_point = self ._convert_precision (data_point )
521+
522+ # Initialize memmap on first write
523+ if not self ._memmap_initialized :
524+ self ._init_memmap (data_point , self ._total_samples_estimate )
525+
526+ # Get seq_len after optimization
421527 sample_seq_len = 0
422528 for _ , tensor in data_point .items ():
423529 if isinstance (tensor , torch .Tensor ) and tensor .ndim >= 2 :
@@ -444,6 +550,8 @@ def _write_sample_to_memmap(self, data_point: Dict[str, torch.Tensor]):
444550 arr = tensor .squeeze (0 ) # [N, ...]
445551 if tensor .dtype == torch .bfloat16 :
446552 arr = arr .float ().half () # bfloat16 -> float32 -> float16
553+ elif tensor .dtype == torch .int8 :
554+ pass # int8 can be directly converted to numpy
447555 arr_np = arr .contiguous ().numpy ()
448556
449557 seq_len = arr_np .shape [0 ]
@@ -524,11 +632,15 @@ def _finalize_memmap(self):
524632 old_offsets_path = self ._memmap_dir / "offsets.npy"
525633 os .replace (str (final_offsets_path ), str (old_offsets_path ))
526634
527- # Save metadata JSON
635+ # Save metadata JSON with storage optimization info
528636 metadata = {
529637 "format" : "packed" , # Distinguish from the old rectangular format
530638 "total_samples" : self ._sample_count ,
531639 "total_tokens" : total_tokens ,
640+ "storage_optimization" : {
641+ "precision" : self .storage_precision ,
642+ "int8_quantization" : self .enable_int8_quantization ,
643+ },
532644 "fields" : {},
533645 }
534646 for field_name in self ._field_dtypes :
@@ -659,10 +771,6 @@ def _process_single_sample(self, idx: int, row: Dict[str, Any]) -> bool:
659771 batch_token_dict = dict (zip (unique_ids .tolist (), counts .tolist ()))
660772 self .token_dict .update (batch_token_dict )
661773
662- # Initialize memmap on first write
663- if not self ._memmap_initialized :
664- self ._init_memmap (data_point , self ._total_samples_estimate )
665-
666774 # Write directly to memmap
667775 self ._write_sample_to_memmap (data_point )
668776 return True
@@ -882,6 +990,23 @@ def parse_arguments() -> argparse.Namespace:
882990 help = "Path to draft model config file, used to read draft_vocab_size and vocab_size "
883991 "for computing vocab mapping" ,
884992 )
993+
994+ # Storage optimization arguments
995+ parser .add_argument (
996+ "--storage_precision" ,
997+ type = str ,
998+ default = "bfloat16" ,
999+ choices = ["bfloat16" , "float16" , "float32" ],
1000+ help = "Storage precision for hidden states. bfloat16/float16 reduce storage by 50%%" ,
1001+ )
1002+ parser .add_argument (
1003+ "--enable_int8_quantization" ,
1004+ action = "store_true" ,
1005+ default = False ,
1006+ help = "Enable per-token absmax int8 quantization for hidden_states and target_hiddens. "
1007+ "Reduces storage by ~50%% for these tensors with minimal quality loss." ,
1008+ )
1009+
8851010 return parser .parse_args ()
8861011
8871012
@@ -1074,12 +1199,22 @@ def main():
10741199 output_dir = f"{ args .outdir } /rank_{ rank } "
10751200 logger .info (f"writing hidden states to { output_dir } " , extra = {"rank" : rank })
10761201
1202+ # Log storage optimization settings
1203+ logger .info (
1204+ f"[Storage Optimization Settings] "
1205+ f"Precision: { args .storage_precision } , "
1206+ f"Int8 Quantization: { args .enable_int8_quantization } " ,
1207+ extra = {"rank" : rank },
1208+ )
1209+
10771210 generator = HiddenStateGenerator (
10781211 target_model ,
10791212 output_dir ,
10801213 rank = rank ,
10811214 draft_vocab_size = draft_vocab_size ,
10821215 target_vocab_size = target_vocab_size ,
1216+ storage_precision = args .storage_precision ,
1217+ enable_int8_quantization = args .enable_int8_quantization ,
10831218 )
10841219 successful , failed = generator .generate (dataset_slice )
10851220
0 commit comments