Skip to content

Commit 2be6c40

Browse files
fix(ci): handle LlamaRotaryEmbedding signature changes in newer transformers versions
- Add a try-except block to support both `transformers<=4.40.0` and `transformers>=4.45.0` signatures for `LlamaRotaryEmbedding` - Update `transformers` dependency requirement to `<5.0.0` to maintain broader stability Co-authored-by: Pomilon <220483426+Pomilon@users.noreply.github.com>
1 parent 0803335 commit 2be6c40

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

src/lema/models/llama.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@ def __init__(self, config: Dict[str, Any]):
1111
self.hf_config = LlamaConfig(**config)
1212
if getattr(self.hf_config, "_attn_implementation", None) is None:
1313
self.hf_config._attn_implementation = config.get("attn_implementation", "eager")
14-
self.rotary_emb = LlamaRotaryEmbedding(self.hf_config.hidden_size // self.hf_config.num_attention_heads, max_position_embeddings=self.hf_config.max_position_embeddings)
14+
15+
try:
16+
self.rotary_emb = LlamaRotaryEmbedding(self.hf_config)
17+
except TypeError:
18+
self.rotary_emb = LlamaRotaryEmbedding(
19+
self.hf_config.hidden_size // self.hf_config.num_attention_heads,
20+
max_position_embeddings=self.hf_config.max_position_embeddings
21+
)
22+
1523
self.layer_pool: List[nn.Module] = []
1624
self.param_mappings: Dict[int, List[tuple]] = {}
1725
self._max_pool_size = 8

0 commit comments

Comments
 (0)