Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dataframely/_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Metadata:
members: dict[str, MemberInfo] = field(default_factory=dict)
filters: dict[str, Filter] = field(default_factory=dict)

def update(self, other: Self):
def update(self, other: Self) -> None:
self.members.update(other.members)
self.filters.update(other.filters)

Expand All @@ -92,7 +92,7 @@ def __new__(
namespace: dict[str, Any],
*args: Any,
**kwargs: Any,
):
) -> CollectionMeta:
result = Metadata()
for base in bases:
result.update(mcs._get_metadata_recursively(base))
Expand Down
4 changes: 2 additions & 2 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Metadata:
columns: dict[str, Column] = field(default_factory=dict)
rules: dict[str, Rule] = field(default_factory=dict)

def update(self, other: Self):
def update(self, other: Self) -> None:
self.columns.update(other.columns)
self.rules.update(other.rules)

Expand All @@ -71,7 +71,7 @@ def __new__(
namespace: dict[str, Any],
*args: Any,
**kwargs: Any,
):
) -> SchemaMeta:
result = Metadata()
for base in bases:
result.update(mcs._get_metadata_recursively(base))
Expand Down
2 changes: 1 addition & 1 deletion dataframely/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class _DummyModule: # pragma: no cover
def __init__(self, module: str):
def __init__(self, module: str) -> None:
self.module = module

def __getattr__(self, name: str) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion dataframely/_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class Filter(Generic[C]):
"""Internal class representing logic for filtering members of a collection."""

def __init__(self, logic: Callable[[C], pl.LazyFrame]):
def __init__(self, logic: Callable[[C], pl.LazyFrame]) -> None:
self.logic = logic


Expand Down
4 changes: 2 additions & 2 deletions dataframely/_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
class Rule:
"""Internal class representing validation rules."""

def __init__(self, expr: pl.Expr):
def __init__(self, expr: pl.Expr) -> None:
self.expr = expr


class GroupRule(Rule):
"""Rule that is evaluated on a group of columns."""

def __init__(self, expr: pl.Expr, group_columns: list[str]):
def __init__(self, expr: pl.Expr, group_columns: list[str]) -> None:
super().__init__(expr)
self.group_columns = group_columns

Expand Down
4 changes: 2 additions & 2 deletions dataframely/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def collect_all(self) -> Self:

# ---------------------------------- PERSISTENCE --------------------------------- #

