Skip to content

Commit b4a0934

Browse files
committed
Fix up keys method
1 parent d2d0170 commit b4a0934

2 files changed

Lines changed: 54 additions & 7 deletions

File tree

cdblib/compat.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,13 @@ def finish(self):
8585
class cdb:
8686
def __init__(self, f, encoding='utf-8'):
8787
self._file_path = f
88+
8889
self.encoding = encoding
90+
strict = not bool(encoding)
8991

9092
self._file_obj = open(self._file_path, mode='rb')
9193
self._mmap_obj = mmap(self._file_obj.fileno(), 0, access=ACCESS_READ)
92-
self._reader = Reader(self._mmap_obj)
94+
self._reader = Reader(self._mmap_obj, strict=strict)
9395

9496
self._keys = self._get_key_iterator()
9597
self._items = cycle(chain(self._decoded_items(), [None]))
@@ -180,7 +182,7 @@ def getall(self, k):
180182
for value in self._reader.gets(k):
181183
try:
182184
value = value.decode(self.encoding)
183-
except (AttributeError, UnicodeDecodeError):
185+
except (AttributeError, UnicodeDecodeError, TypeError):
184186
value = value
185187
ret_append(value)
186188

@@ -189,7 +191,9 @@ def getall(self, k):
189191
def keys(self):
190192
"""Return a list of the distinct keys stored in the database.
191193
"""
192-
return self._reader.keys()
194+
all_keys = (k for k, v in self._decoded_items())
195+
unique_keys = self._unique_keys(all_keys)
196+
return list(unique_keys)
193197

194198
@property
195199
def name(self):

tests/compat_test.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,31 @@ def setUp(self):
3333
def tearDown(self):
3434
rmtree(self.temp_dir, ignore_errors=False)
3535

36-
def _get_reader(self):
36+
def _get_reader(self, **kwargs):
3737
self.db.finish()
38-
return cdb.init(self.cdb_path.encode('utf-8'))
38+
return cdb.init(self.cdb_path.encode('utf-8'), **kwargs)
3939

4040
def test_add(self):
4141
self.db.add('a', '4')
4242
self.assertEqual(self.db.numentries, 6)
4343
self.db.add('a', '4')
4444
self.assertEqual(self.db.numentries, 7)
4545

46+
with self.assertRaises(TypeError):
47+
self.db.add(1, '1')
48+
4649
def test_numentries(self):
4750
self.assertEqual(self.db.numentries, 5)
4851
self.db.finish()
4952
self.assertEqual(self.db.numentries, 5)
5053

54+
def test_cdbmake_fd(self):
55+
self.assertTrue(isinstance(self.db.fd, int))
56+
5157
def test_finish(self):
5258
self.db.finish()
5359
self.assertFalse(exists(self.tmp_path))
60+
self.db.finish()
5461

5562
def test_get(self):
5663
reader = self._get_reader()
@@ -100,10 +107,15 @@ def test_nextkey(self):
100107
self.assertIsNone(reader.nextkey())
101108
self.assertIsNone(reader.nextkey())
102109

103-
def test_name_size(self):
110+
def test_keys(self):
111+
reader = self._get_reader()
112+
self.assertEqual(reader.keys(), ['a', 'b', 'c'])
113+
114+
def test_name_size_fd(self):
104115
reader = self._get_reader()
105116
self.assertEqual(reader.name.decode('utf-8'), self.cdb_path)
106117
self.assertEqual(reader.size, 2178)
118+
self.assertTrue(isinstance(reader.fd, int))
107119

108120

109121
@unittest.skipIf(not test_cdb, 'Tests for Python 2 module')
@@ -113,7 +125,38 @@ class PythonCDBTests(CompatTests, unittest.TestCase):
113125

114126
@unittest.skipIf(test_cdb, 'Tests for Python 3 module')
115127
class PythonPureCDBTests(CompatTests, unittest.TestCase):
116-
pass
128+
def test_cdbmake_cleanup(self):
129+
# Cleanup after close - no exception
130+
self.db.finish()
131+
self.db._cleanup()
132+
133+
# Exception during cleanup - we soldier on
134+
self.db._temp_obj = None
135+
self.db._cleanup()
136+
137+
def test_add_after_finish(self):
138+
self.db.finish()
139+
with self.assertRaises(cdb.error):
140+
self.db.add('d', '1')
141+
142+
def test_cdb_cleanup(self):
143+
# Cleanup after close - no exception
144+
reader = self._get_reader()
145+
reader._mmap_obj.close()
146+
reader._file_obj.close()
147+
reader._cleanup()
148+
149+
reader._mmap_obj = None
150+
reader._file_obj = None
151+
152+
# Exception during cleanup - we soldier on
153+
reader._cleanup()
154+
155+
def test_no_encoding(self):
156+
reader = self._get_reader(encoding=None)
157+
self.assertEqual(reader.get(b'a'), b'1')
158+
self.assertEqual(reader.getall(b'a'), [b'1', b'2', b'\x80'])
159+
self.assertEqual(reader.keys(), [b'a', b'b', b'c'])
117160

118161

119162
if __name__ == '__main__':

0 commit comments

Comments
 (0)