Skip to content

Commit 232c9e6

Browse files
authored
tests: Add distill tests and CI (#42)
* Fixed failing test * Added tests for distillation, added tokenizer for testing * Removed imports * Added extra test * Updated test * Updated test * Small update * Updated init * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Added CI for codeconv * Resolved comments
1 parent 2f09539 commit 232c9e6

9 files changed

Lines changed: 31005 additions & 10 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
name: Run tests and upload coverage
2+
3+
on:
4+
push
5+
6+
jobs:
7+
test:
8+
name: Run tests with pytest
9+
runs-on: ${{ matrix.os }}
10+
strategy:
11+
matrix:
12+
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
13+
python-version: ["3.10"]
14+
fail-fast: false
15+
16+
steps:
17+
- uses: actions/checkout@v4
18+
19+
- name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }}
20+
uses: actions/setup-python@v5
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
allow-prereleases: true
24+
25+
# Step for Windows: Create and activate a virtual environment
26+
- name: Create and activate a virtual environment (Windows)
27+
if: ${{ runner.os == 'Windows' }}
28+
run: |
29+
irm https://astral.sh/uv/install.ps1 | iex
30+
uv venv .venv
31+
"VIRTUAL_ENV=.venv" | Out-File -FilePath $env:GITHUB_ENV -Append
32+
"$PWD/.venv/Scripts" | Out-File -FilePath $env:GITHUB_PATH -Append
33+
34+
# Step for Unix: Create and activate a virtual environment
35+
- name: Create and activate a virtual environment (Unix)
36+
if: ${{ runner.os != 'Windows' }}
37+
run: |
38+
curl -LsSf https://astral.sh/uv/install.sh | sh
39+
uv venv .venv
40+
echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV
41+
echo "$PWD/.venv/bin" >> $GITHUB_PATH
42+
43+
# Install dependencies using uv pip
44+
- name: Install dependencies
45+
run: make install
46+
# run: uv pip install -e ".[pytest]"
47+
48+
# Run tests with coverage
49+
- name: Run tests under coverage
50+
run: |
51+
coverage run -m pytest
52+
coverage report
53+
54+
# Upload results to Codecov
55+
- name: Upload results to Codecov
56+
uses: codecov/codecov-action@v4
57+
with:
58+
token: ${{ secrets.CODECOV_TOKEN }}

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ repos:
1515
description: Prevent giant files from being committed.
1616
- id: check-case-conflict
1717
description: Check for files with names that would conflict on case-insensitive filesystems like MacOS/Windows.
18+
- id: check-yaml
19+
description: Check yaml files for syntax errors.
1820
- repo: https://github.com/jsh9/pydoclint
1921
rev: 0.5.3
2022
hooks:

README.md

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
2-
3-
4-
5-
6-
71
<div align="center">
82
<h1>Model2Vec: Distill a Small Fast Model from any Sentence Transformer</h1>
93
</div>
@@ -22,12 +16,11 @@
2216
<a href="https://pypi.org/project/model2vec/"><img src="https://img.shields.io/pypi/pyversions/model2vec" alt="Supported Python versions"></a>
2317
<a href="https://pepy.tech/project/model2vec">
2418
<img src="https://static.pepy.tech/badge/model2vec" alt="Downloads">
25-
</a>
19+
</a>
2620
<a href="https://github.com/MinishLab/model2vec/blob/main/LICENSE"><img src="https://img.shields.io/badge/license-MIT-green" alt="License - MIT"></a>
2721
</h2>
2822
</div>
2923

30-
3124
<p align="center">
3225
<img src="assets/images/model2vec_model_diagram.png" alt="Model2Vec">
3326
</p>
@@ -66,7 +59,7 @@ embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to ever
6659

6760
And that's it. You can use the model to classify texts, to cluster, or to build a RAG system.
6861

69-
Instead of using on of our models, you can distill your own Model2Vec model from a Sentence Transformer model. The following code snippet shows how to distill a model:
62+
Instead of using one of our models, you can distill your own Model2Vec model from a Sentence Transformer model. The following code snippet shows how to distill a model:
7063
```python
7164
from model2vec.distill import distill
7265

tests/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import Any
5+
16
import numpy as np
27
import pytest
8+
import torch
39
from tokenizers import Tokenizer
410
from tokenizers.models import WordLevel
511
from tokenizers.pre_tokenizers import Whitespace
12+
from transformers import AutoModel, BertTokenizerFast
613

714

815
@pytest.fixture
@@ -18,6 +25,40 @@ def mock_tokenizer() -> Tokenizer:
1825
return tokenizer
1926

2027

28+
@pytest.fixture
29+
def mock_berttokenizer() -> BertTokenizerFast:
30+
"""Load the real BertTokenizerFast from the provided tokenizer.json file."""
31+
tokenizer_path = Path("tests/data/test_tokenizer/tokenizer.json")
32+
return BertTokenizerFast(tokenizer_file=str(tokenizer_path))
33+
34+
35+
@pytest.fixture
36+
def mock_transformer() -> AutoModel:
37+
"""Create a mock transformer model."""
38+
39+
class MockPreTrainedModel:
40+
def __init__(self) -> None:
41+
self.device = "cpu"
42+
43+
def to(self, device: str) -> MockPreTrainedModel:
44+
self.device = device
45+
return self
46+
47+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
48+
# Simulate a last_hidden_state output for a transformer model
49+
batch_size, seq_length = kwargs["input_ids"].shape
50+
# Return a tensor of shape (batch_size, seq_length, 768)
51+
return type(
52+
"BaseModelOutputWithPoolingAndCrossAttentions",
53+
(object,),
54+
{
55+
"last_hidden_state": torch.rand(batch_size, seq_length, 768) # Simulate 768 hidden units
56+
},
57+
)
58+
59+
return MockPreTrainedModel()
60+
61+
2162
@pytest.fixture
2263
def mock_vectors() -> np.ndarray:
2364
"""Create mock vectors."""
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"cls_token": "[CLS]",
3+
"mask_token": "[MASK]",
4+
"pad_token": "[PAD]",
5+
"sep_token": "[SEP]",
6+
"unk_token": "[UNK]"
7+
}

0 commit comments

Comments
 (0)