Skip to content

Commit b8ea565

Browse files
committed
fix: add is_compileable attribute to HFInferenceParams for transformers 5.4+
transformers >= 5.4 checks cache.is_compileable in generate(). The custom HFInferenceParams class (TE-based cache) did not implement this attribute, causing AttributeError during test_generate_with_cache tests. Set is_compileable = False since this cache type is not compatible with torch.compile generate(). Tested locally: - models/mixtral: 52 passed, 3 skipped, 26 xfailed (3 local-only OOM on 32GB GPU, pass on CI L4) - recipes/mixtral_native_te: 7 passed - recipes/opengenome2_mixtral_native_te: 20 passed Signed-off-by: svc-bionemo <267129667+svc-bionemo@users.noreply.github.com>
1 parent 5609f30 commit b8ea565

3 files changed

Lines changed: 12 additions & 0 deletions

File tree

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
872872
class HFInferenceParams(InferenceParams):
873873
"""Extension of the InferenceParams class to support HF generate() and beam search."""
874874

875+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
876+
# custom TE-based cache is not compatible with torch.compile generate().
877+
is_compileable = False
878+
875879
def get_seq_length(self, layer_idx: int = 0) -> int:
876880
"""Return the current cached sequence length.
877881

bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
878878
class HFInferenceParams(InferenceParams):
879879
"""Extension of the InferenceParams class to support HF generate() and beam search."""
880880

881+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
882+
# custom TE-based cache is not compatible with torch.compile generate().
883+
is_compileable = False
884+
881885
def get_seq_length(self, layer_idx: int = 0) -> int:
882886
"""Return the current cached sequence length.
883887

bionemo-recipes/recipes/opengenome2_mixtral_native_te/modeling_mixtral_te.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,10 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
878878
class HFInferenceParams(InferenceParams):
879879
"""Extension of the InferenceParams class to support HF generate() and beam search."""
880880

881+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
882+
# custom TE-based cache is not compatible with torch.compile generate().
883+
is_compileable = False
884+
881885
def get_seq_length(self, layer_idx: int = 0) -> int:
882886
"""Return the current cached sequence length.
883887

0 commit comments

Comments
 (0)