Skip to content

Commit bc0a489

Browse files
author
Ram Idavalapati
authored
Merge pull request #38 from deep-compute/combined_get_nearest
bug fixes in get_nearest; modified logic for combined get_nearest
2 parents 22f33dd + 5c09ce6 commit bc0a489

5 files changed

Lines changed: 88 additions & 33 deletions

File tree

.travis.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,20 @@ language: python
22
python:
33
- '3.5'
44
before_install:
5-
- wget 'https://s3.amazonaws.com/deepcompute-public-data/wordvecspace/small_test_data.tgz'
6-
&& tar xvzf small_test_data.tgz
7-
- export WORDVECSPACE_DATADIR='small_test_data'
5+
- wget 'https://s3.amazonaws.com/deepcompute-public-data/wordvecspace/test_data-0_5_4.tgz' && tar xvzf test_data-0_5_4.tgz
6+
- export WORDVECSPACE_DATADIR='test_data-0_5_4.tgz'
87
install:
8+
- sudo apt update
99
- pip install .[service]
10-
- sudo apt install libopenblas-base
1110
script:
1211
- echo "No tests for this release"
1312
deploy:
1413
- provider: releases
1514
skip_cleanup: true
1615
api-key:
1716
secure: LmVvlW+FdYNIDlinjJ4sieONrcx1jaw18J7/mpHBD9ppIWZ+TB6H/iNqkqkh4WvULZttJrTHRYE6rQHXww7KK2UMrjVNE/TVUPaLFDeRRFvLDinAbqJkn+QJia0TuRa/26Bg9cDcvNYTghy7s37xpK2bJTEMF/eCM9b9RHYXilESYy8Z4l8IkFn5vnaDDfT5iV8xjuuOE4lsf4KC3L0xXIkYnKC/LbDVDj3B9h52TpsteL6cZtn/ExAThor5SrVymW7oMR1qrPQv8btNAdxymqJvEbjaP5RUuX7ehihev0Yge47A2X9gvxDRv+a6wM0HOvT4aGsMwCWo++fb0taWH7HUXFxSvkzKhsl74kDMmnE0WarcI/8L/3Q/zRhW1a2vAtj3O0FDHtzS/OK/k3TDk6Fh/LOvk2mTuGD3L34YxJrXxDxnt4tK2ubde8cGeA7pI5jRLNTNQXUip6Dxhr/5ZnMmG2nHI6ujjmDnucE+CHBtUmS1wjBn6ootE4pdoyti0aaA9OrVoGrf39pK7FAG38KJghqn8I3YCLoeapWjI4/DI0WIfq2Vl+v6yQar3Dn9lBLpWFLrjUmZnAx2F1e0P2y0VUg9hl0bINzIIrm2mHw4Zsl2GlMVSR033cwvcbdyeNxKMAfSV3EZBDpNuI6nlkkUZG1O72N/WV+kFRtSdQA=
18-
name: wordvecspace-0.5.3
19-
tag_name: 0.5.3
17+
name: wordvecspace-0.5.4
18+
tag_name: 0.5.4
2019
on:
2120
repo: deep-compute/wordvecspace
2221
# pypitest

README.md

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# WordVecSpace
22
A high performance pure python module that helps in loading and performing operations on word vector spaces created using Google's Word2vec tool.
33

4-
This module has ability to the load data into memory using `WordVecSpaceMem` and it can also support performing operations on the data which is on the disk using `WordVecSpaceAnnoy` and `WordVecSpaceDisk`.
4+
This module has ability to the load data into memory using `WordVecSpaceMem` and it can also support performing operations on the data which is on the disk using `WordVecSpaceAnnoy` and `WordVecSpaceDisk`.
55

66
## Installation
77
> Prerequisites: >=Python3.5.2
88
99
```bash
10+
1011
$ sudo apt install libopenblas-base # Optional
1112
$ sudo pip3 install wordvecspace
1213
```
@@ -21,8 +22,8 @@ word vector space data. Here are two ways to get that.
2122
#### Download pre-computed sample data
2223

2324
```bash
24-
$ wget https://s3.amazonaws.com/deepcompute-public-data/wordvecspace/small_test_data.tgz
25-
$ tar zxvf small_test_data.tgz
25+
$ wget https://s3.amazonaws.com/deepcompute-public-data/wordvecspace/test_data-0_5_4.tgz
26+
$ tar zxvf test_data-0_5_4.tgz
2627
```
2728

2829
> NOTE: We got this data by downloading the `text8` corpus
@@ -193,6 +194,7 @@ wordvecspace.exception.UnknownWord: "inidia"
193194
>>> print(wv.get_indices(['the', 'deepcompute', 'india']))
194195
[1, None, 509]
195196

