Skip to content

Commit ce04905

Browse files
committed
Pushed test coverage to 100%
1 parent c7e2132 commit ce04905

1 file changed

Lines changed: 65 additions & 10 deletions

File tree

src/flexible_schema/pyarrow.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,16 @@ class PyArrowSchema(Schema):
4646
numeric_value=1.0,
4747
text_value=None,
4848
parent_codes=None)
49+
50+
You can also validate tables with this class
51+
4952
>>> data_tbl = pa.Table.from_pydict({
53+
... "subject_id": [1, 2, 3],
5054
... "time": [
5155
... datetime.datetime(2021, 3, 1),
5256
... datetime.datetime(2021, 4, 1),
5357
... datetime.datetime(2021, 5, 1),
5458
... ],
55-
... "subject_id": [1, 2, 3],
5659
... "code": ["A", "B", "C"],
5760
... })
5861
>>> Data.validate(data_tbl)
@@ -71,6 +74,43 @@ class PyArrowSchema(Schema):
7174
numeric_value: [[null,null,null]]
7275
text_value: [[null,null,null]]
7376
parent_codes: [[null,null,null]]
77+
78+
Including casting and reordering columns:
79+
80+
>>> data_tbl = pa.Table.from_pydict({
81+
... "time": [
82+
... datetime.datetime(2021, 3, 1),
83+
... datetime.datetime(2021, 4, 1),
84+
... datetime.datetime(2021, 5, 1),
85+
... ],
86+
... "subject_id": [1, 2, 3],
87+
... "code": ["A", "B", "C"],
88+
... }, schema=pa.schema(
89+
... [
90+
... pa.field("time", pa.timestamp("us")),
91+
... pa.field("subject_id", pa.int32()),
92+
... pa.field("code", pa.string()),
93+
... ]
94+
... ))
95+
>>> Data.validate(data_tbl)
96+
pyarrow.Table
97+
subject_id: int64
98+
time: timestamp[us]
99+
code: string
100+
numeric_value: float
101+
text_value: string
102+
parent_codes: list<item: string>
103+
child 0, item: string
104+
----
105+
subject_id: [[1,2,3]]
106+
time: [[2021-03-01 00:00:00.000000,2021-04-01 00:00:00.000000,2021-05-01 00:00:00.000000]]
107+
code: [["A","B","C"]]
108+
numeric_value: [[null,null,null]]
109+
text_value: [[null,null,null]]
110+
parent_codes: [[null,null,null]]
111+
112+
And handling extra columns:
113+
74114
>>> data_tbl_with_extra = pa.Table.from_pydict({
75115
... "time": [
76116
... datetime.datetime(2021, 3, 1),
@@ -112,17 +152,16 @@ class PyArrowSchema(Schema):
112152
DataType(int64)
113153
>>> Data.code_dtype
114154
DataType(string)
115-
>>> data_tbl_with_extra = pa.Table.from_pydict({
116-
... "subject_id": [4, 5],
117-
... "code": ["D", "E"],
118-
... })
119-
>>> Data.validate(data_tbl_with_extra)
155+
>>> Data.validate(pa.Table.from_pydict({"subject_id": [4, 5], "code": ["D", "E"]}))
120156
pyarrow.Table
121157
subject_id: int64
122158
code: string
123159
----
124160
subject_id: [[4,5]]
125161
code: [["D","E"]]
162+
163+
Errors will be raised when extra columns are present inapproriately or mandatory columns are missing:
164+
126165
>>> data_tbl_with_extra = pa.Table.from_pydict({
127166
... "subject_id": [4, 5],
128167
... "code": ["D", "E"],
@@ -132,6 +171,25 @@ class PyArrowSchema(Schema):
132171
Traceback (most recent call last):
133172
...
134173
flexible_schema.base.SchemaValidationError: Unexpected extra columns: {'extra_1'}
174+
>>> Data.validate(pa.Table.from_pydict({ "subject_id": [4, 5], }))
175+
Traceback (most recent call last):
176+
...
177+
flexible_schema.base.SchemaValidationError: Missing mandatory columns: {'code'}
178+
179+
Or when columns can't be cast properly:
180+
181+
>>> Data.validate(pa.Table.from_pydict({"subject_id": ["A", "B"], "code": ["D", "E"]}))
182+
Traceback (most recent call last):
183+
...
184+
flexible_schema.base.SchemaValidationError: Column 'subject_id' cast failed: ...
185+
186+
Not all types are supported
187+
188+
>>> class Data(PyArrowSchema):
189+
... foo: dict[str, str]
190+
Traceback (most recent call last):
191+
...
192+
ValueError: Unsupported type: dict[str, str]
135193
"""
136194

137195
PYTHON_TO_PYARROW: ClassVar[dict[Any, pa.DataType]] = {
@@ -165,13 +223,10 @@ def _remap_type_internal(cls, field_type: Any) -> pa.DataType:
165223
@classmethod
166224
def validate(
167225
cls,
168-
table: pa.Table | dict[str, list[Any]],
226+
table: pa.Table,
169227
reorder_columns: bool = True,
170228
cast_types: bool = True,
171229
) -> pa.Table:
172-
if isinstance(table, dict):
173-
table = pa.Table.from_pydict(table)
174-
175230
table_cols = set(table.column_names)
176231
mandatory_cols = {f.name for f in fields(cls) if not cls._is_optional(f.type)}
177232
all_defined_cols = {f.name for f in fields(cls)}

0 commit comments

Comments
 (0)