Skip to content

Commit bf807ba

Browse files
authored
Merge pull request #1 from Pomilon/fix-ci-tests-16470124072715661906
Fix tests and model compatibility for CI/CD
2 parents a92a385 + 2be6c40 commit bf807ba

6 files changed

Lines changed: 19 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
]
2828
dependencies = [
2929
"torch>=2.0.0",
30-
"transformers>=4.30.0",
30+
"transformers>=4.30.0,<5.0.0",
3131
"safetensors>=0.3.0",
3232
"accelerate>=0.20.0",
3333
"peft>=0.4.0",

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch>=2.0.0
2-
transformers>=4.30.0
2+
transformers>=4.30.0,<5.0.0
33
safetensors>=0.3.0
44
accelerate>=0.20.0
55
peft>=0.4.0

src/lema/models/llama.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@ def __init__(self, config: Dict[str, Any]):
1111
self.hf_config = LlamaConfig(**config)
1212
if getattr(self.hf_config, "_attn_implementation", None) is None:
1313
self.hf_config._attn_implementation = config.get("attn_implementation", "eager")
14-
self.rotary_emb = LlamaRotaryEmbedding(self.hf_config)
14+
15+
try:
16+
self.rotary_emb = LlamaRotaryEmbedding(self.hf_config)
17+
except TypeError:
18+
self.rotary_emb = LlamaRotaryEmbedding(
19+
self.hf_config.hidden_size // self.hf_config.num_attention_heads,
20+
max_position_embeddings=self.hf_config.max_position_embeddings
21+
)
22+
1523
self.layer_pool: List[nn.Module] = []
1624
self.param_mappings: Dict[int, List[tuple]] = {}
1725
self._max_pool_size = 8

tests/test_core_components.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import pytest
22
import torch
33
import torch.nn as nn
4-
from src.lema.core.gbi import GlobalBinaryIndex
5-
from src.lema.core.lora import LoRAManager, LoRAWrapper
6-
from src.lema.core.memory import TripleBufferManager
4+
from lema.core.gbi import GlobalBinaryIndex
5+
from lema.core.lora import LoRAManager, LoRAWrapper
6+
from lema.core.memory import TripleBufferManager
77

88
# Mocking
99
class MockAdapter:

tests/test_gradient_equivalence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import torch.nn as nn
33
from transformers import GPT2Config, GPT2LMHeadModel
44
from safetensors.torch import save_file
5-
from src.lema.core.gbi import GlobalBinaryIndex
6-
from src.lema.models.gpt2 import GPT2Adapter
7-
from src.lema.engine.trainer import LemaTrainer
8-
from src.lema.config import LemaConfig, MemoryStrategy
5+
from lema.core.gbi import GlobalBinaryIndex
6+
from lema.models.gpt2 import GPT2Adapter
7+
from lema.engine.trainer import LemaTrainer
8+
from lema.config import LemaConfig, MemoryStrategy
99
import os
1010
import pytest
1111

tests/test_llama_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from src.lema.models.llama import LlamaAdapter
3+
from lema.models.llama import LlamaAdapter
44
from transformers import LlamaConfig
55

66
def test_llama_adapter_forward():

0 commit comments

Comments
 (0)