Skip to content

Commit 31405ac

Browse files
HuiyingLiclaude
andauthored
docs(nemotron-omni): use device_map fast path for SFT inference (#2126)
The SFT inference snippet under Step 4 instantiated a 30B model on CPU via `AutoModel.from_config()` solely to read its concrete `trust_remote_code` class, then re-loaded weights through that class. On the v3 dump this CPU instantiation alone takes ~5 minutes. Verified locally on `auto2604rc4` against both the base v3 dump and a consolidated SFT checkpoint that `AutoModel.from_pretrained(CKPT, trust_remote_code=True, dtype=torch.bfloat16, device_map={"": torch.cuda.current_device()})` resolves to `NemotronH_Nano_Omni_Reasoning_V3` correctly and produces structured `<s_total>...</s_total>` output — the `from_config` round-trip and the `all_tied_weights_keys` patch are no longer needed. The LoRA section already uses the same fast path. Total inference setup drops from ~5 min to ~85 s on the consolidated dump. Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent bb209de commit 31405ac

1 file changed

Lines changed: 8 additions & 9 deletions

File tree

docs/guides/vlm/nemotron-omni.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ to spot-check structured output.
313313
```python
314314
import torch
315315
import json
316-
from transformers import AutoConfig, AutoModel, AutoProcessor
316+
from transformers import AutoModel, AutoProcessor
317317
from datasets import load_dataset
318318
from nemo_automodel.components.datasets.vlm.utils import json2token
319319
@@ -323,19 +323,18 @@ CKPT = "<checkpoint_dir>/LOWEST_VAL/model/consolidated"
323323
processor = AutoProcessor.from_pretrained(CKPT, trust_remote_code=True)
324324
tokenizer = processor.tokenizer
325325
326-
# Resolve the trust_remote_code model class via from_config, then load weights.
327-
# Using AutoModel.from_pretrained directly can mis-route on v3 dumps.
328-
config = AutoConfig.from_pretrained(CKPT, trust_remote_code=True)
329-
model_class = type(AutoModel.from_config(config, trust_remote_code=True))
330-
if not hasattr(model_class, "all_tied_weights_keys"):
331-
model_class.all_tied_weights_keys = {}
332-
model = model_class.from_pretrained(CKPT, trust_remote_code=True, torch_dtype=torch.bfloat16)
326+
# `device_map` streams weights directly to GPU; skipping the AutoModel.from_config
327+
# CPU-instantiation step saves ~5 min on the 30B v3 dump.
328+
model = AutoModel.from_pretrained(
329+
CKPT, trust_remote_code=True, torch_dtype=torch.bfloat16,
330+
device_map={"": torch.cuda.current_device()},
331+
)
333332
334333
# Reset RADIO's `summary_idxs` (non-persistent buffer; can be a meta tensor after load)
335334
if hasattr(model, "vision_model") and hasattr(model.vision_model, "radio_model"):
336335
model.vision_model.radio_model.summary_idxs = None
337336
338-
model = model.cuda().eval()
337+
model.eval()
339338
340339
# Load dataset
341340
dataset = load_dataset("naver-clova-ix/cord-v2")

0 commit comments

Comments
 (0)