Skip to content

Commit fb3b2f7

Browse files
committed
Added embedding_dtype and vocabulary_quantization to config
1 parent 5a8578d commit fb3b2f7

2 files changed

Lines changed: 4 additions & 0 deletions

File tree

model2vec/hf_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def save_pretrained(
5353

5454
save_file(model_weights, folder_path / "model.safetensors")
5555
tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
56+
57+
# Add embedding dtype to config
58+
config["embedding_dtype"] = np.dtype(embeddings.dtype).name
5659
json.dump(config, open(folder_path / "config.json", "w"), indent=4)
5760

5861
# Create modules.json

model2vec/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def quantize_model(
506506
embeddings, token_mapping, weights = quantize_vocabulary(
507507
n_clusters=vocabulary_quantization, weights=model.weights, embeddings=model.embedding
508508
)
509+
model.config["vocabulary_quantization"] = vocabulary_quantization
509510
else:
510511
embeddings = model.embedding
511512
token_mapping = model.token_mapping

0 commit comments

Comments
 (0)