def write_parquet(self, directory: Path):
def write_parquet(self, directory: Path) -> None:
"""Write the members of this collection to Parquet files in a directory.

This method writes one Parquet file per member into the provided directory.
Expand Down Expand Up @@ -590,7 +590,7 @@ def _init(cls, data: Mapping[str, FrameType], /) -> Self:
return out

@classmethod
def _validate_input_keys(cls, data: Mapping[str, FrameType], /):
def _validate_input_keys(cls, data: Mapping[str, FrameType], /) -> None:
actual = set(data)

missing = cls.required_members() - actual
Expand Down
2 changes: 1 addition & 1 deletion dataframely/columns/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
class IsInMixin(Generic[U], Base):
"""Mixin to use for types implementing "is in"."""

def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any):
def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.is_in = is_in

Expand Down
4 changes: 3 additions & 1 deletion dataframely/columns/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
# --------------------------------------- UTILS -------------------------------------- #


def _validate(value: decimal.Decimal, precision: int | None, scale: int, name: str):
def _validate(
value: decimal.Decimal, precision: int | None, scale: int, name: str
) -> None:
exponent = value.as_tuple().exponent
if not isinstance(exponent, int):
raise ValueError(f"Encountered 'inf' or 'NaN' for `{name}`.")
Expand Down
10 changes: 5 additions & 5 deletions dataframely/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ class Config(contextlib.ContextDecorator):
#: Singleton stack to track where to go back after exiting a context.
_stack: list[Options] = []

def __init__(self, **options: Unpack[Options]):
def __init__(self, **options: Unpack[Options]) -> None:
self._local_options: Options = {**default_options(), **options}

@staticmethod
def set_max_sampling_iterations(iterations: int):
def set_max_sampling_iterations(iterations: int) -> None:
"""Set the maximum number of sampling iterations to use on
:meth:`Schema.sample`."""
Config.options["max_sampling_iterations"] = iterations

@staticmethod
def restore_defaults():
def restore_defaults() -> None:
"""Restore the defaults of the configuration."""
Config.options = default_options()

# ------------------------------------ CONTEXT ----------------------------------- #

def __enter__(self):
def __enter__(self) -> None:
Config._stack.append(Config.options)
Config.options = self._local_options

Expand All @@ -50,5 +50,5 @@ def __exit__(
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
):
) -> None:
Config.options = Config._stack.pop()
18 changes: 11 additions & 7 deletions dataframely/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ValidationError(Exception):
"""Error raised when :mod:`dataframely` validation encounters an issue."""

def __init__(self, message: str):
def __init__(self, message: str) -> None:
super().__init__()
self.message = message

Expand All @@ -22,7 +22,9 @@ def __str__(self) -> str:
class DtypeValidationError(ValidationError):
"""Validation error raised when column dtypes are wrong."""

def __init__(self, errors: dict[str, tuple[PolarsDataType, PolarsDataType]]):
def __init__(
self, errors: dict[str, tuple[PolarsDataType, PolarsDataType]]
) -> None:
super().__init__(f"{len(errors)} columns have an invalid dtype")
self.errors = errors

Expand All @@ -37,7 +39,7 @@ def __str__(self) -> str:
class RuleValidationError(ValidationError):
"""Complex validation error raised when rule validation fails."""

def __init__(self, errors: dict[str, int]):
def __init__(self, errors: dict[str, int]) -> None:
super().__init__(f"{len(errors)} rules failed validation")

# Split into schema errors and column errors
Expand Down Expand Up @@ -75,11 +77,11 @@ def __str__(self) -> str:
class MemberValidationError(ValidationError):
"""Validation error raised when multiple members of a collection fail validation."""

def __init__(self, errors: dict[str, ValidationError]):
def __init__(self, errors: dict[str, ValidationError]) -> None:
super().__init__(f"{len(errors)} members failed validation")
self.errors = errors

def __str__(self):
def __str__(self) -> str:
details = [
f" > Member '{name}' failed validation:\n"
+ "\n".join(" " + line for line in str(error).split("\n"))
Expand All @@ -95,7 +97,7 @@ class ImplementationError(Exception):
class AnnotationImplementationError(ImplementationError):
"""Error raised when the annotations of a collection are invalid."""

def __init__(self, attr: str, kls: type):
def __init__(self, attr: str, kls: type) -> None:
message = (
"Annotations of a 'dy.Collection' may only be an (optional) "
f"'dy.LazyFrame', but \"{attr}\" has type '{kls}'."
Expand All @@ -106,7 +108,9 @@ def __init__(self, attr: str, kls: type):
class RuleImplementationError(ImplementationError):
"""Error raised when a rule is implemented incorrectly."""

def __init__(self, name: str, return_dtype: pl.DataType, is_group_rule: bool):
def __init__(
self, name: str, return_dtype: pl.DataType, is_group_rule: bool
) -> None:
if is_group_rule:
details = (
" When implementing a group rule (i.e. when using the `group_by` "
Expand Down
6 changes: 4 additions & 2 deletions dataframely/failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class FailureInfo(Generic[S]):
#: The schema used to create the input data frame.
schema: type[S]

def __init__(self, lf: pl.LazyFrame, rule_columns: list[str], schema: type[S]):
def __init__(
self, lf: pl.LazyFrame, rule_columns: list[str], schema: type[S]
) -> None:
self._lf = lf
self._rule_columns = rule_columns
self.schema = schema
Expand Down Expand Up @@ -71,7 +73,7 @@ def __len__(self) -> int:

# ---------------------------------- PERSISTENCE --------------------------------- #

def write_parquet(self, file: str | Path | IO[bytes]):
def write_parquet(self, file: str | Path | IO[bytes]) -> None:
"""Write the failure info to a Parquet file.

Args:
Expand Down
13 changes: 7 additions & 6 deletions dataframely/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# --------------------------------------- RULES -------------------------------------- #


def mark_rules_as_staticmethod(ctx: ClassDefContext):
def mark_rules_as_staticmethod(ctx: ClassDefContext) -> None:
"""Mark all methods decorated with `@rule` as `staticmethod`s."""
info = ctx.cls.info
for sym in info.names.values():
Expand Down Expand Up @@ -199,7 +199,7 @@ def _convert_dy_column_to_dtype(
def store_typed_dict_type_for_schema(
ctx: ClassDefContext,
schema_registry: dict[str, TypedDictType],
):
) -> None:
"""Add `TypedDictType` inferred from the schema's columns to a given registry."""

