The dataclasses module is a new addition in Python 3.7 that utilizes the type annotation syntax introduced in 3.6 to create fully-functional classes that are used to store data.
Some consider dataclass to be an upgrade from
namedtuple.
In short, the decorator will read the type annotations
and create methods like __init__, __repr__, __eq__, etc.
-
Example 1
(Script available)
from dataclasses import dataclass @dataclass() class Vector2(object): x: float y: float = 0 def __add__(self, other): return Vector2(self.x + other.x, self.y + other.y) if __name__ == '__main__': help(Vector2) v1 = Vector2(x=1) print(v1) v2 = Vector2(1, 0) print(v2) print(v1 == v2) v1.y = 3 print(v1 + v2)
It is worth noting that type annotations will not enforce variable types, and dataclass has no special logic regarding this either. To enforce variable types please see next example.
-
Example 2
(Script available)
import sys from dataclasses import dataclass def enforce_types(data_cls): __setattr = data_cls.__setattr__ def __setattr__(self, key, value): if key in self.__class__.__dataclass_fields__: value = self.__dataclass_fields__[key].type(value) return __setattr(self, key, value) data_cls.__setattr__ = __setattr__ return data_cls def math_dataclass(data_cls): def get_method(name): def __method(self, other): return self.__class__(**{ field_name: getattr(field.type, name)( getattr(self, field_name), getattr(other, field_name), ) for field_name, field in self.__class__.__dataclass_fields__.items() }) return __method for name in ['__add__', '__sub__', '__mul__', '__truediv__']: if hasattr(data_cls, name): continue setattr(data_cls, name, get_method(name)) return data_cls class OperationNotAllowed(Exception): pass @math_dataclass @enforce_types @dataclass class Vector2(object): x: float y: float def __mul__(self, other): return self.x * other.x + self.y * other.y def __truediv__(self, other): raise OperationNotAllowed if __name__ == '__main__': v1 = Vector2(1, 2) print(v1) v1.x = 0 print(v1) v2 = Vector2(1.5, 2.5) print(v1 + v2) print(v1 * v2) try: print(v1 / v2) except Exception as e: print(repr(e), file=sys.stderr)