Skip to content

Commit 2cf6a39

Browse files
authored
style: Enforce type annotations (#7)
1 parent 3cfe686 commit 2cf6a39

67 files changed

Lines changed: 342 additions & 310 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dataframely/_base_collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class Metadata:
7979
members: dict[str, MemberInfo] = field(default_factory=dict)
8080
filters: dict[str, Filter] = field(default_factory=dict)
8181

82-
def update(self, other: Self):
82+
def update(self, other: Self) -> None:
8383
self.members.update(other.members)
8484
self.filters.update(other.filters)
8585

@@ -92,7 +92,7 @@ def __new__(
9292
namespace: dict[str, Any],
9393
*args: Any,
9494
**kwargs: Any,
95-
):
95+
) -> CollectionMeta:
9696
result = Metadata()
9797
for base in bases:
9898
result.update(mcs._get_metadata_recursively(base))

dataframely/_base_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Metadata:
5858
columns: dict[str, Column] = field(default_factory=dict)
5959
rules: dict[str, Rule] = field(default_factory=dict)
6060

61-
def update(self, other: Self):
61+
def update(self, other: Self) -> None:
6262
self.columns.update(other.columns)
6363
self.rules.update(other.rules)
6464

@@ -71,7 +71,7 @@ def __new__(
7171
namespace: dict[str, Any],
7272
*args: Any,
7373
**kwargs: Any,
74-
):
74+
) -> SchemaMeta:
7575
result = Metadata()
7676
for base in bases:
7777
result.update(mcs._get_metadata_recursively(base))

dataframely/_compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class _DummyModule: # pragma: no cover
9-
def __init__(self, module: str):
9+
def __init__(self, module: str) -> None:
1010
self.module = module
1111

1212
def __getattr__(self, name: str) -> Any:

dataframely/_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Filter(Generic[C]):
1313
"""Internal class representing logic for filtering members of a collection."""
1414

15-
def __init__(self, logic: Callable[[C], pl.LazyFrame]):
15+
def __init__(self, logic: Callable[[C], pl.LazyFrame]) -> None:
1616
self.logic = logic
1717

1818

dataframely/_rule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
class Rule:
1313
"""Internal class representing validation rules."""
1414

15-
def __init__(self, expr: pl.Expr):
15+
def __init__(self, expr: pl.Expr) -> None:
1616
self.expr = expr
1717

1818

1919
class GroupRule(Rule):
2020
"""Rule that is evaluated on a group of columns."""
2121

22-
def __init__(self, expr: pl.Expr, group_columns: list[str]):
22+
def __init__(self, expr: pl.Expr, group_columns: list[str]) -> None:
2323
super().__init__(expr)
2424
self.group_columns = group_columns
2525

dataframely/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def collect_all(self) -> Self:
498498

499499
# ---------------------------------- PERSISTENCE --------------------------------- #
500500

