Skip to content

Commit 0bfaceb

Browse files
committed
Fix CI failures: restore is_compileable and unify quantized_model_init
- Restore is_compileable property on HFInferenceParams (accidentally dropped from PR 1500), required by newer transformers generate(). - Unify get_autocast_context init path to work both standalone (model tests, no outer context) and with outer quantized_model_init (recipe training). FP8/FP4 layers use per-layer quantized_model_init with preserve_high_precision_init_val=True; BF16 layers use quantized_model_init(enabled=False) to override any outer context. Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
1 parent a611fb4 commit 0bfaceb

3 files changed

Lines changed: 24 additions & 15 deletions

File tree

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,9 @@ def get_autocast_context(
409409

410410
if init and self.config.use_quantized_model_init:
411411
if precision in ("fp8", "fp4"):
412-
# Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext()
413-
# preserves the outer context's settings (recipe, preserve_high_precision_init_val).
414-
# A nested quantized_model_init would override preserve_high_precision_init_val to False.
415-
return nullcontext()
416-
# BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context.
412+
return transformer_engine.pytorch.quantized_model_init(
413+
recipe=recipe, preserve_high_precision_init_val=True
414+
)
417415
return transformer_engine.pytorch.quantized_model_init(enabled=False)
418416

419417
if precision == "fp8":
@@ -633,6 +631,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
633631
return 0
634632
return max(self.sequences.values())
635633

634+
@property
635+
def is_compileable(self) -> bool:
636+
"""Required by HuggingFace transformers generate() auto-compile check."""
637+
return False
638+
636639
def reorder_cache(self, beam_idx: torch.LongTensor):
637640
"""Reorder the cache based on the beam indices."""
638641
if isinstance(self.cache_manager, PagedKVCacheManager):

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,9 @@ def get_autocast_context(
415415

416416
if init and self.config.use_quantized_model_init:
417417
if precision in ("fp8", "fp4"):
418-
# Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext()
419-
# preserves the outer context's settings (recipe, preserve_high_precision_init_val).
420-
# A nested quantized_model_init would override preserve_high_precision_init_val to False.
421-
return nullcontext()
422-
# BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context.
418+
return transformer_engine.pytorch.quantized_model_init(
419+
recipe=recipe, preserve_high_precision_init_val=True
420+
)
423421
return transformer_engine.pytorch.quantized_model_init(enabled=False)
424422

425423
if precision == "fp8":
@@ -639,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
639637
return 0
640638
return max(self.sequences.values())
641639

640+
@property
641+
def is_compileable(self) -> bool:
642+
"""Required by HuggingFace transformers generate() auto-compile check."""
643+
return False
644+
642645
def reorder_cache(self, beam_idx: torch.LongTensor):
643646
"""Reorder the cache based on the beam indices."""
644647
if isinstance(self.cache_manager, PagedKVCacheManager):

bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,9 @@ def get_autocast_context(
415415

416416
if init and self.config.use_quantized_model_init:
417417
if precision in ("fp8", "fp4"):
418-
# Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext()
419-
# preserves the outer context's settings (recipe, preserve_high_precision_init_val).
420-
# A nested quantized_model_init would override preserve_high_precision_init_val to False.
421-
return nullcontext()
422-
# BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context.
418+
return transformer_engine.pytorch.quantized_model_init(
419+
recipe=recipe, preserve_high_precision_init_val=True
420+
)
423421
return transformer_engine.pytorch.quantized_model_init(enabled=False)
424422

425423
if precision == "fp8":
@@ -639,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
639637
return 0
640638
return max(self.sequences.values())
641639

640+
@property
641+
def is_compileable(self) -> bool:
642+
"""Required by HuggingFace transformers generate() auto-compile check."""
643+
return False
644+
642645
def reorder_cache(self, beam_idx: torch.LongTensor):
643646
"""Reorder the cache based on the beam indices."""
644647
if isinstance(self.cache_manager, PagedKVCacheManager):

0 commit comments

Comments
 (0)