Skip to content

Commit 01317c0

Browse files
committed
wip
1 parent 93b41ea commit 01317c0

3 files changed

Lines changed: 69 additions & 35 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def distill_from_model(
3030
token_remove_pattern: str | None = r"\[unused\d+\]",
3131
quantize_to: DType | str = DType.Float16,
3232
use_subword: bool | None = None,
33+
vocabulary_quantization: int | None = None,
3334
) -> StaticModel:
3435
"""
3536
Distill a staticmodel from a sentence transformer.
@@ -113,14 +114,21 @@ def distill_from_model(
113114
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
114115
)
115116

116-
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
117-
km = KMeans(4096, random_state=42)
118-
km.fit(embeddings)
119-
clustered_embeddings = km.predict(embeddings)
120-
mapping = {idx: x for idx, x in enumerate(clustered_embeddings)}
117+
if vocabulary_quantization is not None:
118+
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
119+
km = KMeans(vocabulary_quantization, random_state=42)
120+
km.fit(embeddings)
121+
clustered_embeddings = km.predict(embeddings)
122+
mapping = {idx: x for idx, x in enumerate(clustered_embeddings)}
121123

122-
embeddings = km.cluster_centers_
123-
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
124+
embeddings = km.cluster_centers_
125+
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
126+
else:
127+
# Post-process the embeddings.
128+
embeddings, weights = post_process_embeddings(
129+
np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient
130+
)
131+
mapping = {idx: token.form for idx, token in enumerate(all_tokens)}
124132
# Quantize the embeddings.
125133
embeddings = quantize_embeddings(embeddings, quantize_to)
126134

@@ -219,6 +227,7 @@ def distill(
219227
trust_remote_code: bool = False,
220228
quantize_to: DType | str = DType.Float16,
221229
use_subword: bool | None = None,
230+
vocabulary_quantization: int | None = None,
222231
) -> StaticModel:
223232
"""
224233
Distill a staticmodel from a sentence transformer.
@@ -263,4 +272,5 @@ def distill(
263272
sif_coefficient=sif_coefficient,
264273
quantize_to=quantize_to,
265274
use_subword=use_subword,
275+
vocabulary_quantization=vocabulary_quantization,
266276
)

model2vec/hf_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def save_pretrained(
2424
config: dict[str, Any],
2525
create_model_card: bool = True,
2626
subfolder: str | None = None,
27+
weights: np.ndarray | None = None,
2728
**kwargs: Any,
2829
) -> None:
2930
"""
@@ -39,7 +40,12 @@ def save_pretrained(
3940
"""
4041
folder_path = folder_path / subfolder if subfolder else folder_path
4142
folder_path.mkdir(exist_ok=True, parents=True)
42-
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
43+
44+
model_weights = {"embeddings": embeddings}
45+
if weights is not None:
46+
model_weights["weights"] = weights
47+
48+
save_file(model_weights, folder_path / "model.safetensors")
4349
tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
4450
json.dump(config, open(folder_path / "config.json", "w"), indent=4)
4551

@@ -99,7 +105,7 @@ def load_pretrained(
99105
subfolder: str | None = None,
100106
token: str | None = None,
101107
from_sentence_transformers: bool = False,
102-
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
108+
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any], np.ndarray | None]:
103109
"""
104110
Loads a pretrained model from a folder.
105111
@@ -177,8 +183,14 @@ def load_pretrained(
177183
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
178184
if from_sentence_transformers:
179185
embeddings = opened_tensor_file.get_tensor("embedding.weight")
186+
weights = None
180187
else:
181188
embeddings = opened_tensor_file.get_tensor("embeddings")
189+
try:
190+
weights = opened_tensor_file.get_tensor("weights")
191+
except Exception:
192+
# Bare except because safetensors does not export its own errors.
193+
weights = None
182194

183195
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
184196
config = json.load(open(config_path))
@@ -188,7 +200,7 @@ def load_pretrained(
188200
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
189201
)
190202

191-
return embeddings, tokenizer, config, metadata
203+
return embeddings, tokenizer, config, metadata, weights
192204

193205

194206
def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:

model2vec/model.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class StaticModel:
2424
def __init__(
2525
self,
2626
vectors: np.ndarray,
27-
weights: np.ndarray,
27+
weights: np.ndarray | None,
2828
token_mapping: dict[int, int],
2929
tokenizer: Tokenizer,
3030
config: dict[str, Any] | None = None,
@@ -107,6 +107,8 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
107107
"""
108108
from model2vec.hf_utils import save_pretrained
109109

