Skip to content

Commit 46cba79

Browse files
committed
Updated docstrings
1 parent 5dfdcd1 commit 46cba79

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def distill_from_model(
6060
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
6161
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
6262
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
63-
:param pooling: The pooling strategy to use for creating embeddings. Can be one of "mean", "last", or "cls".
63+
:param pooling: The pooling strategy to use for creating embeddings. Can be one of "mean" (default), "last", "first", or "pooler".
6464
:return: A StaticModel
6565
:raises: ValueError if the vocabulary is empty after preprocessing.
6666
@@ -259,7 +259,7 @@ def distill(
259259
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
260260
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
261261
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
262-
:param pooling: The pooling strategy to use for creating embeddings. Can be one of "mean", "last", or "cls".
262+
:param pooling: The pooling strategy to use for creating embeddings. Can be one of "mean" (default), "last", "first", or "pooler".
263263
:return: A StaticModel
264264
265265
"""

model2vec/distill/inference.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,22 +218,30 @@ def post_process_embeddings(
218218
if pca_dims > embeddings.shape[1]:
219219
logger.warning(
220220
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
221-
"Applying PCA, but not reducing dimensionality. If this is not desired, set `pca_dims` to None."
221+
"Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
222+
"Applying PCA will probably improve performance, so consider just leaving it."
222223
)
223224
pca_dims = embeddings.shape[1]
224225
if pca_dims >= embeddings.shape[0]:
225226
logger.warning(
226227
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
227228
)
228229
elif pca_dims <= embeddings.shape[1]:
230+
if isinstance(pca_dims, float):
231+
logger.info(f"Applying PCA with {pca_dims} explained variance.")
232+
else:
233+
logger.info(f"Applying PCA with n_components {pca_dims}")
234+
229235
orig_dims = embeddings.shape[1]
230236
p = PCA(n_components=pca_dims, svd_solver="full")
231237
embeddings = p.fit_transform(embeddings)
238+
232239
if embeddings.shape[1] < orig_dims:
233-
logger.info(
234-
f"Reduced dimensionality {orig_dims} -> {embeddings.shape[1]} "
235-
f"(explained var ratio: {np.sum(p.explained_variance_ratio_):.3f})."
236-
)
240+
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
241+
explained_variance = np.sum(p.explained_variance_)
242+
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
243+
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
244+
logger.info(f"Explained variance: {explained_variance:.3f}.")
237245

238246
if sif_coefficient is not None:
239247
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")

0 commit comments

Comments
 (0)