Skip to content

Commit 8a14dfb

Browse files
add (strict) typing to Record
1 parent 56efa1b commit 8a14dfb

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

pytools/__init__.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ class RecordWithoutPickling:
408408
__slots__: ClassVar[List[str]] = []
409409
fields: ClassVar[Set[str]]
410410

411-
def __init__(self, valuedict=None, exclude=None, **kwargs):
411+
def __init__(self, valuedict: Optional[Dict[str, Any]] = None,
412+
exclude: Optional[List[str]] = None, **kwargs: Any) -> None:
412413
assert self.__class__ is not Record
413414

414415
if exclude is None:
@@ -427,7 +428,7 @@ def __init__(self, valuedict=None, exclude=None, **kwargs):
427428
fields.add(key)
428429
setattr(self, key, value)
429430

430-
def get_copy_kwargs(self, **kwargs):
431+
def get_copy_kwargs(self, **kwargs: Any) -> Any:
431432
for f in self.__class__.fields:
432433
if f not in kwargs:
433434
try:
@@ -436,25 +437,25 @@ def get_copy_kwargs(self, **kwargs):
436437
pass
437438
return kwargs
438439

439-
def copy(self, **kwargs):
440+
def copy(self, **kwargs: Any) -> "RecordWithoutPickling":
440441
return self.__class__(**self.get_copy_kwargs(**kwargs))
441442

442-
def __repr__(self):
443+
def __repr__(self) -> str:
443444
return "{}({})".format(
444445
self.__class__.__name__,
445446
", ".join(f"{fld}={getattr(self, fld)!r}"
446447
for fld in self.__class__.fields
447448
if hasattr(self, fld)))
448449

449-
def register_fields(self, new_fields):
450+
def register_fields(self, new_fields: Set[str]) -> None:
450451
try:
451452
fields = self.__class__.fields
452453
except AttributeError:
453454
self.__class__.fields = fields = set()
454455

455456
fields.update(new_fields)
456457

457-
def __getattr__(self, name):
458+
def __getattr__(self, name: str) -> None:
458459
# This method is implemented to avoid pylint 'no-member' errors for
459460
# attribute access.
460461
raise AttributeError(
@@ -465,13 +466,13 @@ def __getattr__(self, name):
465466
class Record(RecordWithoutPickling):
466467
__slots__: ClassVar[List[str]] = []
467468

468-
def __getstate__(self):
469+
def __getstate__(self) -> Dict[str, Any]:
469470
return {
470471
key: getattr(self, key)
471472
for key in self.__class__.fields
472473
if hasattr(self, key)}
473474

474-
def __setstate__(self, valuedict):
475+
def __setstate__(self, valuedict: Dict[str, Any]) -> None:
475476
try:
476477
fields = self.__class__.fields
477478
except AttributeError:
@@ -481,30 +482,33 @@ def __setstate__(self, valuedict):
481482
fields.add(key)
482483
setattr(self, key, value)
483484

484-
def __eq__(self, other):
485+
def __eq__(self, other: Any) -> bool:
485486
if self is other:
486487
return True
488+
if not isinstance(other, Record):
489+
return False
487490
return (self.__class__ == other.__class__
488491
and self.__getstate__() == other.__getstate__())
489492

490-
def __ne__(self, other):
493+
def __ne__(self, other: Any) -> bool:
491494
return not self.__eq__(other)
492495

493496

494497
class ImmutableRecordWithoutPickling(RecordWithoutPickling):
495498
"""Hashable record. Does not explicitly enforce immutability."""
496-
def __init__(self, *args, **kwargs):
499+
def __init__(self, *args: Any, **kwargs: Any) -> None:
497500
RecordWithoutPickling.__init__(self, *args, **kwargs)
498-
self._cached_hash = None
501+
self._cached_hash: Optional[int] = None
499502

500-
def __hash__(self):
503+
def __hash__(self) -> int:
501504
# This attribute may vanish during pickling.
502505
if getattr(self, "_cached_hash", None) is None:
503506
self._cached_hash = hash(
504507
(type(self),) + tuple(getattr(self, field)
505508
for field in self.__class__.fields))
506509

507-
return self._cached_hash
510+
from typing import cast
511+
return cast(int, self._cached_hash)
508512

509513

510514
class ImmutableRecord(ImmutableRecordWithoutPickling, Record):

0 commit comments

Comments
 (0)