197+
196198
>>> print(wv.get_indices(['the', 'deepcompute', 'india'], raise_exc=True))
197199
Traceback (most recent call last):
198200
File "/usr/lib/python3.6/code.py", line 91, in runcode
@@ -342,8 +344,33 @@ wordvecspace.exception.UnknownWord: "inidia"
342344
[[3844, 16727, 15811, 42731, 41516], [509, 3389, 486, 523, 7125]]
343345

344346
# Get common nearest neighbors among given words
345-
>>> print(wv.get_nearest(['india', 'bosnia'], 10, combination=True))
346-
[523, 509, 486]
347+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10)[0])
348+
['india', 'indian', 'delhi', 'subcontinent', 'hyderabad', 'pradesh', 'pakistan', 'gujarat', 'bombay', 'chhattisgarh']
349+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10)[1])
350+
['pakistan', 'pakistani', 'india', 'bangladesh', 'peshawar', 'afghanistan', 'baluchistan', 'balochistan', 'kashmir', 'islamabad']
351+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True)[0])
352+
['pakistan', 'india', 'indian', 'bangladesh', 'pakistani', 'subcontinent', 'shimla', 'delhi', 'punjab', 'ladakh']
353+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[1, 0])[0])
354+
['india', 'indian', 'delhi', 'subcontinent', 'hyderabad', 'pradesh', 'pakistan', 'gujarat', 'bombay', 'chhattisgarh']
355+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[0, 1])[0])
356+
['pakistan', 'pakistani', 'india', 'bangladesh', 'peshawar', 'afghanistan', 'baluchistan', 'balochistan', 'kashmir', 'islamabad']
357+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[0.7, 0.3])[0])
358+
['india', 'pakistan', 'indian', 'subcontinent', 'delhi', 'bangladesh', 'hyderabad', 'shimla', 'punjab', 'bengal']
359+
>>> wv.get_words(wv.get_nearest(['india', 'pakistan'], 10, combination=True, weights=[0.3, 0.7])[0])
360+
['pakistan', 'india', 'pakistani', 'bangladesh', 'subcontinent', 'indian', 'shimla', 'punjab', 'kashmir', 'ladakh']
361+
362+
# Get nearest with vector(s)
363+
>>> wv.get_words(wv.get_nearest(wv.get_vector('india').reshape(1, wv.dim), k=5))
364+
['india', 'indian', 'subcontinent', 'bombay', 'bengal']
365+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5)[0])
366+
['india', 'indian', 'subcontinent', 'bombay', 'bengal']
367+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5)[1])
368+
['pakistan', 'pakistani', 'kargil', 'afghanistan', 'bangladesh']
369+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5, combination=True)[0])
370+
['india', 'pakistan', 'indian', 'pakistani', 'subcontinent']
371+
>>> wv.get_words(wv.get_nearest(wv.get_vectors(['india', 'pakistan']), k=5, combination=True, weights=[0.4, 0.6])[0])
372+
['pakistan', 'india', 'pakistani', 'kargil', 'indian']
373+
347374
```
348375

349376
## Service

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup, find_packages
22

3-
version = '0.5.3'
3+
version = '0.5.4'
44
setup(
55
name="wordvecspace",
66
python_requires='>3.5.1',

test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from wordvecspace import mem
88
from wordvecspace import annoy
99
from wordvecspace import disk
10+
from wordvecspace import wvspace
1011

1112
def suite_test():
1213
suite = unittest.TestSuite()
@@ -15,6 +16,7 @@ def suite_test():
1516
suite.addTests(doctest.DocTestSuite(mem))
1617
suite.addTests(doctest.DocTestSuite(annoy))
1718
suite.addTests(doctest.DocTestSuite(disk))
19+
suite.addTests(doctest.DocTestSuite(wvspace))
1820

1921
return suite
2022

@@ -23,3 +25,4 @@ def suite_test():
2325
doctest.testmod(mem)
2426
doctest.testmod(annoy)
2527
doctest.testmod(disk)
28+
doctest.testmod(wvspace)

wordvecspace/wvspace.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# $export WORDVECSPACE_DATADIR=/path/to/data/
1717
DATAFILE_ENV_VAR = os.environ.get('WORDVECSPACE_DATADIR', '')
1818

19+
1920
class WordVecSpace(WordVecSpaceBase):
2021
METRIC = 'cosine'
22+
DEFAULT_K = 512
2123

2224
def __init__(self, input_dir: str, metric: str=METRIC) -> None:
2325
self._f = WordVecSpaceFile(input_dir, mode='r')
@@ -40,6 +42,7 @@ def _make_array(self, shape, dtype):
4042
def _check_index_or_word(self, item):
4143
if isinstance(item, str):
4244
return self.get_index(item)
45+
4346
return item
4447

4548
def _check_indices_or_words(self, items):
@@ -54,14 +57,17 @@ def _check_indices_or_words(self, items):
5457
if isinstance(w, (list, tuple)):
5558
if isinstance(w[0], str):
5659
return self.get_indices(w)
60+
5761
return w
5862

5963
def _check_vec(self, v, normalised=False):
60-
if isinstance(v, np.ndarray) and len(v.shape) == 2 and v.dtype==np.float32:
64+
if isinstance(v, np.ndarray) and len(v.shape) == 2 and v.dtype == np.float32:
6165
if normalised:
6266
m = np.linalg.norm(v)
6367
return v / m
68+
6469
return v
70+
6571
else:
6672
if isinstance(v, (list, tuple)):
6773
return self.get_vectors(v, normalized=normalised)
@@ -133,8 +139,8 @@ def get_vectors(self, words_or_indices: list, normalized: bool=False) -> np.ndar
133139

134140
return np.multiply(vecs.T, mags).T
135141

136-
def get_distance(self, word_or_index1: Union[int, str],\
137-
word_or_index2: Union[int, str], metric: str='cosine') -> float:
142+
def get_distance(self, word_or_index1: Union[int, str],
143+
word_or_index2: Union[int, str], metric: str='cosine') -> float:
138144

139145
w1 = word_or_index1
140146
w2 = word_or_index2
@@ -167,8 +173,10 @@ def _check_r_and_c(self, r, c, m):
167173

168174
return m, r, c
169175

170-
def get_distances(self, row_words_or_indices: Union[list, np.ndarray],\
171-
col_words_or_indices: Union[list, None, np.ndarray]=None, metric=None) -> np.ndarray:
176+
def get_distances(self,
177+
row_words_or_indices: Union[list, np.ndarray],
178+
col_words_or_indices: Union[list, None, np.ndarray]=None,
179+
metric=None) -> np.ndarray:
172180

173181
r = row_words_or_indices
174182
c = col_words_or_indices
@@ -186,7 +194,6 @@ def get_distances(self, row_words_or_indices: Union[list, np.ndarray],\
186194
nvecs, dim = col_vectors.shape
187195

188196
vec_out = self._make_array((len(col_vectors), len(row_vectors)), dtype=np.float32)
189-
190197
res = self._perform_sgemv(row_vectors, col_vectors, vec_out, nvecs, dim)
191198

192199
else:
@@ -205,33 +212,52 @@ def get_distances(self, row_words_or_indices: Union[list, np.ndarray],\
205212

206213
return distance.cdist(row_vectors, col_vectors, 'euclidean')
207214

208-
DEFAULT_K = 512
209-
210-
def get_nearest(self, v_w_i: list, k: int=DEFAULT_K,\
211-
distances: bool=False, combination: bool=False,\
212-
metric: str='cosine') -> np.ndarray:
213-
214-
d = self.get_distances(v_w_i, metric=metric)
215+
def _nearest_sorting(self, d, k):
215216

216217
ner = self._make_array(shape=(len(d), k), dtype=np.uint32)
217218
dist = self._make_array(shape=(len(d), k), dtype=np.float32)
218219

219220
for index, p in enumerate(d):
221+
# FIXME: better variable name for b_sort
220222
b_sort = bottleneck.argpartition(p, k)[:k]
221-
pr_dist = np.take(d, b_sort)
223+
pr_dist = np.take(p, b_sort)
222224

225+
# FIXME: better variable name for a_sorted
223226
a_sorted = np.argsort(pr_dist)
224227
indices = np.take(b_sort, a_sorted)
225228

226229
ner[index] = indices
227230
dist[index] = np.take(p, indices)
228231

229-
if combination:
230-
ner = set(ner[0]).intersection(*ner)
231-
return (ner, dist) if distances else ner
232+
return ner, dist
233+
234+
def get_nearest(self, v_w_i: list,
235+
k: int=DEFAULT_K,
236+
distances: bool=False,
237+
combination: bool=False,
238+
weights: list=None,
239+
metric: str='cosine') -> np.ndarray:
240+
241+
d = self.get_distances(v_w_i, metric=metric)
242+
243+
if not weights:
244+
weights = np.ones(len(v_w_i))
245+
246+
if combination and len(weights) == len(v_w_i):
247+
weights = np.array(weights)
248+
w_d = np.dot(weights, d)
249+
nearest_indices, dist = self._nearest_sorting(w_d.reshape(1, len(w_d)), k)
250+
251+
if distances:
252+
return nearest_indices, dist
253+
254+
else:
255+
return nearest_indices
256+
257+
nearest_indices, dist = self._nearest_sorting(d, k)
258+
259+
if isinstance(v_w_i, (list, tuple)) or isinstance(v_w_i, np.ndarray) and len(v_w_i) > 1:
260+
return (nearest_indices, dist) if distances else nearest_indices
232261

233-
if isinstance(v_w_i, (list, tuple)) or \
234-
isinstance(v_w_i, np.ndarray) and len(v_w_i) > 1:
235-
return (ner, dist) if distances else ner
236262
else:
237-
return (ner[0], dist[0]) if distances else ner[0]
263+
return (nearest_indices[0], dist[0]) if distances else nearest_indices[0]

0 commit comments

Comments
 (0)