-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy path_base_schema.py
More file actions
329 lines (272 loc) · 12.6 KB
/
_base_schema.py
File metadata and controls
329 lines (272 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import ast
import inspect
import sys
import textwrap
from abc import ABCMeta
from copy import copy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import polars as pl
from ._rule import DtypeCastRule, GroupRule, Rule, RuleFactory
from .columns import Column
from .exc import ImplementationError
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
_COLUMN_ATTR = "__dataframely_columns__"
_RULE_ATTR = "__dataframely_rules__"
ORIGINAL_COLUMN_PREFIX = "__DATAFRAMELY_ORIGINAL__"
# --------------------------------------- UTILS -------------------------------------- #
def _extract_column_docstrings(cls: type) -> dict[str, str]:
"""Extract docstrings for class attributes from source code.
This function parses the source code of a class to find string literals
that immediately follow attribute assignments. These are treated as
documentation strings for those attributes.
Args:
cls: The class to extract docstrings from.
Returns:
A dictionary mapping attribute names to their docstrings.
"""
try:
source = inspect.getsource(cls)
# Dedent to handle indented class definitions
tree = ast.parse(textwrap.dedent(source))
# Find the class definition
class_def = None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_def = node
break
if not class_def:
return {}
# Extract docstrings that appear after assignments
docstrings = {}
for i in range(len(class_def.body) - 1):
current = class_def.body[i]
next_stmt = class_def.body[i + 1]
# Check if current is an assignment and next is a string constant
if (
isinstance(current, ast.Assign)
and isinstance(next_stmt, ast.Expr)
and isinstance(next_stmt.value, ast.Constant)
and isinstance(next_stmt.value.value, str)
):
# Get the target name(s)
for target in current.targets:
if isinstance(target, ast.Name):
docstrings[target.id] = next_stmt.value.value
return docstrings
except (OSError, TypeError, SyntaxError):
# Source not available or cannot be parsed
return {}
def _build_rules(
custom: dict[str, Rule], columns: dict[str, Column], *, with_cast: bool
) -> dict[str, Rule]:
# NOTE: Copy here to prevent in-place modification of the custom rules
rules: dict[str, Rule] = copy(custom)
# Add primary key validation to the list of rules if applicable
primary_key = _primary_key(columns)
if len(primary_key) > 0:
rules["primary_key"] = Rule(~pl.struct(primary_key).is_duplicated())
# Add column-specific rules
column_rules = {
f"{col_name}|{rule_name}": Rule(expr)
for col_name, column in columns.items()
for rule_name, expr in column.validation_rules(pl.col(col_name)).items()
}
rules.update(column_rules)
# Add casting rules if requested. Here, we can simply check whether the nullability
# property of a column changes due to lenient dtype casting. Whenever casting fails,
# the value is set to `null`, mismatching the previous nullability.
# NOTE: This check assumes that both the original and cast column are present in the
# data frame.
if with_cast:
casting_rules = {
f"{col_name}|dtype": DtypeCastRule(
pl.col(col_name).is_null()
== pl.col(f"{ORIGINAL_COLUMN_PREFIX}{col_name}").is_null()
)
for col_name in columns
}
rules.update(casting_rules)
return rules
def _primary_key(columns: dict[str, Column]) -> list[str]:
return list(k for k, col in columns.items() if col.primary_key)
# ------------------------------------------------------------------------------------ #
# SCHEMA META #
# ------------------------------------------------------------------------------------ #
@dataclass
class Metadata:
"""Utility class to gather columns and rules associated with a schema."""
columns: dict[str, Column] = field(default_factory=dict)
rules: dict[str, RuleFactory] = field(default_factory=dict)
def update(self, other: Self) -> None:
self.columns.update(other.columns)
self.rules.update(other.rules)
class SchemaMeta(ABCMeta):
def __new__(
mcs, # noqa: N804
name: str,
bases: tuple[type[object], ...],
namespace: dict[str, Any],
*args: Any,
**kwargs: Any,
) -> SchemaMeta:
result = Metadata()
for base in bases:
result.update(mcs._get_metadata_recursively(base))
result.update(mcs._get_metadata(namespace))
namespace[_COLUMN_ATTR] = result.columns
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)
# Extract and attach docstrings to columns
docstrings = _extract_column_docstrings(cls)
for col_name, col in result.columns.items():
# Use the original attribute name (not alias) to match docstrings
original_name = None
for attr, value in namespace.items():
if value is col:
original_name = attr
break
# If we found a docstring for this column and it doesn't already have one,
# attach it
if original_name and original_name in docstrings and col.doc is None:
col.doc = docstrings[original_name]
# Assign rules retroactively as we only encounter rule factories in the result
rules = {name: factory.make(cls) for name, factory in result.rules.items()}
setattr(cls, _RULE_ATTR, rules)
# At this point, we already know all columns and custom rules. We want to run
# some checks...
# 1) Check that the column names clash with none of the rule names. To this end,
# we assume that users cast dtypes, i.e. additional rules for dtype casting
# are also checked.
all_column_names = set(result.columns)
all_rule_names = set(_build_rules(rules, result.columns, with_cast=True))
common_names = all_column_names & all_rule_names
if len(common_names) > 0:
common_list = ", ".join(sorted(f"'{col}'" for col in common_names))
raise ImplementationError(
"Rules and columns must not be named equally but found "
f"{len(common_names)} overlaps: {common_list}."
)
# 2) Check that the columns referenced in the group rules exist.
for rule_name, rule in rules.items():
if isinstance(rule, GroupRule):
missing_columns = set(rule.group_columns) - set(result.columns)
if len(missing_columns) > 0:
missing_list = ", ".join(
sorted(f"'{col}'" for col in missing_columns)
)
raise ImplementationError(
f"Group validation rule '{rule_name}' has been implemented "
f"incorrectly. It references {len(missing_columns)} columns "
f"which are not in the schema: {missing_list}."
)
# 3) Check that all members are non-pathological (i.e., user errors).
for attr, value in namespace.items():
if attr.startswith("__"):
continue
# Check for tuple of column (commonly caused by trailing comma)
if (
isinstance(value, tuple)
and len(value) > 0
and isinstance(value[0], Column)
):
raise TypeError(
f"Column '{attr}' is defined as a tuple of dy.Column. "
f"Did you accidentally add a trailing comma?"
)
# Check for column type instead of instance (e.g., dy.Float64 instead of dy.Float64())
if isinstance(value, type) and issubclass(value, Column):
raise TypeError(
f"Column '{attr}' is a type, not an instance. "
f"Schema members must be of type Column not type[Column]. "
f"Did you forget to add parentheses?"
)
# Check for pl.DataType instance or type (e.g., pl.String() or pl.String instead of dy.String())
if isinstance(value, pl.DataType) or (
isinstance(value, type) and issubclass(value, pl.DataType)
):
value_type = "instance" if isinstance(value, pl.DataType) else "type"
example = (
"pl.String()" if isinstance(value, pl.DataType) else "pl.String"
)
raise TypeError(
f"Schema member '{attr}' is a polars DataType {value_type}. "
f"Use dataframely column types (e.g., dy.String()) instead of polars types (e.g., {example})."
)
return cls
if not TYPE_CHECKING:
# Only define __getattribute__ at runtime to allow type checkers to properly
# validate attribute access. When TYPE_CHECKING is True, type checkers will use
# the default metaclass behavior which correctly identifies non-existent attributes.
def __getattribute__(cls, name: str) -> Any:
val = super().__getattribute__(name)
# Dynamically set the name of the column if it is a `Column` instance.
if isinstance(val, Column):
val._name = val.alias or name
return val
@staticmethod
def _get_metadata_recursively(kls: type[object]) -> Metadata:
result = Metadata()
for base in kls.__bases__:
result.update(SchemaMeta._get_metadata_recursively(base))
result.update(SchemaMeta._get_metadata(kls.__dict__)) # type: ignore
return result
@staticmethod
def _get_metadata(source: dict[str, Any]) -> Metadata:
result = Metadata()
for attr, value in {
k: v for k, v in source.items() if not k.startswith("__")
}.items():
if isinstance(value, Column):
result.columns[value.alias or attr] = value
if isinstance(value, RuleFactory):
# We must ensure that custom rules do not clash with internal rules.
if attr == "primary_key":
raise ImplementationError(
"Custom validation rule must not be named `primary_key`."
)
result.rules[attr] = value
return result
def __repr__(cls) -> str:
parts = [f'[Schema "{cls.__name__}"]']
parts.append(textwrap.indent("Columns:", prefix=" " * 2))
for name, col in cls.columns().items(): # type: ignore[attr-defined]
parts.append(textwrap.indent(f'- "{name}": {col!r}', prefix=" " * 4))
if validation_rules := cls._schema_validation_rules(): # type: ignore[attr-defined]
parts.append(textwrap.indent("Rules:", prefix=" " * 2))
for name, rule in validation_rules.items():
parts.append(textwrap.indent(f'- "{name}": {rule!r}', prefix=" " * 4))
parts.append("") # Add line break at the end
return "\n".join(parts)
class BaseSchema(metaclass=SchemaMeta):
"""Internal utility abstraction to reference schemas without introducing cyclical
dependencies."""
@classmethod
def column_names(cls) -> list[str]:
"""The column names of this schema."""
return list(getattr(cls, _COLUMN_ATTR).keys())
@classmethod
def columns(cls) -> dict[str, Column]:
"""The column definitions of this schema."""
columns: dict[str, Column] = getattr(cls, _COLUMN_ATTR)
for name in columns.keys():
# Dynamically set the name of the columns.
columns[name]._name = name
return columns
@classmethod
def primary_key(cls) -> list[str]:
"""The primary key columns in this schema (possibly empty)."""
return _primary_key(cls.columns())
@classmethod
def _validation_rules(cls, *, with_cast: bool) -> dict[str, Rule]:
return _build_rules(
cls._schema_validation_rules(), cls.columns(), with_cast=with_cast
)
@classmethod
def _schema_validation_rules(cls) -> dict[str, Rule]:
return getattr(cls, _RULE_ATTR)