|
19 | 19 |
|
20 | 20 | # pytype: skip-file |
21 | 21 |
|
| 22 | +import dataclasses |
22 | 23 | import itertools |
23 | 24 | import pickle |
24 | 25 | import unittest |
@@ -388,6 +389,23 @@ def test_namedtuple_roundtrip(self, user_type): |
388 | 389 | self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) |
389 | 390 | self.assert_namedtuple_equivalent(roundtripped.user_type, user_type) |
390 | 391 |
|
| 392 | + def test_dataclass_roundtrip(self): |
| 393 | + @dataclasses.dataclass |
| 394 | + class SimpleDataclass: |
| 395 | + id: np.int64 |
| 396 | + name: str |
| 397 | + |
| 398 | + roundtripped = typing_from_runner_api( |
| 399 | + typing_to_runner_api( |
| 400 | + SimpleDataclass, schema_registry=SchemaTypeRegistry()), |
| 401 | + schema_registry=SchemaTypeRegistry()) |
| 402 | + |
| 403 | + self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) |
| 404 | + # The roundtripped user_type is generated as a NamedTuple, so we can't test equivalence directly with the dataclass. |
| 405 | + # Instead, let's verify annotations. |
| 406 | + self.assertEqual( |
| 407 | + roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__) |
| 408 | + |
391 | 409 | def test_row_type_constraint_to_schema(self): |
392 | 410 | result_type = typing_to_runner_api( |
393 | 411 | row_type.RowTypeConstraint.from_fields([ |
@@ -646,6 +664,48 @@ def test_trivial_example(self): |
646 | 664 | expected.row_type.schema.fields, |
647 | 665 | typing_to_runner_api(MyCuteClass).row_type.schema.fields) |
648 | 666 |
|
| 667 | + def test_trivial_example_dataclass(self): |
| 668 | + @dataclasses.dataclass |
| 669 | + class MyCuteDataclass: |
| 670 | + name: str |
| 671 | + age: Optional[int] |
| 672 | + interests: List[str] |
| 673 | + height: float |
| 674 | + blob: ByteString |
| 675 | + |
| 676 | + expected = schema_pb2.FieldType( |
| 677 | + row_type=schema_pb2.RowType( |
| 678 | + schema=schema_pb2.Schema( |
| 679 | + fields=[ |
| 680 | + schema_pb2.Field( |
| 681 | + name='name', |
| 682 | + type=schema_pb2.FieldType( |
| 683 | + atomic_type=schema_pb2.STRING), |
| 684 | + ), |
| 685 | + schema_pb2.Field( |
| 686 | + name='age', |
| 687 | + type=schema_pb2.FieldType( |
| 688 | + nullable=True, atomic_type=schema_pb2.INT64)), |
| 689 | + schema_pb2.Field( |
| 690 | + name='interests', |
| 691 | + type=schema_pb2.FieldType( |
| 692 | + array_type=schema_pb2.ArrayType( |
| 693 | + element_type=schema_pb2.FieldType( |
| 694 | + atomic_type=schema_pb2.STRING)))), |
| 695 | + schema_pb2.Field( |
| 696 | + name='height', |
| 697 | + type=schema_pb2.FieldType( |
| 698 | + atomic_type=schema_pb2.DOUBLE)), |
| 699 | + schema_pb2.Field( |
| 700 | + name='blob', |
| 701 | + type=schema_pb2.FieldType( |
| 702 | + atomic_type=schema_pb2.BYTES)), |
| 703 | + ]))) |
| 704 | + |
| 705 | + self.assertEqual( |
| 706 | + expected.row_type.schema.fields, |
| 707 | + typing_to_runner_api(MyCuteDataclass).row_type.schema.fields) |
| 708 | + |
649 | 709 | def test_user_type_annotated_with_id_after_conversion(self): |
650 | 710 | MyCuteClass = NamedTuple('MyCuteClass', [ |
651 | 711 | ('name', str), |
|
0 commit comments