Skip to content

Commit 0803335

Browse files
fix(ci): fix tests and model compatibility for CI/CD
- Limit transformers dependency to <5.0.0 to prevent breaking changes in model outputs. - Fix LlamaRotaryEmbedding instantiation in LlamaAdapter by passing the correct dim. - Update test imports from src.lema to lema. Co-authored-by: Pomilon <220483426+Pomilon@users.noreply.github.com>
1 parent a92a385 commit 0803335

6 files changed

Lines changed: 11 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
]
2828
dependencies = [
2929
"torch>=2.0.0",
30-
"transformers>=4.30.0",
30+
"transformers>=4.30.0,<5.0.0",
3131
"safetensors>=0.3.0",
3232
"accelerate>=0.20.0",
3333
"peft>=0.4.0",

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch>=2.0.0
2-
transformers>=4.30.0
2+
transformers>=4.30.0,<5.0.0
33
safetensors>=0.3.0
44
accelerate>=0.20.0
55
peft>=0.4.0

src/lema/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, config: Dict[str, Any]):
1111
self.hf_config = LlamaConfig(**config)
1212
if getattr(self.hf_config, "_attn_implementation", None) is None:
1313
self.hf_config._attn_implementation = config.get("attn_implementation", "eager")
14-
self.rotary_emb = LlamaRotaryEmbedding(self.hf_config)
14+
self.rotary_emb = LlamaRotaryEmbedding(self.hf_config.hidden_size // self.hf_config.num_attention_heads, max_position_embeddings=self.hf_config.max_position_embeddings)
1515
self.layer_pool: List[nn.Module] = []
1616
self.param_mappings: Dict[int, List[tuple]] = {}
1717
self._max_pool_size = 8

tests/test_core_components.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import pytest
22
import torch
33
import torch.nn as nn
4-
from src.lema.core.gbi import GlobalBinaryIndex
5-
from src.lema.core.lora import LoRAManager, LoRAWrapper
6-
from src.lema.core.memory import TripleBufferManager
4+
from lema.core.gbi import GlobalBinaryIndex
5+
from lema.core.lora import LoRAManager, LoRAWrapper
6+
from lema.core.memory import TripleBufferManager
77

88
# Mocking
99
class MockAdapter:

tests/test_gradient_equivalence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import torch.nn as nn
33
from transformers import GPT2Config, GPT2LMHeadModel
44
from safetensors.torch import save_file
5-
from src.lema.core.gbi import GlobalBinaryIndex
6-
from src.lema.models.gpt2 import GPT2Adapter
7-
from src.lema.engine.trainer import LemaTrainer
8-
from src.lema.config import LemaConfig, MemoryStrategy
5+
from lema.core.gbi import GlobalBinaryIndex
6+
from lema.models.gpt2 import GPT2Adapter
7+
from lema.engine.trainer import LemaTrainer
8+
from lema.config import LemaConfig, MemoryStrategy
99
import os
1010
import pytest
1111

tests/test_llama_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from src.lema.models.llama import LlamaAdapter
3+
from lema.models.llama import LlamaAdapter
44
from transformers import LlamaConfig
55

66
def test_llama_adapter_forward():

0 commit comments

Comments
 (0)