501-
def write_parquet(self, directory: Path):
501+
def write_parquet(self, directory: Path) -> None:
502502
"""Write the members of this collection to Parquet files in a directory.
503503
504504
This method writes one Parquet file per member into the provided directory.
@@ -590,7 +590,7 @@ def _init(cls, data: Mapping[str, FrameType], /) -> Self:
590590
return out
591591

592592
@classmethod
593-
def _validate_input_keys(cls, data: Mapping[str, FrameType], /):
593+
def _validate_input_keys(cls, data: Mapping[str, FrameType], /) -> None:
594594
actual = set(data)
595595

596596
missing = cls.required_members() - actual

dataframely/columns/_mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
8383
class IsInMixin(Generic[U], Base):
8484
"""Mixin to use for types implementing "is in"."""
8585

86-
def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any):
86+
def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any) -> None:
8787
super().__init__(**kwargs)
8888
self.is_in = is_in
8989

dataframely/columns/decimal.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
148148
# --------------------------------------- UTILS -------------------------------------- #
149149

150150

151-
def _validate(value: decimal.Decimal, precision: int | None, scale: int, name: str):
151+
def _validate(
152+
value: decimal.Decimal, precision: int | None, scale: int, name: str
153+
) -> None:
152154
exponent = value.as_tuple().exponent
153155
if not isinstance(exponent, int):
154156
raise ValueError(f"Encountered 'inf' or 'NaN' for `{name}`.")

dataframely/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ class Config(contextlib.ContextDecorator):
2525
#: Singleton stack to track where to go back after exiting a context.
2626
_stack: list[Options] = []
2727

28-
def __init__(self, **options: Unpack[Options]):
28+
def __init__(self, **options: Unpack[Options]) -> None:
2929
self._local_options: Options = {**default_options(), **options}
3030

3131
@staticmethod
32-
def set_max_sampling_iterations(iterations: int):
32+
def set_max_sampling_iterations(iterations: int) -> None:
3333
"""Set the maximum number of sampling iterations to use on
3434
:meth:`Schema.sample`."""
3535
Config.options["max_sampling_iterations"] = iterations
3636

3737
@staticmethod
38-
def restore_defaults():
38+
def restore_defaults() -> None:
3939
"""Restore the defaults of the configuration."""
4040
Config.options = default_options()
4141

4242
# ------------------------------------ CONTEXT ----------------------------------- #
4343

44-
def __enter__(self):
44+
def __enter__(self) -> None:
4545
Config._stack.append(Config.options)
4646
Config.options = self._local_options
4747

@@ -50,5 +50,5 @@ def __exit__(
5050
exc_type: type[BaseException] | None,
5151
exc_val: BaseException | None,
5252
exc_tb: TracebackType | None,
53-
):
53+
) -> None:
5454
Config.options = Config._stack.pop()

dataframely/exc.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class ValidationError(Exception):
1212
"""Error raised when :mod:`dataframely` validation encounters an issue."""
1313

14-
def __init__(self, message: str):
14+
def __init__(self, message: str) -> None:
1515
super().__init__()
1616
self.message = message
1717

@@ -22,7 +22,9 @@ def __str__(self) -> str:
2222
class DtypeValidationError(ValidationError):
2323
"""Validation error raised when column dtypes are wrong."""
2424

25-
def __init__(self, errors: dict[str, tuple[PolarsDataType, PolarsDataType]]):
25+
def __init__(
26+
self, errors: dict[str, tuple[PolarsDataType, PolarsDataType]]
27+
) -> None:
2628
super().__init__(f"{len(errors)} columns have an invalid dtype")
2729
self.errors = errors
2830

@@ -37,7 +39,7 @@ def __str__(self) -> str:
3739
class RuleValidationError(ValidationError):
3840
"""Complex validation error raised when rule validation fails."""
3941

40-
def __init__(self, errors: dict[str, int]):
42+
def __init__(self, errors: dict[str, int]) -> None:
4143
super().__init__(f"{len(errors)} rules failed validation")
4244

4345
# Split into schema errors and column errors
@@ -75,11 +77,11 @@ def __str__(self) -> str:
7577
class MemberValidationError(ValidationError):
7678
"""Validation error raised when multiple members of a collection fail validation."""
7779

78-
def __init__(self, errors: dict[str, ValidationError]):
80+
def __init__(self, errors: dict[str, ValidationError]) -> None:
7981
super().__init__(f"{len(errors)} members failed validation")
8082
self.errors = errors
8183

82-
def __str__(self):
84+
def __str__(self) -> str:
8385
details = [
8486
f" > Member '{name}' failed validation:\n"
8587
+ "\n".join(" " + line for line in str(error).split("\n"))
@@ -95,7 +97,7 @@ class ImplementationError(Exception):
9597
class AnnotationImplementationError(ImplementationError):
9698
"""Error raised when the annotations of a collection are invalid."""
9799

98-
def __init__(self, attr: str, kls: type):
100+
def __init__(self, attr: str, kls: type) -> None:
99101
message = (
100102
"Annotations of a 'dy.Collection' may only be an (optional) "
101103
f"'dy.LazyFrame', but \"{attr}\" has type '{kls}'."
@@ -106,7 +108,9 @@ def __init__(self, attr: str, kls: type):
106108
class RuleImplementationError(ImplementationError):
107109
"""Error raised when a rule is implemented incorrectly."""
108110

109-
def __init__(self, name: str, return_dtype: pl.DataType, is_group_rule: bool):
111+
def __init__(
112+
self, name: str, return_dtype: pl.DataType, is_group_rule: bool
113+
) -> None:
110114
if is_group_rule:
111115
details = (
112116
" When implementing a group rule (i.e. when using the `group_by` "

0 commit comments

Comments
 (0)