diff --git a/bertopic/backend/_sklearn.py b/bertopic/backend/_sklearn.py index d8150fe6..8e60f823 100644 --- a/bertopic/backend/_sklearn.py +++ b/bertopic/backend/_sklearn.py @@ -65,4 +65,4 @@ def embed(self, documents, verbose=False): except NotFittedError: embeddings = self.pipe.fit_transform(documents) - return embeddings + return embeddings.toarray() if hasattr(embeddings, "toarray") else embeddings