Skip to content

Commit dc431a4

Browse files
committed
Support Enum values in PersistentDict
1 parent 7dc54b1 commit dc431a4

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

pytools/persistent_dict.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
THE SOFTWARE.
2727
"""
2828

29+
from enum import Enum
2930
import logging
3031
import hashlib
3132
import collections.abc as abc
@@ -233,7 +234,8 @@ def rec(self, key_hash, key):
233234
digest = inner_key_hash.digest()
234235

235236
if digest is None:
236-
tname = type(key).__name__
237+
tp = type(key)
238+
tname = tp.__name__
237239
method = None
238240
try:
239241
method = getattr(self, "update_for_"+tname)
@@ -245,6 +247,9 @@ def rec(self, key_hash, key):
245247
if isinstance(key, np.dtype):
246248
method = self.update_for_specific_dtype
247249

250+
elif issubclass(tp, Enum):
251+
method = self.update_for_enum
252+
248253
if method is not None:
249254
inner_key_hash = self.new_hash()
250255
method(inner_key_hash, key)
@@ -287,6 +292,10 @@ def update_for_int(key_hash, key):
287292
except OverflowError:
288293
sz *= 2
289294

295+
@classmethod
296+
def update_for_enum(cls, key_hash, key):
297+
cls.update_for_str(key_hash, str(key))
298+
290299
@staticmethod
291300
def update_for_bool(key_hash, key):
292301
key_hash.update(str(key).encode("utf8"))

0 commit comments

Comments
 (0)