|
19 | 19 |
|
20 | 20 | from __future__ import annotations |
21 | 21 |
|
| 22 | +import dataclasses |
22 | 23 | from typing import Any |
23 | 24 | from typing import Dict |
24 | 25 | from typing import Optional |
25 | 26 | from typing import Sequence |
26 | 27 | from typing import Tuple |
27 | 28 |
|
28 | 29 | from apache_beam.typehints import typehints |
| 30 | +from apache_beam.typehints.native_type_compatibility import match_is_dataclass |
29 | 31 | from apache_beam.typehints.native_type_compatibility import match_is_named_tuple |
30 | 32 | from apache_beam.typehints.schema_registry import SchemaTypeRegistry |
31 | 33 |
|
@@ -56,18 +58,14 @@ def __init__( |
56 | 58 | for guidance on creating PCollections with inferred schemas. |
57 | 59 |
|
58 | 60 | Note RowTypeConstraint does not currently store arbitrary functions for |
59 | | - converting to/from the user type. Instead, we only support ``NamedTuple`` |
60 | | - user types and make the follow assumptions: |
| 61 | + converting to/from the user type. Instead, we support ``NamedTuple`` and |
| 62 | + ``dataclasses`` user types and make the follow assumptions: |
61 | 63 |
|
62 | 64 | - The user type can be constructed with field values as arguments in order |
63 | 65 | (i.e. ``constructor(*field_values)``). |
64 | 66 | - Field values can be accessed from instances of the user type by attribute |
65 | 67 | (i.e. with ``getattr(obj, field_name)``). |
66 | 68 |
|
67 | | - In the future we will add support for dataclasses |
68 | | - ([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy |
69 | | - these assumptions. |
70 | | -
|
71 | 69 | The RowTypeConstraint constructor should not be called directly (even |
72 | 70 | internally to Beam). Prefer static methods ``from_user_type`` or |
73 | 71 | ``from_fields``. |
@@ -127,6 +125,29 @@ def from_user_type( |
127 | 125 | field_options=field_options, |
128 | 126 | field_descriptions=field_descriptions) |
129 | 127 |
|
| 128 | + if match_is_dataclass(user_type): |
| 129 | + fields = [(field.name, field.type) |
| 130 | + for field in dataclasses.fields(user_type)] |
| 131 | + |
| 132 | + field_descriptions = getattr(user_type, '_field_descriptions', None) |
| 133 | + |
| 134 | + if _user_type_is_generated(user_type): |
| 135 | + return RowTypeConstraint.from_fields( |
| 136 | + fields, |
| 137 | + schema_id=getattr(user_type, _BEAM_SCHEMA_ID), |
| 138 | + schema_options=schema_options, |
| 139 | + field_options=field_options, |
| 140 | + field_descriptions=field_descriptions) |
| 141 | + |
| 142 | + # TODO(https://github.com/apache/beam/issues/22125): Add user API for |
| 143 | + # specifying schema/field options |
| 144 | + return RowTypeConstraint( |
| 145 | + fields=fields, |
| 146 | + user_type=user_type, |
| 147 | + schema_options=schema_options, |
| 148 | + field_options=field_options, |
| 149 | + field_descriptions=field_descriptions) |
| 150 | + |
130 | 151 | return None |
131 | 152 |
|
132 | 153 | @staticmethod |
|
0 commit comments