Skip to content

Commit 4e49db9

Browse files
committed
Introduce register_row to register with both coder and schema registry
Save schema registry id->type mapping
1 parent 1e00d27 commit 4e49db9

File tree

7 files changed

+46
-15
lines changed

7 files changed

+46
-15
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run",
3-
"modification": 15
3+
"modification": 16
44
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run",
3-
"modification": 16
3+
"modification": 17
44
}

sdks/python/apache_beam/coders/typecoders.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def _normalize_typehint_type(typehint_type):
124124
def register_coder(
125125
self, typehint_type: Any,
126126
typehint_coder_class: Type[coders.Coder]) -> None:
127+
"Register a user type with a coder"
127128
if not isinstance(typehint_coder_class, type):
128129
raise TypeError(
129130
'Coder registration requires a coder class object. '
@@ -133,6 +134,20 @@ def register_coder(
133134
self._register_coder_internal(
134135
self._normalize_typehint_type(typehint_type), typehint_coder_class)
135136

137+
def register_row(self, typehint_type: Any) -> None:
138+
"""
139+
Register a user type with a Beam Row.
140+
141+
This registers the type with a RowCoder and register its schema.
142+
"""
143+
from apache_beam.typehints.schemas import typing_to_runner_api
144+
from apache_beam.coders import RowCoder
145+
# Register with row coder
146+
self.register_coder(typehint_type, RowCoder)
147+
# This call generated a schema id for the type and register it with
148+
# schema registry
149+
typing_to_runner_api(typehint_type)
150+
136151
def get_coder(self, typehint: Any) -> coders.Coder:
137152
if typehint and typehint.__module__ == '__main__':
138153
# See https://github.com/apache/beam/issues/21541

sdks/python/apache_beam/internal/cloudpickle_pickler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,20 +256,27 @@ def dump_session(file_path):
256256
# dump supported Beam Registries (currently only logical type registry)
257257
from apache_beam.coders import typecoders
258258
from apache_beam.typehints import schemas
259+
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
259260

260261
with _pickle_lock, open(file_path, 'wb') as file:
261262
coder_reg = typecoders.registry.get_custom_type_coder_tuples()
262263
logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom()
264+
schema_reg = SCHEMA_REGISTRY.get_registered_typings()
263265

264266
pickler = cloudpickle.CloudPickler(file)
265267
# TODO(https://github.com/apache/beam/issues/18500) add file system registry
266268
# once implemented
267-
pickler.dump({"coder": coder_reg, "logical_type": logical_type_reg})
269+
pickler.dump({
270+
"coder": coder_reg,
271+
"logical_type": logical_type_reg,
272+
"schema": schema_reg
273+
})
268274

269275

270276
def load_session(file_path):
271277
from apache_beam.coders import typecoders
272278
from apache_beam.typehints import schemas
279+
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
273280

274281
with _pickle_lock, open(file_path, 'rb') as file:
275282
registries = cloudpickle.load(file)
@@ -284,3 +291,7 @@ def load_session(file_path):
284291
schemas.LogicalType._known_logical_types.load(registries["logical_type"])
285292
else:
286293
_LOGGER.warning('No logical type registry found in saved session')
294+
if "schema" in registries:
295+
SCHEMA_REGISTRY.load_registered_typings(registries["schema"])
296+
else:
297+
_LOGGER.warning('No schema registry found in saved session')

sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
("f_timestamp", Timestamp), ("f_decimal", Decimal),
6565
("f_date", datetime.date), ("f_time", datetime.time)],
6666
)
67-
coders.registry.register_coder(JdbcTestRow, coders.RowCoder)
67+
coders.registry.register_row(JdbcTestRow)
6868

6969
CustomSchemaRow = typing.NamedTuple(
7070
"CustomSchemaRow",
@@ -82,11 +82,11 @@
8282
("renamed_time", datetime.time),
8383
],
8484
)
85-
coders.registry.register_coder(CustomSchemaRow, coders.RowCoder)
85+
coders.registry.register_row(CustomSchemaRow)
8686

8787
SimpleRow = typing.NamedTuple(
8888
"SimpleRow", [("id", int), ("name", str), ("value", float)])
89-
coders.registry.register_coder(SimpleRow, coders.RowCoder)
89+
coders.registry.register_row(SimpleRow)
9090

9191

9292
@pytest.mark.uses_gcp_java_expansion_service

sdks/python/apache_beam/typehints/row_type_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def test_group_by_key_namedtuple(self):
4444
MyNamedTuple = typing.NamedTuple(
4545
"MyNamedTuple", [("id", int), ("name", str)])
4646

47-
beam.coders.typecoders.registry.register_coder(
48-
MyNamedTuple, beam.coders.RowCoder)
47+
beam.coders.typecoders.registry.register_row(MyNamedTuple)
4948

5049
def generate(num: int):
5150
for i in range(100):
@@ -69,8 +68,7 @@ class MyDataClass:
6968
id: int
7069
name: str
7170

72-
beam.coders.typecoders.registry.register_coder(
73-
MyDataClass, beam.coders.RowCoder)
71+
beam.coders.typecoders.registry.register_row(MyDataClass)
7472

7573
def generate(num: int):
7674
for i in range(100):
@@ -122,10 +120,8 @@ class DataClassInt:
122120
class DataClassStr(DataClassInt):
123121
name: str
124122

125-
beam.coders.typecoders.registry.register_coder(
126-
DataClassInt, beam.coders.RowCoder)
127-
beam.coders.typecoders.registry.register_coder(
128-
DataClassStr, beam.coders.RowCoder)
123+
beam.coders.typecoders.registry.register_row(DataClassInt)
124+
beam.coders.typecoders.registry.register_row(DataClassStr)
129125

130126
def generate(num: int):
131127
for i in range(10):

sdks/python/apache_beam/typehints/schema_registry.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
class SchemaTypeRegistry(object):
2727
def __init__(self):
2828
self.by_id = {}
29-
self.by_typing = {}
29+
self.by_typing = {} # currently not used
3030

3131
def generate_new_id(self):
3232
for _ in range(100):
@@ -43,6 +43,15 @@ def add(self, typing, schema):
4343
if schema.id:
4444
self.by_id[schema.id] = (typing, schema)
4545

46+
def load_registered_typings(self, by_id):
47+
for id, typing in by_id.items():
48+
if id not in self.by_id:
49+
self.by_id[id] = (typing, None)
50+
51+
def get_registered_typings(self):
52+
# Used by save_main_session, as pb2.schema isn't picklable
53+
return {k: v[0] for k, v in self.by_id.items()}
54+
4655
def get_typing_by_id(self, unique_id):
4756
if not unique_id:
4857
return None

0 commit comments

Comments
 (0)