Skip to content

Commit 7cc1e78

Browse files
committed
fix: add is_compileable attribute to HFInferenceParams for transformers 5.4+
transformers >= 5.4 checks cache.is_compileable in generate(). The custom HFInferenceParams class (TE-based cache) did not implement this attribute, causing AttributeError during test_generate_with_cache tests. Set is_compileable = False since this cache type is not compatible with torch.compile generate(). Tested locally: - models/mixtral: 52 passed, 3 skipped, 26 xfailed (3 local-only OOM on 32GB GPU, pass on CI L4) - recipes/mixtral_native_te: 7 passed - recipes/opengenome2_mixtral_native_te: 20 passed Signed-off-by: svc-bionemo <267129667+svc-bionemo@users.noreply.github.com>
1 parent 376a00f commit 7cc1e78

4 files changed

Lines changed: 6 additions & 5 deletions

File tree

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,8 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
880880
class HFInferenceParams(InferenceParams):
881881
"""Extension of the InferenceParams class to support HF generate() and beam search."""
882882

883+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
884+
# custom TE-based cache is not compatible with torch.compile generate().
883885
is_compileable = False
884886

885887
def get_seq_length(self, layer_idx: int = 0) -> int:

bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,8 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
886886
class HFInferenceParams(InferenceParams):
887887
"""Extension of the InferenceParams class to support HF generate() and beam search."""
888888

889+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
890+
# custom TE-based cache is not compatible with torch.compile generate().
889891
is_compileable = False
890892

891893
def get_seq_length(self, layer_idx: int = 0) -> int:

bionemo-recipes/recipes/opengenome2_mixtral_native_te/modeling_mixtral_te.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,8 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
880880
class HFInferenceParams(InferenceParams):
881881
"""Extension of the InferenceParams class to support HF generate() and beam search."""
882882

883+
# Required by transformers >= 5.4 _valid_auto_compile_criteria(); this
884+
# custom TE-based cache is not compatible with torch.compile generate().
883885
is_compileable = False
884886

885887
def get_seq_length(self, layer_idx: int = 0) -> int:

ci/scripts/check_copied_files.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,6 @@ def _compare_file_contents(source_file: Path, dest_file: Path, source_display: s
205205
"bionemo-recipes/models/codonfm/modeling_codonfm_te.py": [
206206
"bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py",
207207
],
208-
# Mixtral TE model -> recipe sync
209-
"bionemo-recipes/models/mixtral/modeling_mixtral_te.py": [
210-
"bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py",
211-
"bionemo-recipes/recipes/opengenome2_mixtral_native_te/modeling_mixtral_te.py",
212-
],
213208
# Common test library - synced between models
214209
"bionemo-recipes/models/esm2/tests/common": [
215210
"bionemo-recipes/models/llama3/tests/common",

0 commit comments

Comments
 (0)