Skip to content

Commit 655b762

Browse files
committed
Fix HFInferenceParams missing is_compileable for transformers 5.x compatibility
transformers 5.5.4 (introduced via PyTorch 26.03 base container) added an is_compileable property check on cache objects in generate(). Add is_compileable returning False to HFInferenceParams in all model files (llama3, qwen2, qwen3, mixtral) and their recipe copies. Signed-off-by: svc-bionemo <svc-bionemo@nvidia.com> Signed-off-by: svc-bionemo <267129667+svc-bionemo@users.noreply.github.com>
1 parent b508876 commit 655b762

6 files changed

Lines changed: 30 additions & 0 deletions

File tree

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,8 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
591591
updated_key_cache = key_cache.index_select(0, beam_idx)
592592
updated_value_cache = value_cache.index_select(0, beam_idx)
593593
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
594+
595+
@property
596+
def is_compileable(self) -> bool:
597+
"""Return False as this cache is not compatible with torch.compile."""
598+
return False

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,11 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
876876
updated_value_cache = value_cache.index_select(0, beam_idx)
877877
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
878878

879+
@property
880+
def is_compileable(self) -> bool:
881+
"""Return False as this cache is not compatible with torch.compile."""
882+
return False
883+
879884

880885
@torch.compile(fullgraph=True)
881886
def _build_expert_sort_indices(recv_counts: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

bionemo-recipes/models/qwen/modeling_qwen2_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,8 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
576576
updated_key_cache = key_cache.index_select(0, beam_idx)
577577
updated_value_cache = value_cache.index_select(0, beam_idx)
578578
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
579+
580+
@property
581+
def is_compileable(self) -> bool:
582+
"""Return False as this cache is not compatible with torch.compile."""
583+
return False

bionemo-recipes/models/qwen/modeling_qwen3_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,8 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
586586
updated_key_cache = key_cache.index_select(0, beam_idx)
587587
updated_value_cache = value_cache.index_select(0, beam_idx)
588588
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
589+
590+
@property
591+
def is_compileable(self) -> bool:
592+
"""Return False as this cache is not compatible with torch.compile."""
593+
return False

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,8 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
597597
updated_key_cache = key_cache.index_select(0, beam_idx)
598598
updated_value_cache = value_cache.index_select(0, beam_idx)
599599
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
600+
601+
@property
602+
def is_compileable(self) -> bool:
603+
"""Return False as this cache is not compatible with torch.compile."""
604+
return False

bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,8 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
597597
updated_key_cache = key_cache.index_select(0, beam_idx)
598598
updated_value_cache = value_cache.index_select(0, beam_idx)
599599
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
600+
601+
@property
602+
def is_compileable(self) -> bool:
603+
"""Return False as this cache is not compatible with torch.compile."""
604+
return False

0 commit comments

Comments
 (0)