-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy patharray.py
More file actions
178 lines (150 loc) · 6.51 KB
/
array.py
File metadata and controls
178 lines (150 loc) · 6.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import math
import sys
import warnings
from collections.abc import Sequence
from typing import Any, cast
import polars as pl
from dataframely._compat import pa, sa, sa_TypeEngine
from dataframely.random import Generator
from ._base import Check, Column
from ._registry import column_from_dict, register
from .list import _list_primary_key_check
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
@register
class Array(Column):
"""A fixed-shape array column."""
def __init__(
self,
inner: Column,
shape: int | tuple[int, ...],
*,
nullable: bool = False,
primary_key: bool = False,
unique: bool = False,
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
):
"""
Args:
inner: The inner column type.
shape: The shape of the array.
nullable: Whether this column may contain null values.
primary_key: Whether this column is part of the primary key of the schema.
unique: Whether this column must contain unique values. Unlike `primary_key`,
this checks uniqueness for this column independently. Multiple columns
can each have `unique=True` without forming a composite constraint.
check: A custom rule or multiple rules to run for this column. This can be:
- A single callable that returns a non-aggregated boolean expression.
The name of the rule is derived from the callable name, or defaults to
"check" for lambdas.
- A list of callables, where each callable returns a non-aggregated
boolean expression. The name of the rule is derived from the callable
name, or defaults to "check" for lambdas. Where multiple rules result
in the same name, the suffix __i is appended to the name.
- A dictionary mapping rule names to callables, where each callable
returns a non-aggregated boolean expression.
All rule names provided here are given the prefix `"check_"`.
alias: An overwrite for this column's name which allows for using a column
name that is not a valid Python identifier. Especially note that setting
this option does _not_ allow to refer to the column with two different
names, the specified alias is the only valid name.
metadata: A dictionary of metadata to attach to the column.
"""
super().__init__(
nullable=nullable,
primary_key=primary_key,
unique=unique,
check=check,
alias=alias,
metadata=metadata,
)
self.inner = inner
self.shape = shape if isinstance(shape, tuple) else (shape,)
@property
def dtype(self) -> pl.DataType:
return pl.Array(self.inner.dtype, self.shape)
def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
inner_rules = {
f"inner_{rule_name}": expr.arr.eval(inner_expr).arr.all()
for rule_name, inner_expr in self.inner.validation_rules(
pl.element()
).items()
}
array_rules: dict[str, pl.Expr] = {}
if (rule := _list_primary_key_check(expr.arr, self.inner)) is not None:
array_rules["primary_key"] = rule
if self.unique:
# Wrap the column in a struct to make `is_unique` work with arrays:
# https://github.com/pola-rs/polars/issues/27286
array_rules["unique"] = pl.struct(expr).is_unique()
return {
**super().validation_rules(expr),
**array_rules,
**inner_rules,
}
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
match dialect.name:
case "postgresql":
# Note that the length of the array in each dimension is not supported in SQLAlchemy
# That is because PostgreSQL does not enforce the length anyway
return sa.ARRAY(
self.inner.sqlalchemy_dtype(dialect), dimensions=len(self.shape)
)
case _:
raise NotImplementedError(
f"SQL column cannot have 'Array' type for dialect '{dialect}'."
)
def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field:
if shape:
size, *rest = shape
inner_type = self._pyarrow_field_of_shape(rest)
return pa.field("item", pa.list_(inner_type, size), nullable=True)
else:
return self.inner.pyarrow_field("item")
@property
def pyarrow_dtype(self) -> pa.DataType:
return self._pyarrow_field_of_shape(self.shape).type
@property
def _python_type(self) -> Any:
inner_type = self.inner.pydantic_field()
return list[inner_type] # type: ignore
def _pydantic_field_kwargs(self) -> dict[str, Any]:
if len(self.shape) != 1:
warnings.warn(
"Multi-dimensional arrays are flattened for pydantic validation."
)
return {
**super()._pydantic_field_kwargs(),
"min_length": math.prod(self.shape),
"max_length": math.prod(self.shape),
}
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
# Sample the inner elements in a flat series
n_elements = n * math.prod(self.shape)
all_elements = self.inner.sample(generator, n_elements)
# Finally, apply a null mask
return generator._apply_null_mask(
all_elements.reshape((n, *self.shape)),
null_probability=self._null_probability,
)
def _attributes_match(
self, lhs: Any, rhs: Any, name: str, column_expr: pl.Expr
) -> bool:
if name == "inner":
return cast(Column, lhs).matches(cast(Column, rhs), pl.element())
return super()._attributes_match(lhs, rhs, name, column_expr)
def as_dict(self, expr: pl.Expr) -> dict[str, Any]:
result = super().as_dict(expr)
result["inner"] = self.inner.as_dict(pl.element())
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
data["inner"] = column_from_dict(data["inner"])
return super().from_dict(data)