Skip to content

Commit 7f1c6b1

Browse files
committed
[bridge] fix: Add missing copyright, docstrings, and fix lint/format issues
Add copyright header to ernie45_vl_vit_debug.py, add missing docstrings for public functions/classes (D103/D101), fix import sorting (I001), remove unused import (F401), and apply ruff format to all ERNIE files. Signed-off-by: kebo01 <kebo01@baidu.com>
1 parent a950c97 commit 7f1c6b1

18 files changed

Lines changed: 310 additions & 362 deletions

examples/models/vlm/ernie_vl/ernie45_vl_fwd_bwd.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import argparse
3535
import os
3636

37+
3738
# Disable torch.compile to avoid triton compatibility issues in some environments.
3839
# Must be set before importing torch.
3940
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
@@ -52,6 +53,7 @@ def _is_rank_0() -> bool:
5253

5354

5455
def print_rank_0(msg: str):
56+
"""Print a message only from rank 0."""
5557
if _is_rank_0():
5658
print(msg, flush=True)
5759

@@ -141,6 +143,7 @@ def run_forward_backward(
141143
if with_vision:
142144
# Read vision config and special token IDs from the model provider / HF config
143145
import json
146+
144147
with open(os.path.join(hf_model_path, "config.json")) as f:
145148
hf_cfg = json.load(f)
146149

@@ -153,13 +156,19 @@ def run_forward_backward(
153156
# without index-out-of-range errors in the embedding layer.
154157
text_cfg = hf_cfg.get("text_config", hf_cfg)
155158
cfg_vocab_size = text_cfg.get("vocab_size", vocab_size)
156-
if image_token_id >= cfg_vocab_size or image_start_token_id >= cfg_vocab_size or image_end_token_id >= cfg_vocab_size:
159+
if (
160+
image_token_id >= cfg_vocab_size
161+
or image_start_token_id >= cfg_vocab_size
162+
or image_end_token_id >= cfg_vocab_size
163+
):
157164
# Use the last 3 tokens in the vocab as placeholders
158165
image_token_id = cfg_vocab_size - 3
159166
image_start_token_id = cfg_vocab_size - 2
160167
image_end_token_id = cfg_vocab_size - 1
161-
print_rank_0(f" Remapped special token IDs to fit vocab_size={cfg_vocab_size}: "
162-
f"image_token={image_token_id}, start={image_start_token_id}, end={image_end_token_id}")
168+
print_rank_0(
169+
f" Remapped special token IDs to fit vocab_size={cfg_vocab_size}: "
170+
f"image_token={image_token_id}, start={image_start_token_id}, end={image_end_token_id}"
171+
)
163172

164173
# Also update model config so get_placeholder_mask / get_rope_index use the remapped IDs
165174
for m in megatron_models:
@@ -184,7 +193,7 @@ def run_forward_backward(
184193
pixel_values = torch.randn(num_patches, patch_dim, dtype=torch.bfloat16, device="cuda")
185194

186195
# Number of image placeholder tokens after resampler spatial merge
187-
num_image_tokens = num_patches // (spatial_merge_size ** 2) # 4 // 4 = 1
196+
num_image_tokens = num_patches // (spatial_merge_size**2) # 4 // 4 = 1
188197

189198
# Build input_ids: [text..., image_start, <image_placeholders>, image_end, text...]
190199
# Ensure seq_len is large enough
@@ -212,7 +221,7 @@ def run_forward_backward(
212221
# mm_token_type_ids: 0=text, 1=image placeholder
213222
mm_token_type_ids = torch.zeros(1, actual_seq_len, dtype=torch.int32, device="cuda")
214223
img_start_pos = num_text_before + 1 # position of first image placeholder
215-
mm_token_type_ids[0, img_start_pos:img_start_pos + num_image_tokens] = 1
224+
mm_token_type_ids[0, img_start_pos : img_start_pos + num_image_tokens] = 1
216225

217226
# Labels and loss mask
218227
labels = torch.randint(0, min(vocab_size, 1024), (1, actual_seq_len), device="cuda")
@@ -229,6 +238,7 @@ def run_forward_backward(
229238
# Use real text with tokenizer for meaningful loss measurement.
230239
# Labels are the next-token targets (input shifted right by 1).
231240
from transformers import AutoTokenizer
241+
232242
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, trust_remote_code=True)
233243
token_ids = tokenizer.encode(prompt, add_special_tokens=True)
234244
# Need at least 2 tokens for next-token prediction
@@ -363,10 +373,7 @@ def loss_func(output_tensor, **kwargs):
363373

364374
from megatron.core import parallel_state
365375

366-
is_last_stage = (
367-
not dist.is_initialized()
368-
or parallel_state.is_pipeline_last_stage()
369-
)
376+
is_last_stage = not dist.is_initialized() or parallel_state.is_pipeline_last_stage()
370377

371378
if is_last_stage:
372379
if isinstance(output, list) and len(output) > 0:
@@ -426,8 +433,7 @@ def loss_func(output_tensor, **kwargs):
426433
# At least some params should have gradients
427434
# (Not all will have gradients due to PP - only the local stage's params)
428435
assert params_with_grad > 0, (
429-
f"No parameters have gradients! "
430-
f"total_params={total_params}, params_with_grad={params_with_grad}"
436+
f"No parameters have gradients! total_params={total_params}, params_with_grad={params_with_grad}"
431437
)
432438

433439
# When vision is enabled, verify vision tower and resampler got gradients
@@ -441,8 +447,7 @@ def loss_func(output_tensor, **kwargs):
441447
vision_params_with_grad += 1
442448
print_rank_0(f" Vision params with non-zero gradient: {vision_params_with_grad}")
443449
assert vision_params_with_grad > 0, (
444-
"Vision tower/resampler parameters have no gradients! "
445-
"The vision forward path may not be exercised."
450+
"Vision tower/resampler parameters have no gradients! The vision forward path may not be exercised."
446451
)
447452

448453
print_rank_0(" Gradient verification passed.")
@@ -479,19 +484,24 @@ def _run(
479484

480485
def main():
481486
"""Parse CLI arguments and launch the forward/backward test."""
482-
parser = argparse.ArgumentParser(
483-
description="ERNIE 4.5 VL MoE forward/backward test"
484-
)
487+
parser = argparse.ArgumentParser(description="ERNIE 4.5 VL MoE forward/backward test")
485488
parser.add_argument("--hf-model-path", required=True, help="Path to HF toy model directory")
486489
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism size")
487490
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallelism size")
488491
parser.add_argument("--ep", type=int, default=1, help="Expert parallelism size")
489492
parser.add_argument("--seq-len", type=int, default=16, help="Sequence length")
490493
parser.add_argument("--forward-only", action="store_true", help="Skip backward pass")
491-
parser.add_argument("--with-vision", action="store_true",
492-
help="Include a dummy image input to exercise the vision tower and resampler")
493-
parser.add_argument("--prompt", type=str, default=None,
494-
help="Use real text prompt instead of random tokens for meaningful loss measurement")
494+
parser.add_argument(
495+
"--with-vision",
496+
action="store_true",
497+
help="Include a dummy image input to exercise the vision tower and resampler",
498+
)
499+
parser.add_argument(
500+
"--prompt",
501+
type=str,
502+
default=None,
503+
help="Use real text prompt instead of random tokens for meaningful loss measurement",
504+
)
495505
args = parser.parse_args()
496506

497507
_run(

examples/models/vlm/ernie_vl/ernie45_vl_logit_compare.py

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import os
5252
import sys
5353

54+
5455
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
5556

5657
import torch
@@ -61,6 +62,7 @@
6162
from megatron.bridge import AutoBridge
6263
from megatron.bridge.utils.common_utils import disable_mtp_for_inference
6364

65+
6466
SIMILARITY_THRESHOLD = 0.98
6567

6668

@@ -71,6 +73,7 @@ def _is_rank_0() -> bool:
7173

7274

7375
def print_rank_0(msg: str):
76+
"""Print a message only from rank 0."""
7477
if _is_rank_0():
7578
print(msg, flush=True)
7679

@@ -79,6 +82,7 @@ def print_rank_0(msg: str):
7982
# Image+Text Preprocessing
8083
# ========================================================================== #
8184

85+
8286
def preprocess_image_text(hf_model_path: str, prompt: str, image_path: str):
8387
"""Use the HF processor to preprocess an image+text prompt.
8488
@@ -88,25 +92,22 @@ def preprocess_image_text(hf_model_path: str, prompt: str, image_path: str):
8892
Returns a dict with all tensors needed for both HF and Megatron forward.
8993
"""
9094
from transformers import AutoProcessor
91-
from PIL import Image
9295

93-
processor = AutoProcessor.from_pretrained(
94-
hf_model_path, trust_remote_code=True
95-
)
96+
processor = AutoProcessor.from_pretrained(hf_model_path, trust_remote_code=True)
9697

9798
# Build chat messages with image
98-
messages = [{
99-
"role": "user",
100-
"content": [
101-
{"type": "text", "text": prompt},
102-
{"type": "image_url", "image_url": {"url": image_path}},
103-
]
104-
}]
99+
messages = [
100+
{
101+
"role": "user",
102+
"content": [
103+
{"type": "text", "text": prompt},
104+
{"type": "image_url", "image_url": {"url": image_path}},
105+
],
106+
}
107+
]
105108

106109
# Apply chat template
107-
text = processor.tokenizer.apply_chat_template(
108-
messages, tokenize=False, add_generation_prompt=True
109-
)
110+
text = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110111
print_rank_0(f" Chat template text (first 200 chars): {text[:200]}...")
111112

112113
# Process vision info (loads and resizes images)
@@ -140,6 +141,7 @@ def preprocess_image_text(hf_model_path: str, prompt: str, image_path: str):
140141
# Phase 1: HF Model Forward
141142
# ========================================================================== #
142143

144+
143145
def run_hf_forward(
144146
hf_model_path: str,
145147
input_ids: torch.Tensor,
@@ -180,6 +182,7 @@ def run_hf_forward(
180182
# route tensors through meta-device shape inference, which breaks for
181183
# data-dependent ops (torch.nonzero) used in ERNIE VL MoE routing.
182184
from accelerate.hooks import remove_hook_from_module
185+
183186
for _name, _module in hf_model.named_modules():
184187
remove_hook_from_module(_module)
185188
print_rank_0(" Removed accelerate dispatch hooks from all modules.")
@@ -192,7 +195,7 @@ def run_hf_forward(
192195
# indexing on meta-dispatched tensors.
193196
_fixed_moe = 0
194197
for _name, _module in hf_model.named_modules():
195-
if hasattr(_module, 'use_correction_bias') and _module.use_correction_bias:
198+
if hasattr(_module, "use_correction_bias") and _module.use_correction_bias:
196199
_module.use_correction_bias = False
197200
_fixed_moe += 1
198201
if _fixed_moe:
@@ -201,19 +204,17 @@ def run_hf_forward(
201204
# Safety: fix inv_freq if stuck on meta device (only happens with
202205
# device_map="auto"). With device_map={"": device} this is a no-op.
203206
for name, module in hf_model.named_modules():
204-
if hasattr(module, 'inv_freq') and isinstance(module.inv_freq, torch.Tensor):
205-
if module.inv_freq.device.type == 'meta':
207+
if hasattr(module, "inv_freq") and isinstance(module.inv_freq, torch.Tensor):
208+
if module.inv_freq.device.type == "meta":
206209
dim = module.inv_freq.shape[0] * 2 # inv_freq has shape [dim//2]
207210
theta = 10000.0
208-
module.inv_freq = 1.0 / (theta ** (
209-
torch.arange(0, dim, 2, dtype=torch.float32) / dim
210-
))
211+
module.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
211212
print_rank_0(f" Fixed meta inv_freq in {name} -> CPU, shape={module.inv_freq.shape}")
212213

213214
if processor_output is not None:
214215
# Image+text VL forward
215216
# Register image preprocessor for GPU-side pixel normalization
216-
if processor is not None and hasattr(hf_model, 'add_image_preprocess'):
217+
if processor is not None and hasattr(hf_model, "add_image_preprocess"):
217218
hf_model.add_image_preprocess(processor)
218219
print_rank_0(" Image preprocessor registered on HF model")
219220

@@ -250,7 +251,7 @@ def run_hf_forward(
250251
# 3D M-RoPE position_ids: [batch, seq_len, 3] -- for text-only, all 3 dims identical
251252
position_ids = (
252253
torch.arange(seq_len, dtype=torch.long, device=device)
253-
.unsqueeze(0) # [1, seq_len]
254+
.unsqueeze(0) # [1, seq_len]
254255
.unsqueeze(-1) # [1, seq_len, 1]
255256
.expand(1, seq_len, 3) # [1, seq_len, 3]
256257
.clone()
@@ -286,7 +287,10 @@ def run_hf_forward(
286287
# Phase 2: Megatron Model Forward
287288
# ========================================================================== #
288289

290+
289291
class SingleBatchIterator:
292+
"""Iterator that yields a single batch for Megatron forward scheduling."""
293+
290294
def __init__(self, batch):
291295
self.batch = batch
292296
self._yielded = False
@@ -413,10 +417,7 @@ def run_megatron_forward(
413417
)
414418

415419
# Process output on last pipeline stage
416-
is_last_stage = (
417-
not dist.is_initialized()
418-
or parallel_state.is_pipeline_last_stage()
419-
)
420+
is_last_stage = not dist.is_initialized() or parallel_state.is_pipeline_last_stage()
420421

421422
megatron_logits_cpu = None
422423

@@ -433,12 +434,9 @@ def run_megatron_forward(
433434

434435
megatron_logits = output[0, -1, :].float() # [padded_vocab_size]
435436

436-
is_primary = (
437-
not dist.is_initialized()
438-
or (
439-
parallel_state.get_tensor_model_parallel_rank() == 0
440-
and parallel_state.get_expert_model_parallel_rank() == 0
441-
)
437+
is_primary = not dist.is_initialized() or (
438+
parallel_state.get_tensor_model_parallel_rank() == 0
439+
and parallel_state.get_expert_model_parallel_rank() == 0
442440
)
443441

444442
if is_primary:
@@ -447,8 +445,12 @@ def run_megatron_forward(
447445
top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids]
448446

449447
print_rank_0(f" Megatron output shape: {output.shape}")
450-
print_rank_0(f" Megatron logits stats: mean={megatron_logits.mean():.4f}, std={megatron_logits.std():.4f}")
451-
print_rank_0(f" Megatron next token: {megatron_next_token.item()} ('{tokenizer.decode([megatron_next_token.item()])}')")
448+
print_rank_0(
449+
f" Megatron logits stats: mean={megatron_logits.mean():.4f}, std={megatron_logits.std():.4f}"
450+
)
451+
print_rank_0(
452+
f" Megatron next token: {megatron_next_token.item()} ('{tokenizer.decode([megatron_next_token.item()])}')"
453+
)
452454
print_rank_0(f" Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}")
453455

454456
megatron_logits_cpu = megatron_logits.cpu()
@@ -460,7 +462,10 @@ def run_megatron_forward(
460462
# Phase 3: Comparison
461463
# ========================================================================== #
462464

463-
def compare_logits(hf_logits: torch.Tensor, megatron_logits: torch.Tensor, tokenizer, threshold: float = SIMILARITY_THRESHOLD):
465+
466+
def compare_logits(
467+
hf_logits: torch.Tensor, megatron_logits: torch.Tensor, tokenizer, threshold: float = SIMILARITY_THRESHOLD
468+
):
464469
"""Compare HF and Megatron logits. Returns True if pass."""
465470
print_rank_0("\n=== Phase 3: Comparing Logits ===")
466471

@@ -509,7 +514,9 @@ def compare_logits(hf_logits: torch.Tensor, megatron_logits: torch.Tensor, token
509514
# Main
510515
# ========================================================================== #
511516

517+
512518
def main():
519+
"""Run ERNIE 4.5 VL MoE logit comparison between HF and Megatron."""
513520
parser = argparse.ArgumentParser(description="ERNIE 4.5 VL MoE logit comparison (HF vs Megatron)")
514521
parser.add_argument("--hf-model-path", required=True, help="Path to HF model directory")
515522
parser.add_argument("--prompt", default="Hello, how are you?", help="Text prompt for comparison")
@@ -535,6 +542,7 @@ def main():
535542

536543
# Load tokenizer
537544
from transformers import AutoTokenizer
545+
538546
tokenizer = AutoTokenizer.from_pretrained(
539547
args.hf_model_path,
540548
trust_remote_code=True,
@@ -554,9 +562,7 @@ def main():
554562
# Image+Text mode: use processor to prepare all inputs
555563
# ============================================================
556564
print_rank_0("\n=== Preprocessing: Image+Text ===")
557-
processor_output, processor = preprocess_image_text(
558-
args.hf_model_path, args.prompt, args.image_path
559-
)
565+
processor_output, processor = preprocess_image_text(args.hf_model_path, args.prompt, args.image_path)
560566
input_ids = processor_output["input_ids"]
561567

562568
# Extract vision tensors for Megatron side
@@ -572,10 +578,11 @@ def main():
572578
if "token_type_ids" in processor_output:
573579
hf_token_type_ids = processor_output["token_type_ids"]
574580
# Take the first seq_len values (drop the extra trailing token)
575-
mm_token_type_ids = hf_token_type_ids[:, :input_ids.size(1)].to(torch.int32)
581+
mm_token_type_ids = hf_token_type_ids[:, : input_ids.size(1)].to(torch.int32)
576582
num_img_tokens = (mm_token_type_ids == 1).sum().item()
577-
print_rank_0(f" mm_token_type_ids: {mm_token_type_ids.shape}, "
578-
f"image tokens: {num_img_tokens}/{input_ids.size(1)}")
583+
print_rank_0(
584+
f" mm_token_type_ids: {mm_token_type_ids.shape}, image tokens: {num_img_tokens}/{input_ids.size(1)}"
585+
)
579586
else:
580587
# ============================================================
581588
# Text-only mode: simple tokenization
@@ -598,7 +605,8 @@ def main():
598605
# Also pad mm_token_type_ids if present
599606
if mm_token_type_ids is not None:
600607
mm_padding = torch.zeros(
601-
mm_token_type_ids.shape[0], pad_len,
608+
mm_token_type_ids.shape[0],
609+
pad_len,
602610
dtype=mm_token_type_ids.dtype,
603611
)
604612
mm_token_type_ids = torch.cat([mm_token_type_ids, mm_padding], dim=1)
@@ -607,7 +615,9 @@ def main():
607615

608616
# Phase 1: HF forward (rank 0 only)
609617
hf_logits = run_hf_forward(
610-
args.hf_model_path, input_ids, tokenizer,
618+
args.hf_model_path,
619+
input_ids,
620+
tokenizer,
611621
processor_output=processor_output,
612622
processor=processor,
613623
)

0 commit comments

Comments
 (0)