11import os
22import json
33from typing import Union
4-
54from scipy .spatial import distance
65import numpy as np
6+ import pandas as pd
77import bottleneck
88
99from .fileformat import WordVecSpaceFile
@@ -60,19 +60,19 @@ def _check_indices_or_words(self, items):
6060
6161 return w
6262
63- def _check_vec (self , v , normalised = False ):
63+ def _check_vec (self , v , normalized = False ):
6464 if isinstance (v , np .ndarray ) and len (v .shape ) == 2 and v .dtype == np .float32 :
65- if normalised :
65+ if normalized :
6666 m = np .linalg .norm (v )
6767 return v / m
6868
6969 return v
7070
7171 else :
7272 if isinstance (v , (list , tuple )):
73- return self .get_vectors (v , normalized = normalised )
73+ return self .get_vectors (v , normalized = normalized )
7474
75- return self .get_vector (v , normalized = normalised )
75+ return self .get_vector (v , normalized = normalized )
7676
7777 def get_manifest (self ) -> dict :
7878 manifest_info = open (os .path .join (self .input_dir , "manifest.json" ), "r" )
@@ -148,6 +148,7 @@ def get_distance(
148148 word_or_index1 : Union [int , str ],
149149 word_or_index2 : Union [int , str ],
150150 metric : str = "cosine" ,
151+ normalized : bool = True ,
151152 ) -> float :
152153
153154 w1 = word_or_index1
@@ -156,9 +157,9 @@ def get_distance(
156157 if not metric :
157158 metric = self .metric
158159
159- if metric == "cosine" or "angular" :
160- vec1 = self ._check_vec (w1 , True )
161- vec2 = self ._check_vec (w2 , True )
160+ if metric in ( "cosine" , "angular" ) :
161+ vec1 = self ._check_vec (w1 , normalized )
162+ vec2 = self ._check_vec (w2 , normalized )
162163
163164 return 1 - np .dot (vec1 , vec2 .T )
164165
@@ -186,19 +187,20 @@ def get_distances(
186187 row_words_or_indices : Union [list , np .ndarray ],
187188 col_words_or_indices : Union [list , None , np .ndarray ] = None ,
188189 metric = None ,
190+ normalized : bool = True ,
189191 ) -> np .ndarray :
190192
191193 r = row_words_or_indices
192194 c = col_words_or_indices
193195
194196 metric , r , c = self ._check_r_and_c (r , c , metric )
195197
196- if metric == "cosine" or "angular" :
197- row_vectors = self ._check_vec (r , True )
198+ if metric in ( "cosine" , "angular" ) :
199+ row_vectors = self ._check_vec (r , normalized )
198200
199201 col_vectors = self .vecs
200202 if c is not None and len (c ):
201- col_vectors = self ._check_vec (c , True )
203+ col_vectors = self ._check_vec (c , normalized )
202204
203205 if len (r ) == 1 :
204206 nvecs , dim = col_vectors .shape
@@ -214,6 +216,10 @@ def get_distances(
214216 )
215217 res = self ._perform_sgemm (row_vectors , col_vectors , mat_out )
216218
219+ if not normalized :
220+ res = np .multiply (res , self .mags )
221+ return res
222+
217223 return 1 - res
218224
219225 elif metric == "euclidean" :
@@ -226,19 +232,25 @@ def get_distances(
226232
227233 return distance .cdist (row_vectors , col_vectors , "euclidean" )
228234
229- def _nearest_sorting (self , d , k ):
235+ def _nearest_sorting (self , d , k , normalized = True ):
230236
231237 ner = self ._make_array (shape = (len (d ), k ), dtype = np .uint32 )
232238 dist = self ._make_array (shape = (len (d ), k ), dtype = np .float32 )
233239
234240 for index , p in enumerate (d ):
235- # FIXME: better variable name for b_sort
236- b_sort = bottleneck .argpartition (p , k )[:k ]
237- pr_dist = np .take (p , b_sort )
241+ if normalized :
242+ # FIXME: better variable name for b_sort
243+ b_sort = bottleneck .argpartition (p , k )[:k ]
244+ pr_dist = np .take (p , b_sort )
238245
239- # FIXME: better variable name for a_sorted
240- a_sorted = np .argsort (pr_dist )
241- indices = np .take (b_sort , a_sorted )
246+ # FIXME: better variable name for a_sorted
247+ a_sorted = np .argsort (pr_dist )
248+ indices = np .take (b_sort , a_sorted )
249+
250+ else :
251+ d = pd .Series (p )
252+ d = d .nlargest (k )
253+ indices = d .keys ()
242254
243255 ner [index ] = indices
244256 dist [index ] = np .take (p , indices )
@@ -253,25 +265,28 @@ def get_nearest(
253265 combination : bool = False ,
254266 weights : list = None ,
255267 metric : str = "cosine" ,
268+ normalized : bool = True ,
256269 ) -> np .ndarray :
257270
258- d = self .get_distances (v_w_i , metric = metric )
271+ d = self .get_distances (v_w_i , metric = metric , normalized = normalized )
259272
260273 if not weights :
261274 weights = np .ones (len (v_w_i ))
262275
263276 if combination and len (weights ) == len (v_w_i ):
264277 weights = np .array (weights )
265278 w_d = np .dot (weights , d )
266- nearest_indices , dist = self ._nearest_sorting (w_d .reshape (1 , len (w_d )), k )
279+ nearest_indices , dist = self ._nearest_sorting (
280+ w_d .reshape (1 , len (w_d )), k , normalized
281+ )
267282
268283 if distances :
269284 return nearest_indices , dist
270285
271286 else :
272287 return nearest_indices
273288
274- nearest_indices , dist = self ._nearest_sorting (d , k )
289+ nearest_indices , dist = self ._nearest_sorting (d , k , normalized )
275290
276291 if (
277292 isinstance (v_w_i , (list , tuple ))
0 commit comments