-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy path_mixins.py
More file actions
112 lines (88 loc) · 3.69 KB
/
_mixins.py
File metadata and controls
112 lines (88 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause
import sys
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
import polars as pl
if TYPE_CHECKING: # pragma: no cover
from ._base import Column
Base = Column
else:
Base = object
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
# ----------------------------------- ORDINAL MIXIN ---------------------------------- #
class Comparable(Protocol):
def __gt__(self, other: Self, /) -> bool: ...
def __ge__(self, other: Self, /) -> bool: ...
T = TypeVar("T", bound=Comparable)
class OrdinalMixin(Generic[T], Base):
"""Mixin to use for ordinal types."""
def __init__(
self,
*,
min: T | None = None,
min_exclusive: T | None = None,
max: T | None = None,
max_exclusive: T | None = None,
**kwargs: Any,
):
if min is not None and min_exclusive is not None:
raise ValueError("At most one of `min` and `min_exclusive` must be set.")
if max is not None and max_exclusive is not None:
raise ValueError("At most one of `max` and `max_exclusive` must be set.")
if min is not None and max is not None and min > max:
raise ValueError("`min` must not be greater than `max`.")
if min_exclusive is not None and max is not None and min_exclusive >= max:
raise ValueError("`min_exclusive` must not be greater or equal to `max`.")
if min is not None and max_exclusive is not None and min >= max_exclusive:
raise ValueError("`min` must not be greater or equal to `max_exclusive`.")
if (
min_exclusive is not None
and max_exclusive is not None
and min_exclusive >= max_exclusive
):
raise ValueError(
"`min_exclusive` must not be greater or equal to `max_exclusive`."
)
super().__init__(**kwargs)
self.min = min
self.min_exclusive = min_exclusive
self.max = max
self.max_exclusive = max_exclusive
def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
result = super().validation_rules(expr)
if self.min is not None:
result["min"] = expr >= self.min # type: ignore
if self.min_exclusive is not None:
result["min_exclusive"] = expr > self.min_exclusive # type: ignore
if self.max is not None:
result["max"] = expr <= self.max # type: ignore
if self.max_exclusive is not None:
result["max_exclusive"] = expr < self.max_exclusive # type: ignore
return result
def _pydantic_field_kwargs(self) -> dict[str, Any]:
kwargs = super()._pydantic_field_kwargs()
if self.min is not None:
kwargs["ge"] = self.min
if self.min_exclusive is not None:
kwargs["gt"] = self.min_exclusive
if self.max is not None:
kwargs["le"] = self.max
if self.max_exclusive is not None:
kwargs["lt"] = self.max_exclusive
return kwargs
# ------------------------------------ IS IN MIXIN ----------------------------------- #
U = TypeVar("U")
class IsInMixin(Generic[U], Base):
"""Mixin to use for types implementing "is in"."""
def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.is_in = is_in
def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
result = super().validation_rules(expr)
if self.is_in is not None:
result["is_in"] = expr.is_in(self.is_in)
return result