5151import os
5252import sys
5353
54+
5455os .environ .setdefault ("TORCH_COMPILE_DISABLE" , "1" )
5556
5657import torch
6162from megatron .bridge import AutoBridge
6263from megatron .bridge .utils .common_utils import disable_mtp_for_inference
6364
65+
6466SIMILARITY_THRESHOLD = 0.98
6567
6668
@@ -71,6 +73,7 @@ def _is_rank_0() -> bool:
7173
7274
7375def print_rank_0 (msg : str ):
76+ """Print a message only from rank 0."""
7477 if _is_rank_0 ():
7578 print (msg , flush = True )
7679
@@ -79,6 +82,7 @@ def print_rank_0(msg: str):
7982# Image+Text Preprocessing
8083# ========================================================================== #
8184
85+
8286def preprocess_image_text (hf_model_path : str , prompt : str , image_path : str ):
8387 """Use the HF processor to preprocess an image+text prompt.
8488
@@ -88,25 +92,22 @@ def preprocess_image_text(hf_model_path: str, prompt: str, image_path: str):
8892 Returns a dict with all tensors needed for both HF and Megatron forward.
8993 """
9094 from transformers import AutoProcessor
91- from PIL import Image
9295
93- processor = AutoProcessor .from_pretrained (
94- hf_model_path , trust_remote_code = True
95- )
96+ processor = AutoProcessor .from_pretrained (hf_model_path , trust_remote_code = True )
9697
9798 # Build chat messages with image
98- messages = [{
99- "role" : "user" ,
100- "content" : [
101- {"type" : "text" , "text" : prompt },
102- {"type" : "image_url" , "image_url" : {"url" : image_path }},
103- ]
104- }]
99+ messages = [
100+ {
101+ "role" : "user" ,
102+ "content" : [
103+ {"type" : "text" , "text" : prompt },
104+ {"type" : "image_url" , "image_url" : {"url" : image_path }},
105+ ],
106+ }
107+ ]
105108
106109 # Apply chat template
107- text = processor .tokenizer .apply_chat_template (
108- messages , tokenize = False , add_generation_prompt = True
109- )
110+ text = processor .tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
110111 print_rank_0 (f" Chat template text (first 200 chars): { text [:200 ]} ..." )
111112
112113 # Process vision info (loads and resizes images)
@@ -140,6 +141,7 @@ def preprocess_image_text(hf_model_path: str, prompt: str, image_path: str):
140141# Phase 1: HF Model Forward
141142# ========================================================================== #
142143
144+
143145def run_hf_forward (
144146 hf_model_path : str ,
145147 input_ids : torch .Tensor ,
@@ -180,6 +182,7 @@ def run_hf_forward(
180182 # route tensors through meta-device shape inference, which breaks for
181183 # data-dependent ops (torch.nonzero) used in ERNIE VL MoE routing.
182184 from accelerate .hooks import remove_hook_from_module
185+
183186 for _name , _module in hf_model .named_modules ():
184187 remove_hook_from_module (_module )
185188 print_rank_0 (" Removed accelerate dispatch hooks from all modules." )
@@ -192,7 +195,7 @@ def run_hf_forward(
192195 # indexing on meta-dispatched tensors.
193196 _fixed_moe = 0
194197 for _name , _module in hf_model .named_modules ():
195- if hasattr (_module , ' use_correction_bias' ) and _module .use_correction_bias :
198+ if hasattr (_module , " use_correction_bias" ) and _module .use_correction_bias :
196199 _module .use_correction_bias = False
197200 _fixed_moe += 1
198201 if _fixed_moe :
@@ -201,19 +204,17 @@ def run_hf_forward(
201204 # Safety: fix inv_freq if stuck on meta device (only happens with
202205 # device_map="auto"). With device_map={"": device} this is a no-op.
203206 for name , module in hf_model .named_modules ():
204- if hasattr (module , ' inv_freq' ) and isinstance (module .inv_freq , torch .Tensor ):
205- if module .inv_freq .device .type == ' meta' :
207+ if hasattr (module , " inv_freq" ) and isinstance (module .inv_freq , torch .Tensor ):
208+ if module .inv_freq .device .type == " meta" :
206209 dim = module .inv_freq .shape [0 ] * 2 # inv_freq has shape [dim//2]
207210 theta = 10000.0
208- module .inv_freq = 1.0 / (theta ** (
209- torch .arange (0 , dim , 2 , dtype = torch .float32 ) / dim
210- ))
211+ module .inv_freq = 1.0 / (theta ** (torch .arange (0 , dim , 2 , dtype = torch .float32 ) / dim ))
211212 print_rank_0 (f" Fixed meta inv_freq in { name } -> CPU, shape={ module .inv_freq .shape } " )
212213
213214 if processor_output is not None :
214215 # Image+text VL forward
215216 # Register image preprocessor for GPU-side pixel normalization
216- if processor is not None and hasattr (hf_model , ' add_image_preprocess' ):
217+ if processor is not None and hasattr (hf_model , " add_image_preprocess" ):
217218 hf_model .add_image_preprocess (processor )
218219 print_rank_0 (" Image preprocessor registered on HF model" )
219220
@@ -250,7 +251,7 @@ def run_hf_forward(
250251 # 3D M-RoPE position_ids: [batch, seq_len, 3] -- for text-only, all 3 dims identical
251252 position_ids = (
252253 torch .arange (seq_len , dtype = torch .long , device = device )
253- .unsqueeze (0 ) # [1, seq_len]
254+ .unsqueeze (0 ) # [1, seq_len]
254255 .unsqueeze (- 1 ) # [1, seq_len, 1]
255256 .expand (1 , seq_len , 3 ) # [1, seq_len, 3]
256257 .clone ()
@@ -286,7 +287,10 @@ def run_hf_forward(
286287# Phase 2: Megatron Model Forward
287288# ========================================================================== #
288289
290+
289291class SingleBatchIterator :
292+ """Iterator that yields a single batch for Megatron forward scheduling."""
293+
290294 def __init__ (self , batch ):
291295 self .batch = batch
292296 self ._yielded = False
@@ -413,10 +417,7 @@ def run_megatron_forward(
413417 )
414418
415419 # Process output on last pipeline stage
416- is_last_stage = (
417- not dist .is_initialized ()
418- or parallel_state .is_pipeline_last_stage ()
419- )
420+ is_last_stage = not dist .is_initialized () or parallel_state .is_pipeline_last_stage ()
420421
421422 megatron_logits_cpu = None
422423
@@ -433,12 +434,9 @@ def run_megatron_forward(
433434
434435 megatron_logits = output [0 , - 1 , :].float () # [padded_vocab_size]
435436
436- is_primary = (
437- not dist .is_initialized ()
438- or (
439- parallel_state .get_tensor_model_parallel_rank () == 0
440- and parallel_state .get_expert_model_parallel_rank () == 0
441- )
437+ is_primary = not dist .is_initialized () or (
438+ parallel_state .get_tensor_model_parallel_rank () == 0
439+ and parallel_state .get_expert_model_parallel_rank () == 0
442440 )
443441
444442 if is_primary :
@@ -447,8 +445,12 @@ def run_megatron_forward(
447445 top5_tokens = [tokenizer .decode ([idx ]) for idx in top5_ids ]
448446
449447 print_rank_0 (f" Megatron output shape: { output .shape } " )
450- print_rank_0 (f" Megatron logits stats: mean={ megatron_logits .mean ():.4f} , std={ megatron_logits .std ():.4f} " )
451- print_rank_0 (f" Megatron next token: { megatron_next_token .item ()} ('{ tokenizer .decode ([megatron_next_token .item ()])} ')" )
448+ print_rank_0 (
449+ f" Megatron logits stats: mean={ megatron_logits .mean ():.4f} , std={ megatron_logits .std ():.4f} "
450+ )
451+ print_rank_0 (
452+ f" Megatron next token: { megatron_next_token .item ()} ('{ tokenizer .decode ([megatron_next_token .item ()])} ')"
453+ )
452454 print_rank_0 (f" Megatron Top 5: { list (zip (top5_tokens , top5_vals .tolist ()))} " )
453455
454456 megatron_logits_cpu = megatron_logits .cpu ()
@@ -460,7 +462,10 @@ def run_megatron_forward(
460462# Phase 3: Comparison
461463# ========================================================================== #
462464
463- def compare_logits (hf_logits : torch .Tensor , megatron_logits : torch .Tensor , tokenizer , threshold : float = SIMILARITY_THRESHOLD ):
465+
466+ def compare_logits (
467+ hf_logits : torch .Tensor , megatron_logits : torch .Tensor , tokenizer , threshold : float = SIMILARITY_THRESHOLD
468+ ):
464469 """Compare HF and Megatron logits. Returns True if pass."""
465470 print_rank_0 ("\n === Phase 3: Comparing Logits ===" )
466471
@@ -509,7 +514,9 @@ def compare_logits(hf_logits: torch.Tensor, megatron_logits: torch.Tensor, token
509514# Main
510515# ========================================================================== #
511516
517+
512518def main ():
519+ """Run ERNIE 4.5 VL MoE logit comparison between HF and Megatron."""
513520 parser = argparse .ArgumentParser (description = "ERNIE 4.5 VL MoE logit comparison (HF vs Megatron)" )
514521 parser .add_argument ("--hf-model-path" , required = True , help = "Path to HF model directory" )
515522 parser .add_argument ("--prompt" , default = "Hello, how are you?" , help = "Text prompt for comparison" )
@@ -535,6 +542,7 @@ def main():
535542
536543 # Load tokenizer
537544 from transformers import AutoTokenizer
545+
538546 tokenizer = AutoTokenizer .from_pretrained (
539547 args .hf_model_path ,
540548 trust_remote_code = True ,
@@ -554,9 +562,7 @@ def main():
554562 # Image+Text mode: use processor to prepare all inputs
555563 # ============================================================
556564 print_rank_0 ("\n === Preprocessing: Image+Text ===" )
557- processor_output , processor = preprocess_image_text (
558- args .hf_model_path , args .prompt , args .image_path
559- )
565+ processor_output , processor = preprocess_image_text (args .hf_model_path , args .prompt , args .image_path )
560566 input_ids = processor_output ["input_ids" ]
561567
562568 # Extract vision tensors for Megatron side
@@ -572,10 +578,11 @@ def main():
572578 if "token_type_ids" in processor_output :
573579 hf_token_type_ids = processor_output ["token_type_ids" ]
574580 # Take the first seq_len values (drop the extra trailing token)
575- mm_token_type_ids = hf_token_type_ids [:, :input_ids .size (1 )].to (torch .int32 )
581+ mm_token_type_ids = hf_token_type_ids [:, : input_ids .size (1 )].to (torch .int32 )
576582 num_img_tokens = (mm_token_type_ids == 1 ).sum ().item ()
577- print_rank_0 (f" mm_token_type_ids: { mm_token_type_ids .shape } , "
578- f"image tokens: { num_img_tokens } /{ input_ids .size (1 )} " )
583+ print_rank_0 (
584+ f" mm_token_type_ids: { mm_token_type_ids .shape } , image tokens: { num_img_tokens } /{ input_ids .size (1 )} "
585+ )
579586 else :
580587 # ============================================================
581588 # Text-only mode: simple tokenization
@@ -598,7 +605,8 @@ def main():
598605 # Also pad mm_token_type_ids if present
599606 if mm_token_type_ids is not None :
600607 mm_padding = torch .zeros (
601- mm_token_type_ids .shape [0 ], pad_len ,
608+ mm_token_type_ids .shape [0 ],
609+ pad_len ,
602610 dtype = mm_token_type_ids .dtype ,
603611 )
604612 mm_token_type_ids = torch .cat ([mm_token_type_ids , mm_padding ], dim = 1 )
@@ -607,7 +615,9 @@ def main():
607615
608616 # Phase 1: HF forward (rank 0 only)
609617 hf_logits = run_hf_forward (
610- args .hf_model_path , input_ids , tokenizer ,
618+ args .hf_model_path ,
619+ input_ids ,
620+ tokenizer ,
611621 processor_output = processor_output ,
612622 processor = processor ,
613623 )
0 commit comments