@@ -82,6 +82,8 @@ def __init__(
8282 rank : int = 0 ,
8383 draft_vocab_size : int = None ,
8484 target_vocab_size : int = None ,
85+ storage_precision : str = "bfloat16" ,
86+ enable_int8_quantization : bool = False ,
8587 ):
8688 """
8789 Initialize the hidden state generator.
@@ -92,12 +94,20 @@ def __init__(
9294 rank: Process rank for distributed training
9395 draft_vocab_size: Size of draft model vocabulary (required for vocab mapping)
9496 target_vocab_size: Size of target model vocabulary (required for vocab mapping)
97+ storage_precision: Storage precision for hidden states (bfloat16, float16, float32)
98+ enable_int8_quantization: Whether to enable per-token absmax int8 quantization
99+ for hidden_states and target_hiddens. Reduces storage by ~50%.
95100 """
96101 self .target_model = target_model
97102 self .output_dir = Path (output_dir )
98103 self .rank = rank
99104 self .draft_vocab_size = draft_vocab_size
100105 self .target_vocab_size = target_vocab_size
106+
107+ # Storage optimization config
108+ self .storage_precision = storage_precision
109+ self .enable_int8_quantization = enable_int8_quantization
110+
101111 _max_pixels = os .environ .get ("MAX_PIXELS" )
102112 _min_pixels = os .environ .get ("MIN_PIXELS" , "1024" )
103113 self .max_pixels = int (_max_pixels ) if _max_pixels is not None else None
@@ -122,6 +132,89 @@ def __init__(
122132 # Resolve image_pad token id for vLLM loss_mask rebuilding
123133 self ._image_pad_token_id = None
124134
135+ @staticmethod
136+ def _quantize_per_token_absmax_int8 (tensor : torch .Tensor ):
137+ """Per-token absmax int8 quantization.
138+
139+ For a tensor of shape [B, N, D], compute per-token (per row in the
140+ last two dims) scale = max(|x|) / 127, then quantize to int8.
141+
142+ Args:
143+ tensor: Float tensor of shape [B, N, D]
144+
145+ Returns:
146+ Tuple of (quantized_int8 [B, N, D], scales [B, N, 1])
147+ """
148+ # tensor: [B, N, D]
149+ absmax = tensor .abs ().amax (dim = - 1 , keepdim = True ).clamp (min = 1e-10 ) # [B, N, 1]
150+ scale = absmax / 127.0 # [B, N, 1]
151+ quantized = (tensor / scale ).round ().clamp (- 127 , 127 ).to (torch .int8 ) # [B, N, D]
152+ return quantized , scale
153+
154+ @staticmethod
155+ def _dequantize_per_token_absmax_int8 (
156+ quantized : torch .Tensor , scale : torch .Tensor , target_dtype : torch .dtype = torch .bfloat16
157+ ):
158+ """Dequantize per-token absmax int8 back to float.
159+
160+ Args:
161+ quantized: int8 tensor of shape [B, N, D]
162+ scale: float tensor of shape [B, N, 1]
163+ target_dtype: Target float dtype
164+
165+ Returns:
166+ Dequantized float tensor of shape [B, N, D]
167+ """
168+ return quantized .to (target_dtype ) * scale .to (target_dtype )
169+
170+ def _convert_precision (self , data_point : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
171+ """Convert data precision for storage."""
172+ dtype_map = {
173+ "bfloat16" : torch .bfloat16 ,
174+ "float16" : torch .float16 ,
175+ "float32" : torch .float32 ,
176+ }
177+ target_dtype = dtype_map .get (self .storage_precision , torch .bfloat16 )
178+
179+ converted_data_point = {}
180+ for key , tensor in data_point .items ():
181+ if isinstance (tensor , torch .Tensor ) and tensor .is_floating_point ():
182+ if tensor .dtype != target_dtype :
183+ converted_data_point [key ] = tensor .to (target_dtype )
184+ else :
185+ converted_data_point [key ] = tensor
186+ else :
187+ converted_data_point [key ] = tensor
188+
189+ return converted_data_point
190+
191+ def _apply_int8_quantization (
192+ self , data_point : Dict [str , torch .Tensor ]
193+ ) -> Dict [str , torch .Tensor ]:
194+ """Apply per-token absmax int8 quantization to hidden_states and target_hiddens.
195+
196+ Replaces the original float tensors with int8 data + float32 scales:
197+ hidden_states -> hidden_states_int8 (int8) + hidden_states_scales (float32)
198+ target_hiddens -> target_hiddens_int8 (int8) + target_hiddens_scales (float32)
199+
200+ This reduces storage by ~50% for these large tensors.
201+ """
202+ if not self .enable_int8_quantization :
203+ return data_point
204+
205+ new_data_point = {}
206+ for key , tensor in data_point .items ():
207+ if key in ("hidden_states" , "target_hiddens" ) and isinstance (tensor , torch .Tensor ):
208+ # Quantize: tensor shape [B, N, D]
209+ q_int8 , scales = self ._quantize_per_token_absmax_int8 (tensor .float ())
210+ new_data_point [f"{ key } _int8" ] = q_int8 # [B, N, D] int8
211+ new_data_point [f"{ key } _scales" ] = scales # [B, N, 1] float32
212+ # Do NOT keep the original float tensor
213+ else :
214+ new_data_point [key ] = tensor
215+
216+ return new_data_point
217+
125218 def _resolve_image_pad_token_id (self ):
126219 """Lazily resolve the image_pad token id from the target model's tokenizer."""
127220 if self ._image_pad_token_id is not None :
@@ -412,12 +505,19 @@ def _expand_sample_capacity(self):
412505 )
413506
414507 def _write_sample_to_memmap (self , data_point : Dict [str , torch .Tensor ]):
415- """Write a single sample to packed memmap files.
508+ """Write a single sample to packed memmap files with storage optimizations .
416509
417510 Data is appended at position self._total_tokens_written in the packed array.
418511 The offsets array is updated to record the boundary.
419512 """
420- # Get the seq_len of the current sample
513+ # Apply storage optimizations:
514+ # 1. Per-token absmax int8 quantization for hidden_states/target_hiddens
515+ data_point = self ._apply_int8_quantization (data_point )
516+
517+ # 2. Convert remaining float tensors to target precision
518+ data_point = self ._convert_precision (data_point )
519+
520+ # Get seq_len after optimization
421521 sample_seq_len = 0
422522 for _ , tensor in data_point .items ():
423523 if isinstance (tensor , torch .Tensor ) and tensor .ndim >= 2 :
@@ -524,11 +624,15 @@ def _finalize_memmap(self):
524624 old_offsets_path = self ._memmap_dir / "offsets.npy"
525625 os .replace (str (final_offsets_path ), str (old_offsets_path ))
526626
527- # Save metadata JSON
627+ # Save metadata JSON with storage optimization info
528628 metadata = {
529629 "format" : "packed" , # Distinguish from the old rectangular format
530630 "total_samples" : self ._sample_count ,
531631 "total_tokens" : total_tokens ,
632+ "storage_optimization" : {
633+ "precision" : self .storage_precision ,
634+ "int8_quantization" : self .enable_int8_quantization ,
635+ },
532636 "fields" : {},
533637 }
534638 for field_name in self ._field_dtypes :
@@ -882,6 +986,23 @@ def parse_arguments() -> argparse.Namespace:
882986 help = "Path to draft model config file, used to read draft_vocab_size and vocab_size "
883987 "for computing vocab mapping" ,
884988 )
989+
990+ # Storage optimization arguments
991+ parser .add_argument (
992+ "--storage_precision" ,
993+ type = str ,
994+ default = "bfloat16" ,
995+ choices = ["bfloat16" , "float16" , "float32" ],
996+ help = "Storage precision for hidden states. bfloat16/float16 reduce storage by 50%%" ,
997+ )
998+ parser .add_argument (
999+ "--enable_int8_quantization" ,
1000+ action = "store_true" ,
1001+ default = False ,
1002+ help = "Enable per-token absmax int8 quantization for hidden_states and target_hiddens. "
1003+ "Reduces storage by ~50%% for these tensors with minimal quality loss." ,
1004+ )
1005+
8851006 return parser .parse_args ()
8861007
8871008
@@ -1074,12 +1195,22 @@ def main():
10741195 output_dir = f"{ args .outdir } /rank_{ rank } "
10751196 logger .info (f"writing hidden states to { output_dir } " , extra = {"rank" : rank })
10761197
1198+ # Log storage optimization settings
1199+ logger .info (
1200+ f"[Storage Optimization Settings] "
1201+ f"Precision: { args .storage_precision } , "
1202+ f"Int8 Quantization: { args .enable_int8_quantization } " ,
1203+ extra = {"rank" : rank },
1204+ )
1205+
10771206 generator = HiddenStateGenerator (
10781207 target_model ,
10791208 output_dir ,
10801209 rank = rank ,
10811210 draft_vocab_size = draft_vocab_size ,
10821211 target_vocab_size = target_vocab_size ,
1212+ storage_precision = args .storage_precision ,
1213+ enable_int8_quantization = args .enable_int8_quantization ,
10831214 )
10841215 successful , failed = generator .generate (dataset_slice )
10851216
0 commit comments