Skip to content

Commit eb21ede

Browse files
committed
merge
2 parents cb7feb2 + 3f5786a commit eb21ede

5 files changed

Lines changed: 66 additions & 8 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,24 +314,35 @@ def _clean_vocabulary(tokenizer: Tokenizer, vocabulary: list[str], added_tokens:
314314
n_duplicates = 0
315315
n_multiword = 0
316316
for token in vocabulary:
317-
if tokenizer.normalizer is not None:
318-
token = tokenizer.normalizer.normalize_str(token)
317+
normalizer = tokenizer.normalizer
318+
if normalizer is not None:
319+
token = normalizer.normalize_str(token)
319320

320321
if not token:
321322
n_empty += 1
322323
continue
323-
if token in seen_tokens or token in added_tokens_set:
324-
n_duplicates += 1
325-
continue
326324

327325
pre_tokenizer = tokenizer.pre_tokenizer
326+
# We need to check whether the pretokenized token is a single word or not.
328327
if pre_tokenizer is not None:
329328
pretokenized_tokens = pre_tokenizer.pre_tokenize_str(token)
330329
if len(pretokenized_tokens) != 1:
331330
n_multiword += 1
332331
continue
332+
new_token = pretokenized_tokens[-1][0]
333+
else:
334+
new_token = token
335+
336+
# We need to check whether the pretokenized token is in the vocabulary.
337+
# But we need to return the original token, because that will be tokenized
338+
# again by the tokenizer during featurization.
339+
if new_token in seen_tokens or new_token in added_tokens_set:
340+
n_duplicates += 1
341+
continue
333342

334-
seen_tokens.add(token)
343+
# Add the possibly pretokenized token to _seen_
344+
seen_tokens.add(new_token)
345+
# Add the original string to the vocabulary.
335346
cleaned_vocabulary.append(token)
336347

337348
if n_duplicates:

model2vec/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def from_pretrained(
152152
token: str | None = None,
153153
normalize: bool | None = None,
154154
quantize_to: str | DType | None = None,
155+
dimensionality: int | None = None,
155156
) -> StaticModel:
156157
"""
157158
Load a StaticModel from a local path or huggingface hub path.
@@ -163,7 +164,11 @@ def from_pretrained(
163164
:param normalize: Whether to normalize the embeddings.
164165
:param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
165166
If a string is passed, it is converted to a DType.
167+
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
168+
This is useful if you want to load a model with a lower dimensionality.
169+
Note that this only applies if you have trained your model using mrl or PCA.
166170
:return: A StaticModel
171+
:raises: ValueError if the dimensionality is greater than the model dimensionality.
167172
"""
168173
from model2vec.hf_utils import load_pretrained
169174

@@ -172,6 +177,16 @@ def from_pretrained(
172177
if quantize_to is not None:
173178
quantize_to = DType(quantize_to)
174179
embeddings = quantize_embeddings(embeddings, quantize_to)
180+
if dimensionality is not None:
181+
if dimensionality > embeddings.shape[1]:
182+
raise ValueError(
183+
f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
184+
)
185+
embeddings = embeddings[:, :dimensionality]
186+
if config.get("apply_pca", None) is None:
187+
logger.warning(
188+
"You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected."
189+
)
175190

176191
return cls(
177192
embeddings,

model2vec/py.typed

Whitespace-only changes.

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ packages = ["model2vec"]
4242
include-package-data = true
4343

4444
[tool.setuptools.package-data]
45-
model2vec = ["assets/modelcards/model_card_template.md", "assets/modelcards/classifier_template.md"]
45+
model2vec = [
46+
"assets/modelcards/model_card_template.md",
47+
"assets/modelcards/classifier_template.md",
48+
"py.typed"
49+
]
4650

4751
[project.optional-dependencies]
4852
dev = [

tests/test_model.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,40 @@ def test_load_pretrained_quantized(
207207

208208
# Load the model back from the same path
209209
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float32")
210-
211210
# Assert that the loaded model has the same properties as the original one
212211
assert loaded_model.embedding.dtype == np.float32
213212
assert loaded_model.embedding.shape == mock_vectors.shape
214213

215214

215+
def test_load_pretrained_dim(
216+
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
217+
) -> None:
218+
"""Test loading a pretrained model with dimensionality."""
219+
# Save the model to a temporary path
220+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
221+
save_path = tmp_path / "saved_model"
222+
model.save_pretrained(save_path)
223+
224+
loaded_model = StaticModel.from_pretrained(save_path, dimensionality=2)
225+
226+
# Assert that the loaded model has the same properties as the original one
227+
np.testing.assert_array_equal(loaded_model.embedding, mock_vectors[:, :2])
228+
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
229+
assert loaded_model.config == mock_config
230+
231+
# Load the model back from the same path
232+
loaded_model = StaticModel.from_pretrained(save_path, dimensionality=None)
233+
234+
# Assert that the loaded model has the same properties as the original one
235+
np.testing.assert_array_equal(loaded_model.embedding, mock_vectors)
236+
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
237+
assert loaded_model.config == mock_config
238+
239+
# Load the model back from the same path
240+
with pytest.raises(ValueError):
241+
StaticModel.from_pretrained(save_path, dimensionality=3000)
242+
243+
216244
def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
217245
"""Tests whether the normalization initialization is correct."""
218246
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None)

0 commit comments

Comments
 (0)