Skip to content

Commit 8d0c948

Browse files
committed
bug fixes in get_nearest; modified logic for combined get_nearest; pep8 fixes
1 parent 22f33dd commit 8d0c948

4 files changed

Lines changed: 86 additions & 30 deletions

File tree

README.md

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# WordVecSpace
22
A high performance pure python module that helps in loading and performing operations on word vector spaces created using Google's Word2vec tool.
33

4-
This module has ability to the load data into memory using `WordVecSpaceMem` and it can also support performing operations on the data which is on the disk using `WordVecSpaceAnnoy` and `WordVecSpaceDisk`.
4+
This module has ability to the load data into memory using `WordVecSpaceMem` and it can also support performing operations on the data which is on the disk using `WordVecSpaceAnnoy` and `WordVecSpaceDisk`.
55

66
## Installation
7-
> Prerequisites: >=Python3.5.2
7+
> Prerequisites: Python3.5.2
88
99
```bash
10+
1011
$ sudo apt install libopenblas-base # Optional
1112
$ sudo pip3 install wordvecspace
1213
```
@@ -21,8 +22,8 @@ word vector space data. Here are two ways to get that.
2122
#### Download pre-computed sample data
2223

2324
```bash
24-
$ wget https://s3.amazonaws.com/deepcompute-public-data/wordvecspace/small_test_data.tgz
25-
$ tar zxvf small_test_data.tgz
25+
$ wget https://s3.amazonaws.com/deepcompute-public-data/wordvecspace/test_data-0_5_4.tgz
26+
$ tar test_data-0_5_4.tgz
2627
```
2728

2829
> NOTE: We got this data by downloading the `text8` corpus
@@ -41,7 +42,7 @@ $ git clone https://github.com/tmikolov/word2vec.git
4142

4243
# 1. Navigate to the folder word2vec
4344
# 2. open demo-word.sh for editing
44-
# 3. Edit the command "time ./word2vec -train text8 -output vectors.bin -cbow 1 -size 200 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -binary 1 -iter 15" ----to----> "time ./word2vec -train text8 -output vectors.bin -cbow 1 -size 5 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -binary 1 -save-vocab vocab.txt -iter 15" to get vocab.txt file also as output.
45+
# 3. Edit the command "time ./word2vec -train text8 -output vectors.bin -cbow 1 -size 200 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -binary 1 -iter 15" ----to----> "time ./word2vec -train text8 -output vectors.bin -cbow 1 -size 5 -window 8 -negative 25 -hs 0 -sample 1e-4 -threads 20 -binary 1 -save-vocab vocab.txt -iter 15" to get vocab.txt file also as output.
4546
# 4. Run demo-word.sh
4647

4748
$ chmod +x demo-word.sh
@@ -96,7 +97,7 @@ $ wordvecspace convert /home/user/bindata /home/user/output_dir
9697

9798
`WordVecSpaceMem` and `WordVecSpaceDisk` is a bruteforce algorithm which compares given word with all the words in the vector space
9899

