Skip to content

Commit 8e0be62

Browse files
authored
Merge pull request #19 from MinishLab/fix_pca_bug
Fix pca bug
2 parents 5d32041 + 272c379 commit 8e0be62

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ vocabulary = ["word1", "word2", "word3"]
108108
model_name = "BAAI/bge-base-en-v1.5"
109109

110110
# Distill the model with the custom vocabulary
111-
m2v_model = distill(model_name=model_name, vocabulary=vocabulary, pca_dims=256)
111+
m2v_model = distill(model_name=model_name, vocabulary=vocabulary, pca_dims=None)
112112

113113
# Save the model
114114
m2v_model.save_pretrained("m2v_model")

model2vec/distill/__main__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,19 @@ def distill(
104104
tokenizer = create_tokenizer_from_vocab(tokens, unk_token="[UNK]", pad_token="[PAD]")
105105

106106
if pca_dims is not None:
107-
if pca_dims < embeddings.shape[1]:
107+
if pca_dims >= embeddings.shape[1]:
108+
raise ValueError(
109+
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]})"
110+
)
111+
if pca_dims >= len(tokens):
112+
logger.warning(
113+
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({len(tokens)}). Not applying PCA."
114+
)
115+
elif pca_dims < embeddings.shape[1]:
108116
logger.info(f"Applying PCA with n_components {pca_dims}")
109117

110118
p = PCA(n_components=pca_dims, whiten=False)
111119
embeddings = p.fit_transform(embeddings)
112-
else:
113-
raise ValueError(
114-
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]})"
115-
)
116120

117121
if apply_zipf:
118122
logger.info("Applying Zipf weighting")

0 commit comments

Comments
 (0)