schema_type = ctx.cls.info
Expand Down Expand Up @@ -336,11 +336,13 @@ def alter_dataframe_iter_rows_return_type(


class DataframelyPlugin(Plugin):
def __init__(self, options: Options):
def __init__(self, options: Options) -> None:
super().__init__(options)
self.schema_registry: dict[str, TypedDictType] = {}

def get_base_class_hook(self, fullname: str):
def get_base_class_hook(
self, fullname: str
) -> Callable[[ClassDefContext], None] | None:
# Given a class, check whether it is a subclass of `dy.Schema`. If so, mark
# all methods decorated with `@rule` as staticmethods.
# Also, store the `TypedDictType` for the schema in a registry to allow downstream
Expand All @@ -349,13 +351,12 @@ def get_base_class_hook(self, fullname: str):
if sym and isinstance(sym.node, TypeInfo):
if any(base.fullname == SCHEMA_FULLNAME for base in sym.node.mro):

def _hook(ctx: ClassDefContext) -> bool:
def _hook(ctx: ClassDefContext) -> None:
mark_rules_as_staticmethod(ctx)
store_typed_dict_type_for_schema(
ctx,
self.schema_registry,
)
return True

return _hook
return None
Expand Down
2 changes: 1 addition & 1 deletion dataframely/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Generator:
seeding.
"""

def __init__(self, seed: int | None = None):
def __init__(self, seed: int | None = None) -> None:
"""
Args:
seed: The seed to use for initializing the random number generator used
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
# Copied and adapted from
# https://github.com/pandas-dev/pandas/blob/4a14d064187367cacab3ff4652a12a0e45d0711b/doc/source/conf.py#L613-L659
# Required configuration function to use sphinx.ext.linkcode
def linkcode_resolve(domain, info):
def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None:
"""Determine the URL corresponding to a given Python object."""
if domain != "py":
return None
Expand Down
2 changes: 1 addition & 1 deletion docs/sites/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ expectations on the schema of the data frame, e.g.:

::

def train_model(df: dy.DataFrame[HouseSchema]):
def train_model(df: dy.DataFrame[HouseSchema]) -> None:
...

The type checker (typically ``mypy``) then ensures that it is actually a
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ quote-style = "double"

[tool.mypy]
check_untyped_defs = true
disallow_untyped_defs = true
exclude = ["docs/"]
explicit_package_bases = true
no_implicit_optional = true
Expand Down
24 changes: 13 additions & 11 deletions tests/collection/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,32 @@ class MyCollection(dy.Collection):
second: dy.LazyFrame[MySecondSchema] | None


def test_common_primary_keys():
def test_common_primary_keys() -> None:
assert MyCollection.common_primary_keys() == ["a"]


def test_members():
def test_members() -> None:
members = MyCollection.members()
assert not members["first"].is_optional
assert members["second"].is_optional


def test_member_schemas():
def test_member_schemas() -> None:
schemas = MyCollection.member_schemas()
assert schemas == {"first": MyFirstSchema, "second": MySecondSchema}


def test_required_members():
def test_required_members() -> None:
required_members = MyCollection.required_members()
assert required_members == {"first"}


def test_optional_members():
def test_optional_members() -> None:
optional_members = MyCollection.optional_members()
assert optional_members == {"second"}


def test_cast():
def test_cast() -> None:
collection = MyCollection.cast(
{
"first": pl.LazyFrame({"a": [1, 2, 3]}),
Expand All @@ -74,7 +74,7 @@ def test_cast():
{"first": pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8})},
],
)
def test_to_dict(expected: dict[str, pl.LazyFrame]):
def test_to_dict(expected: dict[str, pl.LazyFrame]) -> None:
collection = MyCollection.validate(expected)

# Check that export looks as expected
Expand All @@ -87,7 +87,7 @@ def test_to_dict(expected: dict[str, pl.LazyFrame]):
assert MyCollection.is_valid(observed)


def test_collect_all():
def test_collect_all() -> None:
collection = MyCollection.cast(
{
"first": pl.LazyFrame({"a": [1, 2, 3]}).filter(pl.col("a") < 3),
Expand All @@ -106,7 +106,7 @@ def test_collect_all():
assert len(out.second.collect()) == 2


def test_collect_all_optional():
def test_collect_all_optional() -> None:
collection = MyCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})})
out = collection.collect_all()

Expand All @@ -118,7 +118,9 @@ def test_collect_all_optional():
@pytest.mark.parametrize(
"read_fn", [MyCollection.scan_parquet, MyCollection.read_parquet]
)
def test_read_write_parquet(tmp_path: Path, read_fn: Callable[[Path], MyCollection]):
def test_read_write_parquet(
tmp_path: Path, read_fn: Callable[[Path], MyCollection]
) -> None:
collection = MyCollection.cast(
{
"first": pl.LazyFrame({"a": [1, 2, 3]}),
Expand All @@ -139,7 +141,7 @@ def test_read_write_parquet(tmp_path: Path, read_fn: Callable[[Path], MyCollecti
)
def test_read_write_parquet_optional(
tmp_path: Path, read_fn: Callable[[Path], MyCollection]
):
) -> None:
collection = MyCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})})
collection.write_parquet(tmp_path)

Expand Down
Loading
Loading