Skip to content

Commit dbc6005

Browse files
authored
Merge pull request #1563 from therealmichaelberna/transformers-v5-compat
feat: add compatibility with transformers V5
2 parents 6fd1762 + 94993fd commit dbc6005

File tree

10 files changed

+278
-3
lines changed

10 files changed

+278
-3
lines changed

FlagEmbedding/finetune/reranker/decoder_only/layerwise/modeling_minicpm_reranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
logging,
5151
replace_return_docstrings,
5252
)
53-
from transformers.utils.import_utils import is_torch_fx_available
53+
from FlagEmbedding.utils.transformers_compat import is_torch_fx_available
5454
from .configuration_minicpm_reranker import LayerWiseMiniCPMConfig
5555
import re
5656

FlagEmbedding/inference/reranker/decoder_only/models/modeling_minicpm_reranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
logging,
5151
replace_return_docstrings,
5252
)
53-
from transformers.utils.import_utils import is_torch_fx_available
53+
from FlagEmbedding.utils.transformers_compat import is_torch_fx_available
5454
from .configuration_minicpm_reranker import LayerWiseMiniCPMConfig
5555
import re
5656

FlagEmbedding/utils/__init__.py

Whitespace-only changes.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from packaging import version
2+
import transformers
3+
4+
TF_VER = version.parse(getattr(transformers, "__version__", "0.0.0"))
5+
IS_TF_V5_OR_HIGHER = TF_VER >= version.parse("5.0.0")
6+
7+
8+
# ------------- torch.fx availability -------------
9+
# v5 removed is_torch_fx_available. We emulate it via feature detection.
10+
def is_torch_fx_available():
11+
try:
12+
import torch.fx # noqa: F401
13+
14+
return True
15+
except Exception:
16+
return False
17+
18+
19+
# ------------- other utilities that moved -------------
20+
# Pattern:
21+
# try the new location first (v5), then fall back to v4 path, else provide a safe default.
22+
def import_from_candidates(candidates, default=None):
23+
for mod, name in candidates:
24+
try:
25+
module = __import__(mod, fromlist=[name])
26+
return getattr(module, name)
27+
except Exception:
28+
pass
29+
return default

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
include_package_data=True,
1616
install_requires=[
1717
'torch>=1.6.0',
18-
'transformers>=4.44.2',
18+
'transformers>=4.44.2,<6.0.0',
1919
'datasets>=2.19.0',
2020
'accelerate>=0.20.1',
2121
'sentence_transformers',

tests/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# FlagEmbedding Tests
2+
3+
This directory contains tests for the FlagEmbedding library, including compatibility tests for Transformers 5.0.
4+
5+
## Test Files
6+
7+
- `test_imports_v5.py`: Tests that imports work with Transformers v5, particularly the compatibility layer for `is_torch_fx_available`.
8+
- `test_infer_embedder_basic.py`: Tests basic functionality of BGE embedder models with a small public checkpoint.
9+
- `test_infer_reranker_basic.py`: Tests basic functionality of reranker models.
10+
11+
## Running Tests
12+
13+
1. create a python venv `python -m venv pytest_venv`
14+
2. activate venv `source pytest_venv/bin/activate`
15+
3. install pytest `pip install pytest`
16+
4. install flagembedding package in development mode: `pip install -e .`
17+
18+
Then run the tests using pytest:
19+
20+
```bash
21+
# Run all tests
22+
pytest tests/
23+
24+
# Run a specific test file
25+
pytest tests/test_imports_v5.py
26+
27+
# Run with verbose output
28+
pytest -v tests/
29+
```
30+
31+
## Transformers 5.0 Compatibility
32+
33+
The tests verify that FlagEmbedding works with Transformers 5.0, which removed the `is_torch_fx_available` function.
34+
The compatibility layer in `FlagEmbedding/utils/transformers_compat.py` provides this function for backward compatibility.
35+
36+
**Note:** Transformers 5.0 requires Python 3.10 or higher. If you're using Python 3.9 or lower, you'll need to upgrade your Python version to test with Transformers 5.0.
37+
38+
To test with a specific version of transformers (with Python 3.10+):
39+
40+
```bash
41+
pip install transformers==5.0.0
42+
pytest tests/

tests/conftest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
Common pytest fixtures and configuration for FlagEmbedding tests.
3+
"""
4+
5+
import os
6+
import pytest
7+
import torch
8+
from packaging import version
9+
import transformers
10+
11+
# Check if we're using transformers v5+
12+
TF_VER = version.parse(getattr(transformers, "__version__", "0.0.0"))
13+
IS_TF_V5_OR_HIGHER = TF_VER >= version.parse("5.0.0")
14+
15+
16+
@pytest.fixture(scope="session")
17+
def device():
18+
"""Return the device to use for tests."""
19+
return "cuda" if torch.cuda.is_available() else "cpu"
20+
21+
22+
@pytest.fixture(scope="session")
23+
def transformers_version():
24+
"""Return the transformers version."""
25+
return TF_VER

tests/test_imports_v5.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
Test that imports work with Transformers v5.
3+
4+
This test verifies that the compatibility layer in FlagEmbedding/utils/transformers_compat.py
5+
properly handles the the removal of is_torch_fx_available in Transformers v5
6+
"""
7+
8+
import pytest
9+
import transformers
10+
from packaging import version
11+
12+
# Import the compatibility layer
13+
from FlagEmbedding.utils.transformers_compat import is_torch_fx_available
14+
15+
# Check if we're using transformers v5+
16+
TF_VER = version.parse(getattr(transformers, "__version__", "0.0.0"))
17+
IS_TF_V5_OR_HIGHER = TF_VER >= version.parse("5.0.0")
18+
19+
20+
# Import the files mentioned in issue #1561 that use is_torch_fx_available
21+
def test_import_modeling_minicpm_reranker_inference():
22+
"""Test importing the modeling_minicpm_reranker module from inference."""
23+
from FlagEmbedding.inference.reranker.decoder_only.models.modeling_minicpm_reranker import (
24+
LayerWiseMiniCPMForCausalLM,
25+
)
26+
27+
assert LayerWiseMiniCPMForCausalLM is not None
28+
29+
30+
def test_import_modeling_minicpm_reranker_finetune():
31+
"""Test importing the modeling_minicpm_reranker module from finetune."""
32+
from FlagEmbedding.finetune.reranker.decoder_only.layerwise.modeling_minicpm_reranker import (
33+
LayerWiseMiniCPMForCausalLM,
34+
)
35+
36+
assert LayerWiseMiniCPMForCausalLM is not None
37+
38+
39+
@pytest.mark.skipif(not IS_TF_V5_OR_HIGHER, reason="Only relevant for Transformers v5+")
40+
def test_is_torch_fx_available_v5():
41+
"""Test that is_torch_fx_available works with Transformers v5."""
42+
# This should not raise an exception
43+
result = is_torch_fx_available()
44+
# The result depends on whether torch.fx is available, but the function should work
45+
assert isinstance(result, bool)
46+
47+
48+
def test_transformers_version(transformers_version):
49+
"""Test that we can detect the transformers version."""
50+
assert transformers_version is not None
51+
print(f"Transformers version: {transformers_version}")

tests/test_infer_embedder_basic.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Test basic functionality of BGE embedder models with Transformers v5.
3+
4+
This test loads a small/public BGE checkpoint and runs a single encode on toy strings,
5+
verifying that the shape/dtype are correct and that cosine similarity is sane.
6+
"""
7+
import pytest
8+
import torch
9+
import numpy as np
10+
from FlagEmbedding import FlagModel
11+
12+
def cosine_similarity(a, b):
13+
"""Compute cosine similarity between two vectors."""
14+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
15+
16+
def test_bge_embedder_basic(device):
17+
"""Test basic functionality of BGE embedder."""
18+
# Load a small BGE model
19+
model_name = "BAAI/bge-base-en-v1.5"
20+
model = FlagModel(model_name, device=device)
21+
22+
# Test encoding single strings
23+
query = "What is the capital of France?"
24+
passage = "Paris is the capital and most populous city of France."
25+
26+
# Get embeddings
27+
query_embedding = model.encode(query)
28+
passage_embedding = model.encode(passage)
29+
30+
# Check shapes and types
31+
assert isinstance(query_embedding, np.ndarray)
32+
assert isinstance(passage_embedding, np.ndarray)
33+
assert query_embedding.ndim == 1 # Should be a 1D vector
34+
assert passage_embedding.ndim == 1 # Should be a 1D vector
35+
36+
# Check that embeddings have reasonable values
37+
assert not np.isnan(query_embedding).any()
38+
assert not np.isnan(passage_embedding).any()
39+
40+
# Check cosine similarity is reasonable (should be high for related texts)
41+
similarity = cosine_similarity(query_embedding, passage_embedding)
42+
assert 0 <= similarity <= 1 # Cosine similarity range
43+
assert similarity > 0.5 # These texts should be somewhat similar
44+
45+
def test_bge_embedder_batch(device):
46+
"""Test batch encoding with BGE embedder."""
47+
# Load a small BGE model
48+
model_name = "BAAI/bge-base-en-v1.5"
49+
model = FlagModel(model_name, device=device)
50+
51+
# Test batch encoding
52+
queries = [
53+
"What is the capital of France?",
54+
"Who wrote Romeo and Juliet?"
55+
]
56+
57+
# Get embeddings
58+
embeddings = model.encode(queries)
59+
60+
# Check shapes and types
61+
assert isinstance(embeddings, np.ndarray)
62+
assert embeddings.ndim == 2 # Should be a 2D array (batch_size x embedding_dim)
63+
assert embeddings.shape[0] == len(queries)
64+
65+
# Check that embeddings have reasonable values
66+
assert not np.isnan(embeddings).any()

tests/test_infer_reranker_basic.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Test basic functionality of reranker models with Transformers v5.
3+
4+
This test instantiates a lightweight reranker and calls compute_score on query/doc pairs
5+
to validate the forward pass.
6+
"""
7+
8+
import pytest
9+
import torch
10+
import numpy as np
11+
from FlagEmbedding import FlagReranker
12+
13+
14+
def test_reranker_basic(device):
15+
"""Test basic functionality of reranker."""
16+
# Load a lightweight reranker model
17+
model_name = "BAAI/bge-reranker-base"
18+
model = FlagReranker(model_name, device=device)
19+
20+
# Test scoring a single query-document pair
21+
query = "What is the capital of France?"
22+
passage = "Paris is the capital and most populous city of France."
23+
24+
# Get score
25+
pair = [(query, passage)]
26+
scores = model.compute_score(pair)
27+
score = scores[0]
28+
29+
# Check score type and range
30+
assert isinstance(score, float)
31+
# Scores are typically in a reasonable range (model-dependent)
32+
assert -100 < score < 100
33+
34+
35+
def test_reranker_batch(device):
36+
"""Test batch scoring with reranker."""
37+
# Load a lightweight reranker model
38+
model_name = "BAAI/bge-reranker-base"
39+
model = FlagReranker(model_name, device=device)
40+
41+
# Test batch scoring
42+
query = "What is the capital of France?"
43+
passages = [
44+
"Paris is the capital and most populous city of France.",
45+
"Berlin is the capital and largest city of Germany.",
46+
"London is the capital and largest city of England and the United Kingdom.",
47+
]
48+
49+
# Create pairs for scoring
50+
pairs = [(query, passage) for passage in passages]
51+
52+
# Get scores
53+
scores = model.compute_score(pairs)
54+
55+
# Check scores shape and type
56+
assert isinstance(scores, list)
57+
assert len(scores) == len(passages)
58+
assert all(isinstance(score, float) for score in scores)
59+
60+
# Check that Paris (correct answer) gets highest score
61+
paris_score = scores[0]
62+
assert paris_score == max(scores), "Paris should have the highest score"

0 commit comments

Comments
 (0)