110+
self.config["token_mapping"] = list(self.token_mapping.items())
111+
110112
save_pretrained(
111113
folder_path=Path(path),
112114
embeddings=self.embedding,
@@ -116,6 +118,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
116118
language=self.language,
117119
model_name=model_name,
118120
subfolder=subfolder,
121+
weights=self.weights,
119122
)
120123

121124
def tokenize(self, sentences: Sequence[str], max_length: int | None = None) -> list[list[int]]:
@@ -131,8 +134,6 @@ def tokenize(self, sentences: Sequence[str], max_length: int | None = None) -> l
131134
m = max_length * self.median_token_length
132135
sentences = [sentence[:m] for sentence in sentences]
133136

134-
max_len = max([len(sentence) for sentence in sentences])
135-
# self.tokenizer.model.max_input_chars_per_word = max_len + 1
136137
if self._can_encode_fast:
137138
encodings: list[Encoding] = self.tokenizer.encode_batch_fast(sentences, add_special_tokens=False)
138139
else:
@@ -159,6 +160,7 @@ def from_pretrained(
159160
subfolder: str | None = None,
160161
quantize_to: str | DType | None = None,
161162
dimensionality: int | None = None,
163+
vocabulary_quantization: int | None = None,
162164
) -> StaticModel:
163165
"""
164166
Load a StaticModel from a local path or huggingface hub path.
@@ -178,36 +180,45 @@ def from_pretrained(
178180
"""
179181
from model2vec.hf_utils import load_pretrained
180182

181-
embeddings, tokenizer, config, metadata = load_pretrained(
183+
embeddings, tokenizer, config, metadata, weights = load_pretrained(
182184
folder_or_repo_path=path,
183185
token=token,
184186
from_sentence_transformers=False,
185187
subfolder=subfolder,
186188
)
187189

188-
weights = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-32
189-
embeddings = embeddings / weights
190-
191-
"""from sklearn.cluster import KMeans
192-
from sklearn.decomposition import PCA
193-
km = KMeans(n_clusters=4096, random_state=0)
194-
km.fit(embeddings)
195-
# Do PCA again?
196-
assignments = km.predict(embeddings)
197-
embeddings = km.cluster_centers_
198-
199-
p = PCA(n_components=dimensionality)
200-
embeddings = p.fit_transform(embeddings)
201-
202-
token_mapping = {i: x for i, x in enumerate(assignments)}"""
203-
token_mapping = {i: i for i in range(len(embeddings))}
204-
205190
embeddings = quantize_and_reduce_dim(
206191
embeddings=embeddings,
207192
quantize_to=quantize_to,
208193
dimensionality=dimensionality,
209194
)
210195

196+
if vocabulary_quantization is not None:
197+
if len(embeddings) != len(tokenizer.get_vocab()):
198+
raise ValueError(
199+
"Already quantized. "
200+
)
201+
202+
if weights is None:
203+
weights = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-32
204+
embeddings = embeddings / weights
205+
206+
# Quantize the vocabulary
207+
from sklearn.cluster import KMeans
208+
kmeans = KMeans(n_clusters=vocabulary_quantization, random_state=42)
209+
kmeans.fit(embeddings)
210+
token_mapping = {idx: x for idx, x in enumerate(kmeans.predict(embeddings))}
211+
embeddings = kmeans.cluster_centers_
212+
213+
else:
214+
token_mapping = config.pop("token_mapping", None)
215+
if isinstance(token_mapping, list):
216+
# If the token mapping is a list, convert it to a dict
217+
token_mapping = {int(k): int(v) for k, v in token_mapping}
218+
elif token_mapping is None:
219+
# If no token mapping is provided, use the default mapping
220+
token_mapping = {i: i for i in range(len(embeddings))}
221+
211222
return cls(
212223
embeddings,
213224
weights,
@@ -245,7 +256,7 @@ def from_sentence_transformers(
245256
"""
246257
from model2vec.hf_utils import load_pretrained
247258

248-
embeddings, tokenizer, config, metadata = load_pretrained(
259+
embeddings, tokenizer, config, metadata, weights = load_pretrained(
249260
folder_or_repo_path=path,
250261
token=token,
251262
from_sentence_transformers=True,
@@ -258,9 +269,10 @@ def from_sentence_transformers(
258269
dimensionality=dimensionality,
259270
)
260271

261-
weights = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-32
262-
embeddings = embeddings / weights
263-
token_mapping = {i: i for i in range(len(embeddings))}
272+
token_mapping = config.pop("token_mapping", None)
273+
if token_mapping is None:
274+
# If no token mapping is provided, use the default mapping
275+
token_mapping = {i: i for i in range(len(embeddings))}
264276

265277
return cls(
266278
embeddings,

0 commit comments

Comments
 (0)