File tree Expand file tree Collapse file tree
opengenome2_llama_native_te Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -409,11 +409,9 @@ def get_autocast_context(
409409
410410 if init and self .config .use_quantized_model_init :
411411 if precision in ("fp8" , "fp4" ):
412- # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext()
413- # preserves the outer context's settings (recipe, preserve_high_precision_init_val).
414- # A nested quantized_model_init would override preserve_high_precision_init_val to False.
415- return nullcontext ()
416- # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context.
412+ return transformer_engine .pytorch .quantized_model_init (
413+ recipe = recipe , preserve_high_precision_init_val = True
414+ )
417415 return transformer_engine .pytorch .quantized_model_init (enabled = False )
418416
419417 if precision == "fp8" :
@@ -633,6 +631,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
633631 return 0
634632 return max (self .sequences .values ())
635633
634+ @property
635+ def is_compileable (self ) -> bool :
636+ """Required by HuggingFace transformers generate() auto-compile check."""
637+ return False
638+
636639 def reorder_cache (self , beam_idx : torch .LongTensor ):
637640 """Reorder the cache based on the beam indices."""
638641 if isinstance (self .cache_manager , PagedKVCacheManager ):
Original file line number Diff line number Diff line change @@ -415,11 +415,9 @@ def get_autocast_context(
415415
416416 if init and self .config .use_quantized_model_init :
417417 if precision in ("fp8" , "fp4" ):
418- # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext()
419- # preserves the outer context's settings (recipe, preserve_high_precision_init_val).
420- # A nested quantized_model_init would override preserve_high_precision_init_val to False.
421- return nullcontext ()
422- # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context.
418+ return transformer_engine .pytorch .quantized_model_init (
419+ recipe = recipe , preserve_high_precision_init_val = True
420+ )
423421 return transformer_engine .pytorch .quantized_model_init (enabled = False )
424422
425423 if precision == "fp8" :
@@ -639,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
639637 return 0
640638 return max (self .sequences .values ())
641639
640+ @property
641+ def is_compileable (self ) -> bool :
642+ """Required by HuggingFace transformers generate() auto-compile check."""
643+ return False
644+
642645 def reorder_cache (self , beam_idx : torch .LongTensor ):
643646 """Reorder the cache based on the beam indices."""
644647 if isinstance (self .cache_manager , PagedKVCacheManager ):
Original file line number Diff line number Diff line change @@ -415,11 +415,9 @@ def get_autocast_context(
415415
416416 if init and self .config .use_quantized_model_init :
417417 if precision in ("fp8" , "fp4" ):
418- # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext()
419- # preserves the outer context's settings (recipe, preserve_high_precision_init_val).
420- # A nested quantized_model_init would override preserve_high_precision_init_val to False.
421- return nullcontext ()
422- # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context.
418+ return transformer_engine .pytorch .quantized_model_init (
419+ recipe = recipe , preserve_high_precision_init_val = True
420+ )
423421 return transformer_engine .pytorch .quantized_model_init (enabled = False )
424422
425423 if precision == "fp8" :
@@ -639,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
639637 return 0
640638 return max (self .sequences .values ())
641639
640+ @property
641+ def is_compileable (self ) -> bool :
642+ """Required by HuggingFace transformers generate() auto-compile check."""
643+ return False
644+
642645 def reorder_cache (self , beam_idx : torch .LongTensor ):
643646 """Reorder the cache based on the beam indices."""
644647 if isinstance (self .cache_manager , PagedKVCacheManager ):
You can’t perform that action at this time.
0 commit comments