Skip to content

Commit d62920e

Browse files
committed
Fix bug in PCA
1 parent 5d32041 commit d62920e

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

model2vec/distill/__main__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,19 @@ def distill(
104104
tokenizer = create_tokenizer_from_vocab(tokens, unk_token="[UNK]", pad_token="[PAD]")
105105

106106
if pca_dims is not None:
107-
if pca_dims < embeddings.shape[1]:
107+
if pca_dims >= embeddings.shape[1]:
108+
raise ValueError(
109+
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]})"
110+
)
111+
if pca_dims >= len(tokens):
112+
logger.warning(
113+
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({len(tokens)}). Not applying PCA."
114+
)
115+
elif pca_dims < embeddings.shape[1]:
108116
logger.info(f"Applying PCA with n_components {pca_dims}")
109117

110118
p = PCA(n_components=pca_dims, whiten=False)
111119
embeddings = p.fit_transform(embeddings)
112-
else:
113-
raise ValueError(
114-
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]})"
115-
)
116120

117121
if apply_zipf:
118122
logger.info("Applying Zipf weighting")

0 commit comments

Comments
 (0)