99-
`WordVecSpaceAnnoy` takes wordvecspace output_dir as input and creates annoy indexes in another file (index file). Using this file `annoy` gives approximate results quickly. For better understanding of `Annoy` please go through this [link](https://github.com/spotify/annoy)
100+
`WordVecSpaceAnnoy` takes wordvecspace output_dir as input and creates annoy indexes in another file (index file). Using this file `annoy` gives approximate results quickly. For better understanding of `Annoy` please go through this [link](https://github.com/spotify/annoy)
100101

101102
As we have seen how to import `WordVecSpaceDisk` above, let us look at `WordVecSpaceAnnoy` and `WordVecSpaceMem`
102103

@@ -193,6 +194,7 @@ wordvecspace.exception.UnknownWord: "inidia"
193194
>>> print(wv.get_indices(['the', 'deepcompute', 'india']))
194195
[1, None, 509]
195196

197+
196198
>>> print(wv.get_indices(['the', 'deepcompute', 'india'], raise_exc=True))
197199
Traceback (most recent call last):
198200
File "/usr/lib/python3.6/code.py", line 91, in runcode
@@ -342,8 +344,33 @@ wordvecspace.exception.UnknownWord: "inidia"
342344
[[3844, 16727, 15811, 42731, 41516], [509, 3389, 486, 523, 7125]]
343345

344346
# Get common nearest neighbors among given words
345-
>>> print(wv.get_nearest(['india', 'bosnia'], 10, combination=True))
346-
[523, 509, 486]
347+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10)[0])
348+
['india', 'indian', 'delhi', 'subcontinent', 'hyderabad', 'pradesh', 'pakistan', 'gujarat', 'bombay', 'chhattisgarh']
349+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10)[1])
350+
['pakistan', 'pakistani', 'india', 'bangladesh', 'peshawar', 'afghanistan', 'baluchistan', 'balochistan', 'kashmir', 'islamabad']
351+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True)[0])
352+
['pakistan', 'india', 'indian', 'bangladesh', 'pakistani', 'subcontinent', 'shimla', 'delhi', 'punjab', 'ladakh']
353+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[1, 0])[0])
354+
['india', 'indian', 'delhi', 'subcontinent', 'hyderabad', 'pradesh', 'pakistan', 'gujarat', 'bombay', 'chhattisgarh']
355+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[0, 1])[0])
356+
['pakistan', 'pakistani', 'india', 'bangladesh', 'peshawar', 'afghanistan', 'baluchistan', 'balochistan', 'kashmir', 'islamabad']
357+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[0.7, 0.3])[0])
358+
['india', 'pakistan', 'indian', 'subcontinent', 'delhi', 'bangladesh', 'hyderabad', 'shimla', 'punjab', 'bengal']
359+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[0.3, 0.7])[0])
360+
['pakistan', 'india', 'pakistani', 'bangladesh', 'subcontinent', 'indian', 'shimla', 'punjab', 'kashmir', 'ladakh']
361+
362+
# Get nearest with vector(s)
363+
>>> wv.get_words(wv.get_nearest(wv.get_vector('india').reshape(1, wv.dim), k=5))
364+
['india', 'indian', 'subcontinent', 'bombay', 'bengal']
365+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5)[0])
366+
['india', 'indian', 'subcontinent', 'bombay', 'bengal']
367+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5)[1])
368+
['pakistan', 'pakistani', 'kargil', 'afghanistan', 'bangladesh']
369+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5, combination=True)[0])
370+
['india', 'pakistan', 'indian', 'pakistani', 'subcontinent']
371+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5, combination=True, weights=[0.4, 0.6])[0])
372+
['pakistan', 'india', 'pakistani', 'kargil', 'indian']
373+
347374
```
348375

349376
## Service

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup, find_packages
22

3-
version = '0.5.3'
3+
version = '0.5.4'
44
setup(
55
name="wordvecspace",
66
python_requires='>3.5.1',

test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from wordvecspace import mem
88
from wordvecspace import annoy
99
from wordvecspace import disk
10+
from wordvecspace import wvspace
1011

1112
def suite_test():
1213
suite = unittest.TestSuite()
@@ -15,6 +16,7 @@ def suite_test():
1516
suite.addTests(doctest.DocTestSuite(mem))
1617
suite.addTests(doctest.DocTestSuite(annoy))
1718
suite.addTests(doctest.DocTestSuite(disk))
19+
suite.addTests(doctest.DocTestSuite(wvspace))
1820

1921
return suite
2022

@@ -23,3 +25,4 @@ def suite_test():
2325
doctest.testmod(mem)
2426
doctest.testmod(annoy)
2527
doctest.testmod(disk)
28+
doctest.testmod(wvspace)

wordvecspace/wvspace.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# $export WORDVECSPACE_DATADIR=/path/to/data/
1717
DATAFILE_ENV_VAR = os.environ.get('WORDVECSPACE_DATADIR', '')
1818

19+
1920
class WordVecSpace(WordVecSpaceBase):
2021
METRIC = 'cosine'
22+
DEFAULT_K = 512
2123

2224
def __init__(self, input_dir: str, metric: str=METRIC) -> None:
2325
self._f = WordVecSpaceFile(input_dir, mode='r')
@@ -40,6 +42,7 @@ def _make_array(self, shape, dtype):
4042
def _check_index_or_word(self, item):
4143
if isinstance(item, str):
4244
return self.get_index(item)
45+
4346
return item
4447

4548
def _check_indices_or_words(self, items):
@@ -54,14 +57,17 @@ def _check_indices_or_words(self, items):
5457
if isinstance(w, (list, tuple)):
5558
if isinstance(w[0], str):
5659
return self.get_indices(w)
60+
5761
return w
5862

5963
def _check_vec(self, v, normalised=False):
60-
if isinstance(v, np.ndarray) and len(v.shape) == 2 and v.dtype==np.float32:
64+
if isinstance(v, np.ndarray) and len(v.shape) == 2 and v.dtype == np.float32:
6165
if normalised:
6266
m = np.linalg.norm(v)
6367
return v / m
68+
6469
return v
70+
6571
else:
6672
if isinstance(v, (list, tuple)):
6773
return self.get_vectors(v, normalized=normalised)
@@ -133,8 +139,8 @@ def get_vectors(self, words_or_indices: list, normalized: bool=False) -> np.ndar
133139

134140
return np.multiply(vecs.T, mags).T
135141

136-
def get_distance(self, word_or_index1: Union[int, str],\
137-
word_or_index2: Union[int, str], metric: str='cosine') -> float:
142+
def get_distance(self, word_or_index1: Union[int, str],
143+
word_or_index2: Union[int, str], metric: str='cosine') -> float:
138144

139145
w1 = word_or_index1
140146
w2 = word_or_index2
@@ -167,8 +173,10 @@ def _check_r_and_c(self, r, c, m):
167173

168174
return m, r, c
169175

170-
def get_distances(self, row_words_or_indices: Union[list, np.ndarray],\
171-
col_words_or_indices: Union[list, None, np.ndarray]=None, metric=None) -> np.ndarray:
176+
def get_distances(self,
177+
row_words_or_indices: Union[list, np.ndarray],
178+
col_words_or_indices: Union[list, None, np.ndarray]=None,
179+
metric=None) -> np.ndarray:
172180

173181
r = row_words_or_indices
174182
c = col_words_or_indices
@@ -186,7 +194,6 @@ def get_distances(self, row_words_or_indices: Union[list, np.ndarray],\
186194
nvecs, dim = col_vectors.shape
187195

188196
vec_out = self._make_array((len(col_vectors), len(row_vectors)), dtype=np.float32)
189-
190197
res = self._perform_sgemv(row_vectors, col_vectors, vec_out, nvecs, dim)
191198

192199
else:
@@ -205,33 +212,52 @@ def get_distances(self, row_words_or_indices: Union[list, np.ndarray],\
205212

206213
return distance.cdist(row_vectors, col_vectors, 'euclidean')
207214

208-
DEFAULT_K = 512
209-
210-
def get_nearest(self, v_w_i: list, k: int=DEFAULT_K,\
211-
distances: bool=False, combination: bool=False,\
212-
metric: str='cosine') -> np.ndarray:
213-
214-
d = self.get_distances(v_w_i, metric=metric)
215+
def _nearest_sorting(self, d, k):
215216

216217
ner = self._make_array(shape=(len(d), k), dtype=np.uint32)
217218
dist = self._make_array(shape=(len(d), k), dtype=np.float32)
218219

219220
for index, p in enumerate(d):
221+
# FIXME: better variable name for b_sort
220222
b_sort = bottleneck.argpartition(p, k)[:k]
221-
pr_dist = np.take(d, b_sort)
223+
pr_dist = np.take(p, b_sort)
222224

225+
# FIXME: better variable name for a_sorted
223226
a_sorted = np.argsort(pr_dist)
224227
indices = np.take(b_sort, a_sorted)
225228

226229
ner[index] = indices
227230
dist[index] = np.take(p, indices)
228231

229-
if combination:
230-
ner = set(ner[0]).intersection(*ner)
231-
return (ner, dist) if distances else ner
232+
return ner, dist
233+
234+
def get_nearest(self, v_w_i: list,
235+
k: int=DEFAULT_K,
236+
distances: bool=False,
237+
combination: bool=False,
238+
weights: list=None,
239+
metric: str='cosine') -> np.ndarray:
240+
241+
d = self.get_distances(v_w_i, metric=metric)
242+
243+
if not weights:
244+
weights = np.ones(len(v_w_i))
245+
246+
if combination and len(weights) == len(v_w_i):
247+
weights = np.array(weights)
248+
w_d = np.dot(weights, d)
249+
nearest_indices, dist = self._nearest_sorting(w_d.reshape(1, len(w_d)), k)
250+
251+
if distances:
252+
return nearest_indices, dist
253+
254+
else:
255+
return nearest_indices
256+
257+
nearest_indices, dist = self._nearest_sorting(d, k)
258+
259+
if isinstance(v_w_i, (list, tuple)) or isinstance(v_w_i, np.ndarray) and len(v_w_i) > 1:
260+
return (nearest_indices, dist) if distances else nearest_indices
232261

233-
if isinstance(v_w_i, (list, tuple)) or \
234-
isinstance(v_w_i, np.ndarray) and len(v_w_i) > 1:
235-
return (ner, dist) if distances else ner
236262
else:
237-
return (ner[0], dist[0]) if distances else ner[0]
263+
return (nearest_indices[0], dist[0]) if distances else nearest_indices[0]

0 commit comments

Comments
 (0)