From 9d5edf6dc9429bd9aeb5ca4591f763e0f2ded63b Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Sat, 19 Apr 2025 23:41:16 +0200 Subject: [PATCH 1/2] style: Enforce type annotations --- dataframely/_base_collection.py | 4 +-- dataframely/_base_schema.py | 4 +-- dataframely/_compat.py | 2 +- dataframely/_filter.py | 2 +- dataframely/_rule.py | 4 +-- dataframely/collection.py | 4 +-- dataframely/columns/_mixins.py | 2 +- dataframely/columns/decimal.py | 4 ++- dataframely/config.py | 10 +++--- dataframely/exc.py | 18 ++++++---- dataframely/failure.py | 6 ++-- dataframely/mypy.py | 13 +++---- dataframely/random.py | 2 +- docs/conf.py | 2 +- docs/sites/quickstart.rst | 4 +-- pyproject.toml | 1 + tests/collection/test_base.py | 24 +++++++------ tests/collection/test_cast.py | 10 +++--- tests/collection/test_create_empty.py | 2 +- tests/collection/test_filter_one_to_n.py | 2 +- tests/collection/test_filter_validate.py | 16 ++++----- tests/collection/test_ignore_in_filter.py | 6 ++-- tests/collection/test_implementation.py | 28 +++++++-------- tests/collection/test_optional_members.py | 6 ++-- tests/collection/test_sample.py | 14 ++++---- tests/collection/test_validate_input.py | 4 +-- tests/column_types/test_any.py | 2 +- tests/column_types/test_datetime.py | 18 ++++++---- tests/column_types/test_decimal.py | 14 ++++---- tests/column_types/test_enum.py | 4 +-- tests/column_types/test_float.py | 26 +++++++------- tests/column_types/test_integer.py | 26 +++++++------- tests/column_types/test_list.py | 22 ++++++------ tests/column_types/test_string.py | 8 ++--- tests/column_types/test_struct.py | 16 ++++----- tests/columns/test_alias.py | 6 ++-- tests/columns/test_check.py | 2 +- tests/columns/test_default_dtypes.py | 2 +- tests/columns/test_metadata.py | 2 +- tests/columns/test_pyarrow.py | 18 +++++----- tests/columns/test_rules.py | 6 ++-- tests/columns/test_sample.py | 36 ++++++++++--------- tests/columns/test_sql_schema.py | 14 ++++---- tests/columns/test_str.py | 8 ++--- tests/columns/test_utils.py | 10 +++--- .../core_validation/test_column_validation.py | 4 +-- .../core_validation/test_dtype_validation.py | 8 ++--- tests/core_validation/test_rule_evaluation.py | 14 ++++---- tests/functional/test_concat.py | 4 +-- tests/functional/test_relationships.py | 4 +-- tests/schema/__init__.py | 2 -- tests/schema/test_base.py | 16 ++++----- tests/schema/test_cast.py | 6 ++-- tests/schema/test_create_empty.py | 4 +-- tests/schema/test_create_empty_if_none.py | 4 +-- tests/schema/test_filter.py | 18 ++++++---- tests/schema/test_inheritance.py | 2 +- tests/schema/test_rule_implementation.py | 8 ++--- tests/schema/test_sample.py | 22 ++++++------ tests/schema/test_validate.py | 16 +++++---- tests/test_compat.py | 2 +- tests/test_config.py | 8 ++--- tests/test_exc.py | 6 ++-- tests/test_extre.py | 18 +++++----- tests/test_failure_info.py | 2 +- tests/test_random.py | 26 +++++++------- tests/test_typing.py | 26 +++++++------- 67 files changed, 343 insertions(+), 311 deletions(-) delete mode 100644 tests/schema/__init__.py diff --git a/dataframely/_base_collection.py b/dataframely/_base_collection.py index 5d2ae12..a8ecee2 100644 --- a/dataframely/_base_collection.py +++ b/dataframely/_base_collection.py @@ -79,7 +79,7 @@ class Metadata: members: dict[str, MemberInfo] = field(default_factory=dict) filters: dict[str, Filter] = field(default_factory=dict) - def update(self, other: Self): + def update(self, other: Self) -> None: self.members.update(other.members) self.filters.update(other.filters) @@ -92,7 +92,7 @@ def __new__( namespace: dict[str, Any], *args: Any, **kwargs: Any, - ): + ) -> CollectionMeta: result = Metadata() for base in bases: result.update(mcs._get_metadata_recursively(base)) diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index e17e538..f123e09 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -58,7 +58,7 @@ class Metadata: columns: dict[str, Column] = field(default_factory=dict) rules: dict[str, Rule] = field(default_factory=dict) - def update(self, other: Self): + def update(self, other: Self) -> None: self.columns.update(other.columns) self.rules.update(other.rules) @@ -71,7 +71,7 @@ def __new__( namespace: dict[str, Any], *args: Any, **kwargs: Any, - ): + ) -> SchemaMeta: result = Metadata() for base in bases: result.update(mcs._get_metadata_recursively(base)) diff --git a/dataframely/_compat.py b/dataframely/_compat.py index 3768dff..daf369e 100644 --- a/dataframely/_compat.py +++ b/dataframely/_compat.py @@ -6,7 +6,7 @@ class _DummyModule: # pragma: no cover - def __init__(self, module: str): + def __init__(self, module: str) -> None: self.module = module def __getattr__(self, name: str) -> Any: diff --git a/dataframely/_filter.py b/dataframely/_filter.py index adfb89e..b99d95e 100644 --- a/dataframely/_filter.py +++ b/dataframely/_filter.py @@ -12,7 +12,7 @@ class Filter(Generic[C]): """Internal class representing logic for filtering members of a collection.""" - def __init__(self, logic: Callable[[C], pl.LazyFrame]): + def __init__(self, logic: Callable[[C], pl.LazyFrame]) -> None: self.logic = logic diff --git a/dataframely/_rule.py b/dataframely/_rule.py index 607233c..02df0e2 100644 --- a/dataframely/_rule.py +++ b/dataframely/_rule.py @@ -12,14 +12,14 @@ class Rule: """Internal class representing validation rules.""" - def __init__(self, expr: pl.Expr): + def __init__(self, expr: pl.Expr) -> None: self.expr = expr class GroupRule(Rule): """Rule that is evaluated on a group of columns.""" - def __init__(self, expr: pl.Expr, group_columns: list[str]): + def __init__(self, expr: pl.Expr, group_columns: list[str]) -> None: super().__init__(expr) self.group_columns = group_columns diff --git a/dataframely/collection.py b/dataframely/collection.py index f80e889..293e592 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -498,7 +498,7 @@ def collect_all(self) -> Self: # ---------------------------------- PERSISTENCE --------------------------------- # - def write_parquet(self, directory: Path): + def write_parquet(self, directory: Path) -> None: """Write the members of this collection to Parquet files in a directory. This method writes one Parquet file per member into the provided directory. @@ -590,7 +590,7 @@ def _init(cls, data: Mapping[str, FrameType], /) -> Self: return out @classmethod - def _validate_input_keys(cls, data: Mapping[str, FrameType], /): + def _validate_input_keys(cls, data: Mapping[str, FrameType], /) -> None: actual = set(data) missing = cls.required_members() - actual diff --git a/dataframely/columns/_mixins.py b/dataframely/columns/_mixins.py index 33883dd..cf03f45 100644 --- a/dataframely/columns/_mixins.py +++ b/dataframely/columns/_mixins.py @@ -83,7 +83,7 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: class IsInMixin(Generic[U], Base): """Mixin to use for types implementing "is in".""" - def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any): + def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any) -> None: super().__init__(**kwargs) self.is_in = is_in diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index 5cdb429..5bb0aeb 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -148,7 +148,9 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # --------------------------------------- UTILS -------------------------------------- # -def _validate(value: decimal.Decimal, precision: int | None, scale: int, name: str): +def _validate( + value: decimal.Decimal, precision: int | None, scale: int, name: str +) -> None: exponent = value.as_tuple().exponent if not isinstance(exponent, int): raise ValueError(f"Encountered 'inf' or 'NaN' for `{name}`.") diff --git a/dataframely/config.py b/dataframely/config.py index bdf8009..196dc01 100644 --- a/dataframely/config.py +++ b/dataframely/config.py @@ -25,23 +25,23 @@ class Config(contextlib.ContextDecorator): #: Singleton stack to track where to go back after exiting a context. _stack: list[Options] = [] - def __init__(self, **options: Unpack[Options]): + def __init__(self, **options: Unpack[Options]) -> None: self._local_options: Options = {**default_options(), **options} @staticmethod - def set_max_sampling_iterations(iterations: int): + def set_max_sampling_iterations(iterations: int) -> None: """Set the maximum number of sampling iterations to use on :meth:`Schema.sample`.""" Config.options["max_sampling_iterations"] = iterations @staticmethod - def restore_defaults(): + def restore_defaults() -> None: """Restore the defaults of the configuration.""" Config.options = default_options() # ------------------------------------ CONTEXT ----------------------------------- # - def __enter__(self): + def __enter__(self) -> None: Config._stack.append(Config.options) Config.options = self._local_options @@ -50,5 +50,5 @@ def __exit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ): + ) -> None: Config.options = Config._stack.pop() diff --git a/dataframely/exc.py b/dataframely/exc.py index 474ad51..f341012 100644 --- a/dataframely/exc.py +++ b/dataframely/exc.py @@ -11,7 +11,7 @@ class ValidationError(Exception): """Error raised when :mod:`dataframely` validation encounters an issue.""" - def __init__(self, message: str): + def __init__(self, message: str) -> None: super().__init__() self.message = message @@ -22,7 +22,9 @@ def __str__(self) -> str: class DtypeValidationError(ValidationError): """Validation error raised when column dtypes are wrong.""" - def __init__(self, errors: dict[str, tuple[PolarsDataType, PolarsDataType]]): + def __init__( + self, errors: dict[str, tuple[PolarsDataType, PolarsDataType]] + ) -> None: super().__init__(f"{len(errors)} columns have an invalid dtype") self.errors = errors @@ -37,7 +39,7 @@ def __str__(self) -> str: class RuleValidationError(ValidationError): """Complex validation error raised when rule validation fails.""" - def __init__(self, errors: dict[str, int]): + def __init__(self, errors: dict[str, int]) -> None: super().__init__(f"{len(errors)} rules failed validation") # Split into schema errors and column errors @@ -75,11 +77,11 @@ def __str__(self) -> str: class MemberValidationError(ValidationError): """Validation error raised when multiple members of a collection fail validation.""" - def __init__(self, errors: dict[str, ValidationError]): + def __init__(self, errors: dict[str, ValidationError]) -> None: super().__init__(f"{len(errors)} members failed validation") self.errors = errors - def __str__(self): + def __str__(self) -> str: details = [ f" > Member '{name}' failed validation:\n" + "\n".join(" " + line for line in str(error).split("\n")) @@ -95,7 +97,7 @@ class ImplementationError(Exception): class AnnotationImplementationError(ImplementationError): """Error raised when the annotations of a collection are invalid.""" - def __init__(self, attr: str, kls: type): + def __init__(self, attr: str, kls: type) -> None: message = ( "Annotations of a 'dy.Collection' may only be an (optional) " f"'dy.LazyFrame', but \"{attr}\" has type '{kls}'." @@ -106,7 +108,9 @@ def __init__(self, attr: str, kls: type): class RuleImplementationError(ImplementationError): """Error raised when a rule is implemented incorrectly.""" - def __init__(self, name: str, return_dtype: pl.DataType, is_group_rule: bool): + def __init__( + self, name: str, return_dtype: pl.DataType, is_group_rule: bool + ) -> None: if is_group_rule: details = ( " When implementing a group rule (i.e. when using the `group_by` " diff --git a/dataframely/failure.py b/dataframely/failure.py index ea0804b..9be9b5a 100644 --- a/dataframely/failure.py +++ b/dataframely/failure.py @@ -28,7 +28,9 @@ class FailureInfo(Generic[S]): #: The schema used to create the input data frame. schema: type[S] - def __init__(self, lf: pl.LazyFrame, rule_columns: list[str], schema: type[S]): + def __init__( + self, lf: pl.LazyFrame, rule_columns: list[str], schema: type[S] + ) -> None: self._lf = lf self._rule_columns = rule_columns self.schema = schema @@ -71,7 +73,7 @@ def __len__(self) -> int: # ---------------------------------- PERSISTENCE --------------------------------- # - def write_parquet(self, file: str | Path | IO[bytes]): + def write_parquet(self, file: str | Path | IO[bytes]) -> None: """Write the failure info to a Parquet file. Args: diff --git a/dataframely/mypy.py b/dataframely/mypy.py index 3a9e497..4bd1ac6 100644 --- a/dataframely/mypy.py +++ b/dataframely/mypy.py @@ -50,7 +50,7 @@ # --------------------------------------- RULES -------------------------------------- # -def mark_rules_as_staticmethod(ctx: ClassDefContext): +def mark_rules_as_staticmethod(ctx: ClassDefContext) -> None: """Mark all methods decorated with `@rule` as `staticmethod`s.""" info = ctx.cls.info for sym in info.names.values(): @@ -199,7 +199,7 @@ def _convert_dy_column_to_dtype( def store_typed_dict_type_for_schema( ctx: ClassDefContext, schema_registry: dict[str, TypedDictType], -): +) -> None: """Add `TypedDictType` inferred from the schema's columns to a given registry.""" schema_type = ctx.cls.info @@ -336,11 +336,13 @@ def alter_dataframe_iter_rows_return_type( class DataframelyPlugin(Plugin): - def __init__(self, options: Options): + def __init__(self, options: Options) -> None: super().__init__(options) self.schema_registry: dict[str, TypedDictType] = {} - def get_base_class_hook(self, fullname: str): + def get_base_class_hook( + self, fullname: str + ) -> Callable[[ClassDefContext], None] | None: # Given a class, check whether it is a subclass of `dy.Schema`. If so, mark # all methods decorated with `@rule` as staticmethods. # Also, store the `TypedDictType` for the schema in a registry to allow downstream @@ -349,13 +351,12 @@ def get_base_class_hook(self, fullname: str): if sym and isinstance(sym.node, TypeInfo): if any(base.fullname == SCHEMA_FULLNAME for base in sym.node.mro): - def _hook(ctx: ClassDefContext) -> bool: + def _hook(ctx: ClassDefContext) -> None: mark_rules_as_staticmethod(ctx) store_typed_dict_type_for_schema( ctx, self.schema_registry, ) - return True return _hook return None diff --git a/dataframely/random.py b/dataframely/random.py index 9998752..2fbcd9c 100644 --- a/dataframely/random.py +++ b/dataframely/random.py @@ -33,7 +33,7 @@ class Generator: seeding. """ - def __init__(self, seed: int | None = None): + def __init__(self, seed: int | None = None) -> None: """ Args: seed: The seed to use for initializing the random number generator used diff --git a/docs/conf.py b/docs/conf.py index eacb58d..8175cf4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ # Copied and adapted from # https://github.com/pandas-dev/pandas/blob/4a14d064187367cacab3ff4652a12a0e45d0711b/doc/source/conf.py#L613-L659 # Required configuration function to use sphinx.ext.linkcode -def linkcode_resolve(domain, info): +def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None: """Determine the URL corresponding to a given Python object.""" if domain != "py": return None diff --git a/docs/sites/quickstart.rst b/docs/sites/quickstart.rst index c0d4534..f340300 100644 --- a/docs/sites/quickstart.rst +++ b/docs/sites/quickstart.rst @@ -162,7 +162,7 @@ expectations on the schema of the data frame, e.g.: :: - def train_model(df: dy.DataFrame[HouseSchema]): + def train_model(df: dy.DataFrame[HouseSchema]) -> None: ... The type checker (typically ``mypy``) then ensures that it is actually a @@ -197,7 +197,7 @@ In this case, ``good`` remains to be a ``dy.DataFrame[HouseSchema]``, albeit wit The ``failure`` object is of type :class:`~dataframely.FailureInfo` and provides means to inspect the reasons for validation failures for invalid rows. -Given the example data above and the schema that we defined, we know that rows 2, 3, 4, and 5 are invalid (0-indexed): +Given the example data above and the schema that we defined, we know that rows 2, 3, 4, and 5 are invalid (0-indexed) -> None: - Row 2 has a zip code that does not appear at least twice - Row 3 has a NULL value for the number of bedrooms diff --git a/pyproject.toml b/pyproject.toml index 088c2b8..9d9de4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ quote-style = "double" [tool.mypy] check_untyped_defs = true +disallow_untyped_defs = true exclude = ["docs/"] explicit_package_bases = true no_implicit_optional = true diff --git a/tests/collection/test_base.py b/tests/collection/test_base.py index 9f606ec..1b6c3fa 100644 --- a/tests/collection/test_base.py +++ b/tests/collection/test_base.py @@ -25,32 +25,32 @@ class MyCollection(dy.Collection): second: dy.LazyFrame[MySecondSchema] | None -def test_common_primary_keys(): +def test_common_primary_keys() -> None: assert MyCollection.common_primary_keys() == ["a"] -def test_members(): +def test_members() -> None: members = MyCollection.members() assert not members["first"].is_optional assert members["second"].is_optional -def test_member_schemas(): +def test_member_schemas() -> None: schemas = MyCollection.member_schemas() assert schemas == {"first": MyFirstSchema, "second": MySecondSchema} -def test_required_members(): +def test_required_members() -> None: required_members = MyCollection.required_members() assert required_members == {"first"} -def test_optional_members(): +def test_optional_members() -> None: optional_members = MyCollection.optional_members() assert optional_members == {"second"} -def test_cast(): +def test_cast() -> None: collection = MyCollection.cast( { "first": pl.LazyFrame({"a": [1, 2, 3]}), @@ -74,7 +74,7 @@ def test_cast(): {"first": pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8})}, ], ) -def test_to_dict(expected: dict[str, pl.LazyFrame]): +def test_to_dict(expected: dict[str, pl.LazyFrame]) -> None: collection = MyCollection.validate(expected) # Check that export looks as expected @@ -87,7 +87,7 @@ def test_to_dict(expected: dict[str, pl.LazyFrame]): assert MyCollection.is_valid(observed) -def test_collect_all(): +def test_collect_all() -> None: collection = MyCollection.cast( { "first": pl.LazyFrame({"a": [1, 2, 3]}).filter(pl.col("a") < 3), @@ -106,7 +106,7 @@ def test_collect_all(): assert len(out.second.collect()) == 2 -def test_collect_all_optional(): +def test_collect_all_optional() -> None: collection = MyCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})}) out = collection.collect_all() @@ -118,7 +118,9 @@ def test_collect_all_optional(): @pytest.mark.parametrize( "read_fn", [MyCollection.scan_parquet, MyCollection.read_parquet] ) -def test_read_write_parquet(tmp_path: Path, read_fn: Callable[[Path], MyCollection]): +def test_read_write_parquet( + tmp_path: Path, read_fn: Callable[[Path], MyCollection] +) -> None: collection = MyCollection.cast( { "first": pl.LazyFrame({"a": [1, 2, 3]}), @@ -139,7 +141,7 @@ def test_read_write_parquet(tmp_path: Path, read_fn: Callable[[Path], MyCollecti ) def test_read_write_parquet_optional( tmp_path: Path, read_fn: Callable[[Path], MyCollection] -): +) -> None: collection = MyCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})}) collection.write_parquet(tmp_path) diff --git a/tests/collection/test_cast.py b/tests/collection/test_cast.py index 2488d79..a6ed911 100644 --- a/tests/collection/test_cast.py +++ b/tests/collection/test_cast.py @@ -22,7 +22,7 @@ class Collection(dy.Collection): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_cast_valid(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_cast_valid(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: first = df_type({"a": [3]}) second = df_type({"a": [1]}) out = Collection.cast({"first": first, "second": second}) # type: ignore @@ -32,7 +32,7 @@ def test_cast_valid(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_cast_valid_optional(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_cast_valid_optional(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: first = df_type({"a": [3]}) out = Collection.cast({"first": first}) # type: ignore assert out.first.collect_schema() == FirstSchema.polars_schema() @@ -40,19 +40,19 @@ def test_cast_valid_optional(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_cast_invalid_members(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_cast_invalid_members(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: first = df_type({"a": [3]}) with pytest.raises(ValueError): Collection.cast({"third": first}) # type: ignore -def test_cast_invalid_member_schema_eager(): +def test_cast_invalid_member_schema_eager() -> None: first = pl.DataFrame({"b": [3]}) with pytest.raises(plexc.ColumnNotFoundError): Collection.cast({"first": first}) -def test_cast_invalid_member_schema_lazy(): +def test_cast_invalid_member_schema_lazy() -> None: first = pl.LazyFrame({"b": [3]}) collection = Collection.cast({"first": first}) with pytest.raises(plexc.ColumnNotFoundError): diff --git a/tests/collection/test_create_empty.py b/tests/collection/test_create_empty.py index eed7401..3b17fad 100644 --- a/tests/collection/test_create_empty.py +++ b/tests/collection/test_create_empty.py @@ -20,7 +20,7 @@ class MyCollection(dy.Collection): second: dy.LazyFrame[MySecondSchema] | None -def test_create_empty(): +def test_create_empty() -> None: collection = MyCollection.create_empty() assert collection.first.collect().height == 0 assert collection.first.collect_schema() == MyFirstSchema.polars_schema() diff --git a/tests/collection/test_filter_one_to_n.py b/tests/collection/test_filter_one_to_n.py index 0e6f414..17ffd7a 100644 --- a/tests/collection/test_filter_one_to_n.py +++ b/tests/collection/test_filter_one_to_n.py @@ -28,7 +28,7 @@ def not_car_with_vin_123(self) -> pl.LazyFrame: return self.cars.filter(pl.col("vin") != pl.lit("123")) -def test_valid_failure_infos(): +def test_valid_failure_infos() -> None: cars = {"vin": ["123", "456"], "manufacturer": ["BMW", "Mercedes"]} car_parts: dict[str, list[Any]] = { "vin": ["123", "123", "456"], diff --git a/tests/collection/test_filter_validate.py b/tests/collection/test_filter_validate.py index 727f67c..4c6cb15 100644 --- a/tests/collection/test_filter_validate.py +++ b/tests/collection/test_filter_validate.py @@ -82,7 +82,7 @@ def data_with_filter_with_rule_violation() -> tuple[pl.LazyFrame, pl.LazyFrame]: def test_filter_without_filter_without_rule_violation( data_without_filter_without_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: out, failure = SimpleCollection.filter( { "first": data_without_filter_without_rule_violation[0], @@ -99,7 +99,7 @@ def test_filter_without_filter_without_rule_violation( def test_filter_without_filter_with_rule_violation( data_without_filter_with_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: out, failure = SimpleCollection.filter( { "first": data_without_filter_with_rule_violation[0], @@ -116,7 +116,7 @@ def test_filter_without_filter_with_rule_violation( def test_filter_with_filter_without_rule_violation( data_with_filter_without_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: out, failure = MyCollection.filter( { "first": data_with_filter_without_rule_violation[0], @@ -139,7 +139,7 @@ def test_filter_with_filter_without_rule_violation( def test_filter_with_filter_with_rule_violation( data_with_filter_with_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: out, failure = MyCollection.filter( { "first": data_with_filter_with_rule_violation[0], @@ -159,7 +159,7 @@ def test_filter_with_filter_with_rule_violation( def test_validate_without_filter_without_rule_violation( data_without_filter_without_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: data = { "first": data_without_filter_without_rule_violation[0], "second": data_without_filter_without_rule_violation[1], @@ -174,7 +174,7 @@ def test_validate_without_filter_without_rule_violation( def test_validate_without_filter_with_rule_violation( data_without_filter_with_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: data = { "first": data_without_filter_with_rule_violation[0], "second": data_without_filter_with_rule_violation[1], @@ -194,7 +194,7 @@ def test_validate_without_filter_with_rule_violation( def test_validate_with_filter_without_rule_violation( data_with_filter_without_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: data = { "first": data_with_filter_without_rule_violation[0], "second": data_with_filter_without_rule_violation[1], @@ -215,7 +215,7 @@ def test_validate_with_filter_without_rule_violation( def test_validate_with_filter_with_rule_violation( data_with_filter_with_rule_violation: tuple[pl.LazyFrame, pl.LazyFrame], -): +) -> None: data = { "first": data_with_filter_with_rule_violation[0], "second": data_with_filter_with_rule_violation[1], diff --git a/tests/collection/test_ignore_in_filter.py b/tests/collection/test_ignore_in_filter.py index e7a5753..54b8ce7 100644 --- a/tests/collection/test_ignore_in_filter.py +++ b/tests/collection/test_ignore_in_filter.py @@ -46,12 +46,12 @@ def custom_filter_on_ignored(self) -> pl.LazyFrame: return used_a_ids -def test_collection_ignore_in_filter_meta(): +def test_collection_ignore_in_filter_meta() -> None: assert MyTestCollection.non_ignored_members() == {"a", "b"} assert MyTestCollection.ignored_members() == {"ignored"} -def test_collection_ignore_in_filter(): +def test_collection_ignore_in_filter() -> None: success, failure = MyTestCollection.filter( { "a": pl.LazyFrame({"a_id": [1, 2, 3]}), @@ -65,7 +65,7 @@ def test_collection_ignore_in_filter(): assert failure["ignored"].invalid().height == 0 -def test_collection_ignore_in_filter_failure(): +def test_collection_ignore_in_filter_failure() -> None: success, failure = MyTestCollection.filter( { "a": pl.LazyFrame({"a_id": [1, 2, 3]}), diff --git a/tests/collection/test_implementation.py b/tests/collection/test_implementation.py index 13ba7da..74c4094 100644 --- a/tests/collection/test_implementation.py +++ b/tests/collection/test_implementation.py @@ -16,7 +16,7 @@ class MyTestSchema(dy.Schema): a = dy.Integer(primary_key=True) -def test_annotation_type_failure(): +def test_annotation_type_failure() -> None: with pytest.raises( AnnotationImplementationError, ): @@ -29,7 +29,7 @@ def test_annotation_type_failure(): ) -def test_annotation_union_success(): +def test_annotation_union_success() -> None: """When we use a union annotation, it must contain one typed LazyFrame and None.""" create_collection_raw( "test", @@ -39,7 +39,7 @@ def test_annotation_union_success(): ) -def test_annotation_union_with_data_frame(): +def test_annotation_union_with_data_frame() -> None: """When we use a union annotation, it must contain one typed LazyFrame and None.""" with pytest.raises(AnnotationImplementationError): create_collection_raw( @@ -50,7 +50,7 @@ def test_annotation_union_with_data_frame(): ) -def test_annotation_union_too_many_arg_failure(): +def test_annotation_union_too_many_arg_failure() -> None: """Unions should have a maximum of two types in them.""" with pytest.raises(AnnotationImplementationError): @@ -66,7 +66,7 @@ def test_annotation_union_too_many_arg_failure(): ) -def test_annotation_union_conflicting_types_failure(): +def test_annotation_union_conflicting_types_failure() -> None: """Unions should contain a maximum of one non-None type.""" with pytest.raises(AnnotationImplementationError): @@ -81,7 +81,7 @@ def test_annotation_union_conflicting_types_failure(): ) -def test_annotation_only_none_failure(): +def test_annotation_only_none_failure() -> None: """Annotations must not just be None.""" with pytest.raises(AnnotationImplementationError): create_collection_raw( @@ -92,7 +92,7 @@ def test_annotation_only_none_failure(): ) -def test_annotation_invalid_type_failure(): +def test_annotation_invalid_type_failure() -> None: """First argument of union must be a LazyFrame.""" with pytest.raises(AnnotationImplementationError): create_collection_raw( @@ -103,7 +103,7 @@ def test_annotation_invalid_type_failure(): ) -def test_explicit_annotation_type_failure_no_frame_type(): +def test_explicit_annotation_type_failure_no_frame_type() -> None: """First argument of the annotated union must be a LazyFrame.""" with pytest.raises(AnnotationImplementationError): create_collection_raw( @@ -114,7 +114,7 @@ def test_explicit_annotation_type_failure_no_frame_type(): ) -def test_explicit_annotation_type_failure_too_many_args(): +def test_explicit_annotation_type_failure_too_many_args() -> None: """Annotations should have a maximum of two arguments in them.""" with pytest.raises(AnnotationImplementationError): create_collection_raw( @@ -129,7 +129,7 @@ def test_explicit_annotation_type_failure_too_many_args(): ) -def test_explicit_annotation_type_failure_arg1_type(): +def test_explicit_annotation_type_failure_arg1_type() -> None: """The second argument of the annotated union must be a CollectionMember.""" with pytest.raises(AnnotationImplementationError): create_collection_raw( @@ -140,7 +140,7 @@ def test_explicit_annotation_type_failure_arg1_type(): ) -def test_name_overlap(): +def test_name_overlap() -> None: with pytest.raises( ImplementationError, match=r"Filters defined on the collection must not be named the same", @@ -155,7 +155,7 @@ def test_name_overlap(): ) -def test_collection_no_primary_key_success(): +def test_collection_no_primary_key_success() -> None: """It's ok not to have primary keys if there are no filters.""" create_collection( "test", @@ -165,7 +165,7 @@ def test_collection_no_primary_key_success(): ) -def test_collection_no_primary_key_failure(): +def test_collection_no_primary_key_failure() -> None: """If you have a filter, you must also have a primary key.""" with pytest.raises( ImplementationError, @@ -180,7 +180,7 @@ def test_collection_no_primary_key_failure(): ) -def test_collection_primary_key_but_not_common(): +def test_collection_primary_key_but_not_common() -> None: """If you have a filter, you must also have a common primary key between members.""" with pytest.raises( ImplementationError, diff --git a/tests/collection/test_optional_members.py b/tests/collection/test_optional_members.py index 7782f30..d2afc1b 100644 --- a/tests/collection/test_optional_members.py +++ b/tests/collection/test_optional_members.py @@ -15,17 +15,17 @@ class MyCollection(dy.Collection): second: dy.LazyFrame[TestSchema] | None -def test_collection_optional_member(): +def test_collection_optional_member() -> None: MyCollection.validate({"first": pl.LazyFrame({"a": [1, 2, 3]})}) -def test_filter_failure_info_keys_only_required(): +def test_filter_failure_info_keys_only_required() -> None: out, failure = MyCollection.filter({"first": pl.LazyFrame({"a": [1, 2, 3]})}) assert out.second is None assert set(failure.keys()) == {"first"} -def test_filter_failure_info_keys_required_and_optional(): +def test_filter_failure_info_keys_required_and_optional() -> None: out, failure = MyCollection.filter( { "first": pl.LazyFrame({"a": [1, 2, 3]}), diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index e439b83..4034387 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -78,14 +78,14 @@ def _preprocess_sample( @pytest.mark.parametrize("n", [0, 1000]) -def test_sample_rows(n: int): +def test_sample_rows(n: int) -> None: collection = MyCollection.sample(n) assert collection.first.collect()["a"].to_list() == list(range(n)) assert collection.second is not None assert collection.second.collect().is_empty() -def test_sample_with_overrides(): +def test_sample_with_overrides() -> None: collection = MyCollection.sample( overrides=[ {"first": {"b": 4}, "second": [{"c": 3}, {"c": 4}]}, @@ -101,27 +101,27 @@ def test_sample_with_overrides(): @pytest.mark.parametrize("n", [0, 1000]) -def test_sample_without_dependent_members(n: int): +def test_sample_without_dependent_members(n: int) -> None: collection = SmallCollection.sample(n) assert collection.first.collect().height == n @pytest.mark.parametrize("n", [0, 1000]) -def test_sample_with_ignored_members(n: int): +def test_sample_with_ignored_members(n: int) -> None: collection = IgnoringCollection.sample(n) assert collection.first.collect()["a"].to_list() == list(range(n)) -def test_sample_num_rows_mismatch(): +def test_sample_num_rows_mismatch() -> None: with pytest.raises(ValueError, match=r"`num_rows` mismatches"): MyCollection.sample(num_rows=1, overrides=[]) -def test_sample_no_common_primary_key(): +def test_sample_no_common_primary_key() -> None: with pytest.raises(ValueError, match=r"must contain the common primary keys"): ErroneousCollection.sample() -def test_sample_no_overwrite(): +def test_sample_no_overwrite() -> None: with pytest.raises(ValueError, match=r"`_preprocess_sample` must be overwritten"): IncompleteCollection.sample() diff --git a/tests/collection/test_validate_input.py b/tests/collection/test_validate_input.py index 3da9a06..22749d7 100644 --- a/tests/collection/test_validate_input.py +++ b/tests/collection/test_validate_input.py @@ -16,12 +16,12 @@ class MyCollection(dy.Collection): second: dy.LazyFrame[TestSchema] | None -def test_collection_missing_required_member(): +def test_collection_missing_required_member() -> None: with pytest.raises(ValueError): MyCollection.validate({"second": pl.LazyFrame({"a": [1, 2, 3]})}) -def test_collection_superfluous_member(): +def test_collection_superfluous_member() -> None: with pytest.warns(Warning): MyCollection.validate( { diff --git a/tests/column_types/test_any.py b/tests/column_types/test_any.py index e53d377..363863f 100644 --- a/tests/column_types/test_any.py +++ b/tests/column_types/test_any.py @@ -17,6 +17,6 @@ class AnySchema(dy.Schema): "data", [{"a": [None]}, {"a": [True, None]}, {"a": ["foo"]}, {"a": [3.5]}], ) -def test_any_dtype_passes(data: dict[str, Any]): +def test_any_dtype_passes(data: dict[str, Any]) -> None: df = pl.DataFrame(data) assert AnySchema.is_valid(df) diff --git a/tests/column_types/test_datetime.py b/tests/column_types/test_datetime.py index 44263e3..d67a2fb 100644 --- a/tests/column_types/test_datetime.py +++ b/tests/column_types/test_datetime.py @@ -134,7 +134,9 @@ ), ], ) -def test_args_consistency_min_max(column_type: type[Column], kwargs: dict[str, Any]): +def test_args_consistency_min_max( + column_type: type[Column], kwargs: dict[str, Any] +) -> None: with pytest.raises(ValueError): column_type(**kwargs) @@ -170,7 +172,9 @@ def test_args_consistency_min_max(column_type: type[Column], kwargs: dict[str, A (dy.Duration, {"max_exclusive": dt.timedelta(minutes=30), "resolution": "1h"}), ], ) -def test_args_resolution_invalid(column_type: type[Column], kwargs: dict[str, Any]): +def test_args_resolution_invalid( + column_type: type[Column], kwargs: dict[str, Any] +) -> None: with pytest.raises(ValueError): column_type(**kwargs) @@ -200,7 +204,9 @@ def test_args_resolution_invalid(column_type: type[Column], kwargs: dict[str, An (dy.Duration, {"max_exclusive": dt.timedelta(hours=3), "resolution": "1h"}), ], ) -def test_args_resolution_valid(column_type: type[Column], kwargs: dict[str, Any]): +def test_args_resolution_valid( + column_type: type[Column], kwargs: dict[str, Any] +) -> None: column_type(**kwargs) @@ -331,7 +337,7 @@ def test_args_resolution_valid(column_type: type[Column], kwargs: dict[str, Any] ) def test_validate_min_max( column: Column, values: list[Any], valid: dict[str, list[bool]] -): +) -> None: lf = pl.LazyFrame({"a": values}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) expected = pl.LazyFrame(valid) @@ -374,7 +380,7 @@ def test_validate_min_max( ) def test_validate_resolution( column: Column, values: list[Any], valid: dict[str, list[bool]] -): +) -> None: lf = pl.LazyFrame({"a": values}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) expected = pl.LazyFrame(valid) @@ -389,7 +395,7 @@ def test_validate_resolution( ) ], ) -def test_sample_resolution(column: dy.Column): +def test_sample_resolution(column: dy.Column) -> None: generator = Generator(seed=42) samples = column.sample(generator, n=10_000) schema = create_schema("test", {"a": column}) diff --git a/tests/column_types/test_decimal.py b/tests/column_types/test_decimal.py index ace6a42..a25ac08 100644 --- a/tests/column_types/test_decimal.py +++ b/tests/column_types/test_decimal.py @@ -29,7 +29,7 @@ class DecimalSchema(dy.Schema): {"max": decimal.Decimal(2), "max_exclusive": decimal.Decimal(2)}, ], ) -def test_args_consistency_min_max(kwargs: dict[str, Any]): +def test_args_consistency_min_max(kwargs: dict[str, Any]) -> None: with pytest.raises(ValueError): dy.Decimal(**kwargs) @@ -47,7 +47,7 @@ def test_args_consistency_min_max(kwargs: dict[str, Any]): dict(precision=2, max=decimal.Decimal("100")), ], ) -def test_invalid_args(kwargs: dict[str, Any]): +def test_invalid_args(kwargs: dict[str, Any]) -> None: with pytest.raises(ValueError): dy.Decimal(**kwargs) @@ -55,7 +55,7 @@ def test_invalid_args(kwargs: dict[str, Any]): @pytest.mark.parametrize( "dtype", [pl.Decimal, pl.Decimal(12), pl.Decimal(None, 8), pl.Decimal(6, 2)] ) -def test_any_decimal_dtype_passes(dtype: DataTypeClass): +def test_any_decimal_dtype_passes(dtype: DataTypeClass) -> None: df = pl.DataFrame(schema={"a": dtype}) assert DecimalSchema.is_valid(df) @@ -63,7 +63,7 @@ def test_any_decimal_dtype_passes(dtype: DataTypeClass): @pytest.mark.parametrize( "dtype", [pl.Boolean, pl.String] + list(INTEGER_DTYPES) + list(FLOAT_DTYPES) ) -def test_non_decimal_dtype_fails(dtype: DataTypeClass): +def test_non_decimal_dtype_fails(dtype: DataTypeClass) -> None: df = pl.DataFrame(schema={"a": dtype}) assert not DecimalSchema.is_valid(df) @@ -75,7 +75,7 @@ def test_non_decimal_dtype_fails(dtype: DataTypeClass): (False, {"min_exclusive": [False, False, False, True, True]}), ], ) -def test_validate_min(inclusive: bool, valid: dict[str, list[bool]]): +def test_validate_min(inclusive: bool, valid: dict[str, list[bool]]) -> None: kwargs = {("min" if inclusive else "min_exclusive"): decimal.Decimal(3)} column = dy.Decimal(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) @@ -91,7 +91,7 @@ def test_validate_min(inclusive: bool, valid: dict[str, list[bool]]): (False, {"max_exclusive": [True, True, False, False, False]}), ], ) -def test_validate_max(inclusive: bool, valid: dict[str, list[bool]]): +def test_validate_max(inclusive: bool, valid: dict[str, list[bool]]) -> None: kwargs = {("max" if inclusive else "max_exclusive"): decimal.Decimal(3)} column = dy.Decimal(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) @@ -141,7 +141,7 @@ def test_validate_range( min_inclusive: bool, max_inclusive: bool, valid: dict[str, list[bool]], -): +) -> None: kwargs = { ("min" if min_inclusive else "min_exclusive"): decimal.Decimal(2), ("max" if max_inclusive else "max_exclusive"): decimal.Decimal(4), diff --git a/tests/column_types/test_enum.py b/tests/column_types/test_enum.py index 1ecb730..5a0d8a3 100644 --- a/tests/column_types/test_enum.py +++ b/tests/column_types/test_enum.py @@ -26,7 +26,7 @@ def test_valid( dy_enum: dy.Enum, pl_dtype: pl.Enum, valid: bool, -): +) -> None: schema = create_schema("test", {"a": dy_enum}) df = df_type({"a": ["x", "y", "x", "x"]}).cast(pl_dtype) assert schema.is_valid(df) == valid @@ -48,7 +48,7 @@ def test_valid_cast( data: Any, valid: bool, df_type: type[pl.DataFrame] | type[pl.LazyFrame], -): +) -> None: schema = create_schema("test", {"a": enum}) df = df_type(data) assert schema.is_valid(df, cast=True) == valid diff --git a/tests/column_types/test_float.py b/tests/column_types/test_float.py index 09dc003..d214d39 100644 --- a/tests/column_types/test_float.py +++ b/tests/column_types/test_float.py @@ -34,7 +34,7 @@ class IntegerSchema(dy.Schema): ) def test_args_consistency_min_max( column_type: type[_BaseFloat], kwargs: dict[str, Any] -): +) -> None: with pytest.raises(ValueError): column_type(**kwargs) @@ -50,19 +50,19 @@ def test_args_consistency_min_max( (dy.Float64, dict(max=float("inf"))), ], ) -def test_invalid_args(column_type: type[_BaseFloat], kwargs: dict[str, Any]): +def test_invalid_args(column_type: type[_BaseFloat], kwargs: dict[str, Any]) -> None: with pytest.raises(ValueError): column_type(**kwargs) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_any_integer_dtype_passes(dtype: DataTypeClass): +def test_any_integer_dtype_passes(dtype: DataTypeClass) -> None: df = pl.DataFrame(schema={"a": dtype}) assert IntegerSchema.is_valid(df) @pytest.mark.parametrize("dtype", [pl.Boolean, pl.String] + list(INTEGER_DTYPES)) -def test_non_integer_dtype_fails(dtype: DataTypeClass): +def test_non_integer_dtype_fails(dtype: DataTypeClass) -> None: df = pl.DataFrame(schema={"a": dtype}) assert not IntegerSchema.is_valid(df) @@ -77,7 +77,7 @@ def test_non_integer_dtype_fails(dtype: DataTypeClass): ) def test_validate_min( column_type: type[_BaseFloat], inclusive: bool, valid: dict[str, list[bool]] -): +) -> None: kwargs = {("min" if inclusive else "min_exclusive"): 3} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) @@ -96,7 +96,7 @@ def test_validate_min( ) def test_validate_max( column_type: type[_BaseFloat], inclusive: bool, valid: dict[str, list[bool]] -): +) -> None: kwargs = {("max" if inclusive else "max_exclusive"): 3} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) @@ -148,7 +148,7 @@ def test_validate_range( min_inclusive: bool, max_inclusive: bool, valid: dict[str, list[bool]], -): +) -> None: kwargs = { ("min" if min_inclusive else "min_exclusive"): 2, ("max" if max_inclusive else "max_exclusive"): 4, @@ -162,7 +162,7 @@ def test_validate_range( @pytest.mark.parametrize("inf", [np.inf, -np.inf, float("inf"), float("-inf")]) @pytest.mark.parametrize("nan", [np.nan, float("nan"), float("NaN")]) -def test_validate_inf_nan(inf: Any, nan: Any): +def test_validate_inf_nan(inf: Any, nan: Any) -> None: column = dy.Float(allow_inf_nan=False) lf = pl.LazyFrame({"a": pl.Series([inf, 2.0, nan, 4.0, 5.0])}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -172,7 +172,7 @@ def test_validate_inf_nan(inf: Any, nan: Any): @pytest.mark.parametrize("inf", [np.inf, -np.inf, float("inf"), float("-inf")]) @pytest.mark.parametrize("nan", [np.nan, float("nan"), float("NaN")]) -def test_validate_allow_inf_nan(inf: Any, nan: Any): +def test_validate_allow_inf_nan(inf: Any, nan: Any) -> None: column = dy.Float(allow_inf_nan=True) lf = pl.LazyFrame({"a": pl.Series([inf, 2.0, nan, 4.0, 5.0])}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -181,27 +181,27 @@ def test_validate_allow_inf_nan(inf: Any, nan: Any): ) -def test_sample_unchecked_min_0(): +def test_sample_unchecked_min_0() -> None: column = dy.Float(min=0, max=10) actual = column._sample_unchecked(dy.random.Generator(), n=10000) assert actual.min() >= 0, "There should be no negative values" # type: ignore -def test_sample_unchecked_nan(): +def test_sample_unchecked_nan() -> None: column = dy.Float(min=0, max=10, allow_inf_nan=True) actual = column._sample_unchecked(dy.random.Generator(), n=10000) nan_count = actual.is_nan().sum() assert 0.01 * len(actual) < nan_count < 0.1 * len(actual) -def test_sample_unchecked_unbounded(): +def test_sample_unchecked_unbounded() -> None: column = dy.Float(allow_inf_nan=False) actual = column._sample_unchecked(dy.random.Generator(), n=10000) assert actual.is_nan().sum() == 0 assert actual.is_infinite().sum() == 0 -def test_sample_unchecked_inf(): +def test_sample_unchecked_inf() -> None: column = dy.Float(allow_inf_nan=True) actual = column._sample_unchecked(dy.random.Generator(), n=10000) inf_count = actual.is_infinite().sum() diff --git a/tests/column_types/test_integer.py b/tests/column_types/test_integer.py index bff3434..21e9bdc 100644 --- a/tests/column_types/test_integer.py +++ b/tests/column_types/test_integer.py @@ -32,13 +32,13 @@ class IntegerSchema(dy.Schema): ) def test_args_consistency_min_max( column_type: type[_BaseInteger], kwargs: dict[str, Any] -): +) -> None: with pytest.raises(ValueError): column_type(**kwargs) @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) -def test_invalid_args_min_max(column_type: type[_BaseInteger]): +def test_invalid_args_min_max(column_type: type[_BaseInteger]) -> None: with pytest.raises(ValueError): column_type(min=column_type.min_value - 1) with pytest.raises(ValueError): @@ -54,26 +54,28 @@ def test_invalid_args_min_max(column_type: type[_BaseInteger]): {"min": 1, "max": 5, "is_in": [2, 3, 4]}, ], ) -def test_invalid_args_is_in(column_type: type[_BaseInteger], kwargs: dict[str, Any]): +def test_invalid_args_is_in( + column_type: type[_BaseInteger], kwargs: dict[str, Any] +) -> None: with pytest.raises(ValueError): column_type(**kwargs) @pytest.mark.parametrize("dtype", INTEGER_DTYPES) -def test_any_integer_dtype_passes(dtype: DataTypeClass): +def test_any_integer_dtype_passes(dtype: DataTypeClass) -> None: df = pl.DataFrame(schema={"a": dtype}) assert IntegerSchema.is_valid(df) @pytest.mark.parametrize("dtype", [pl.Boolean, pl.String] + list(FLOAT_DTYPES)) -def test_non_integer_dtype_fails(dtype: DataTypeClass): +def test_non_integer_dtype_fails(dtype: DataTypeClass) -> None: df = pl.DataFrame(schema={"a": dtype}) assert not IntegerSchema.is_valid(df) @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) @pytest.mark.parametrize("inclusive", [True, False]) -def test_validate_min(column_type: type[_BaseInteger], inclusive: bool): +def test_validate_min(column_type: type[_BaseInteger], inclusive: bool) -> None: kwargs = {("min" if inclusive else "min_exclusive"): 3} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) @@ -85,7 +87,7 @@ def test_validate_min(column_type: type[_BaseInteger], inclusive: bool): @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) @pytest.mark.parametrize("inclusive", [True, False]) -def test_validate_max(column_type: type[_BaseInteger], inclusive: bool): +def test_validate_max(column_type: type[_BaseInteger], inclusive: bool) -> None: kwargs = {("max" if inclusive else "max_exclusive"): 3} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) @@ -100,7 +102,7 @@ def test_validate_max(column_type: type[_BaseInteger], inclusive: bool): @pytest.mark.parametrize("max_inclusive", [True, False]) def test_validate_range( column_type: type[_BaseInteger], min_inclusive: bool, max_inclusive: bool -): +) -> None: kwargs = { ("min" if min_inclusive else "min_exclusive"): 2, ("max" if max_inclusive else "max_exclusive"): 4, @@ -120,7 +122,7 @@ def test_validate_range( @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) -def test_validate_is_in(column_type: type[_BaseInteger]): +def test_validate_is_in(column_type: type[_BaseInteger]) -> None: column = column_type(is_in=[3, 5]) lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -142,7 +144,7 @@ def test_validate_is_in(column_type: type[_BaseInteger]): (dy.UInt64, 8), ], ) -def test_num_bytes(column_type: type[_BaseInteger], num_bytes: int): +def test_num_bytes(column_type: type[_BaseInteger], num_bytes: int) -> None: assert column_type.num_bytes == num_bytes @@ -160,7 +162,7 @@ def test_num_bytes(column_type: type[_BaseInteger], num_bytes: int): (dy.UInt64, True), ], ) -def test_is_unsigned(column_type: type[_BaseInteger], is_unsigned: bool): +def test_is_unsigned(column_type: type[_BaseInteger], is_unsigned: bool) -> None: assert column_type.is_unsigned == is_unsigned @@ -180,6 +182,6 @@ def test_is_unsigned(column_type: type[_BaseInteger], is_unsigned: bool): ) def test_type_min_max_values( column_type: type[_BaseInteger], min_value: int, max_value: int -): +) -> None: assert column_type.min_value == min_value assert column_type.max_value == max_value diff --git a/tests/column_types/test_list.py b/tests/column_types/test_list.py index 9c14c81..d7401fc 100644 --- a/tests/column_types/test_list.py +++ b/tests/column_types/test_list.py @@ -10,12 +10,12 @@ @pytest.mark.parametrize("inner", [dy.Int64(), dy.Integer()]) -def test_integer_list(inner: Column): +def test_integer_list(inner: Column) -> None: schema = create_schema("test", {"a": dy.List(inner)}) assert schema.is_valid(pl.DataFrame({"a": [[1], [2], [3]]})) -def test_invalid_inner_type(): +def test_invalid_inner_type() -> None: schema = create_schema("test", {"a": dy.List(dy.Int64())}) assert not schema.is_valid(pl.DataFrame({"a": [["1"], ["2"], ["3"]]})) @@ -50,16 +50,16 @@ def test_invalid_inner_type(): ), ], ) -def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool): +def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool) -> None: assert column.validate_dtype(dtype) == is_valid -def test_nested_lists(): +def test_nested_lists() -> None: schema = create_schema("test", {"a": dy.List(dy.List(dy.Int64()))}) assert schema.is_valid(pl.DataFrame({"a": [[[1]], [[2]], [[3]]]})) -def test_list_with_pk(): +def test_list_with_pk() -> None: schema = create_schema( "test", {"a": dy.List(dy.String(), primary_key=True)}, @@ -70,7 +70,7 @@ def test_list_with_pk(): assert failures.counts() == {"primary_key": 2} -def test_list_with_rules(): +def test_list_with_rules() -> None: schema = create_schema( "test", {"a": dy.List(dy.String(min_length=2, nullable=False))} ) @@ -80,7 +80,7 @@ def test_list_with_rules(): assert failures.counts() == {"a|inner_nullability": 1, "a|inner_min_length": 1} -def test_nested_list_with_rules(): +def test_nested_list_with_rules() -> None: schema = create_schema( "test", {"a": dy.List(dy.List(dy.String(min_length=2, nullable=False)))} ) @@ -94,7 +94,7 @@ def test_nested_list_with_rules(): } -def test_list_length_rules(): +def test_list_length_rules() -> None: schema = create_schema( "test", { @@ -111,7 +111,7 @@ def test_list_length_rules(): assert validation_mask(df, failures).to_list() == [True, False, False, True, False] -def test_outer_inner_nullability(): +def test_outer_inner_nullability() -> None: schema = create_schema( "test", { @@ -125,7 +125,7 @@ def test_outer_inner_nullability(): schema.validate(df, cast=True) -def test_inner_primary_key(): +def test_inner_primary_key() -> None: schema = create_schema("test", {"a": dy.List(dy.Integer(primary_key=True))}) df = pl.DataFrame({"a": [[1, 2, 3], [1, 1, 2], [1, 1], [1, 4]]}) _, failure = schema.filter(df) @@ -147,7 +147,7 @@ def test_inner_primary_key_struct( second_primary_key: bool, failure_count: int, mask: list[bool], -): +) -> None: schema = create_schema( "test", { diff --git a/tests/column_types/test_string.py b/tests/column_types/test_string.py index 10d7629..1003102 100644 --- a/tests/column_types/test_string.py +++ b/tests/column_types/test_string.py @@ -8,7 +8,7 @@ from dataframely.testing import evaluate_rules, rules_from_exprs -def test_validate_min_length(): +def test_validate_min_length() -> None: column = dy.String(min_length=2) lf = pl.LazyFrame({"a": ["foo", "x"]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -16,7 +16,7 @@ def test_validate_min_length(): assert_frame_equal(actual, expected) -def test_validate_max_length(): +def test_validate_max_length() -> None: column = dy.String(max_length=2) lf = pl.LazyFrame({"a": ["foo", "x"]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -24,7 +24,7 @@ def test_validate_max_length(): assert_frame_equal(actual, expected) -def test_validate_regex(): +def test_validate_regex() -> None: column = dy.String(regex="[0-9][a-z]$") lf = pl.LazyFrame({"a": ["33x", "3x", "44"]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -32,7 +32,7 @@ def test_validate_regex(): assert_frame_equal(actual, expected) -def test_validate_all_rules(): +def test_validate_all_rules() -> None: column = dy.String(nullable=False, min_length=2, max_length=4) lf = pl.LazyFrame({"a": ["foo", "x", "foobar", None]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) diff --git a/tests/column_types/test_struct.py b/tests/column_types/test_struct.py index 73e8db4..3d392bb 100644 --- a/tests/column_types/test_struct.py +++ b/tests/column_types/test_struct.py @@ -10,7 +10,7 @@ from dataframely.testing import create_schema -def test_simple_struct(): +def test_simple_struct() -> None: schema = create_schema( "test", {"s": dy.Struct({"a": dy.Integer(), "b": dy.String()})} ) @@ -74,16 +74,16 @@ def test_simple_struct(): ), ], ) -def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool): +def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool) -> None: assert column.validate_dtype(dtype) == is_valid -def test_invalid_inner_type(): +def test_invalid_inner_type() -> None: schema = create_schema("test", {"a": dy.Struct({"a": dy.Int64()})}) assert not schema.is_valid(pl.DataFrame({"a": [{"a": "1"}, {"a": "2"}]})) -def test_nested_structs(): +def test_nested_structs() -> None: schema = create_schema( "test", { @@ -100,7 +100,7 @@ def test_nested_structs(): ) -def test_struct_with_pk(): +def test_struct_with_pk() -> None: schema = create_schema( "test", {"s": dy.Struct({"a": dy.String(), "b": dy.Integer()}, primary_key=True)}, @@ -115,7 +115,7 @@ def test_struct_with_pk(): assert failures.counts() == {"primary_key": 2} -def test_struct_with_rules(): +def test_struct_with_rules() -> None: schema = create_schema( "test", {"s": dy.Struct({"a": dy.String(min_length=2, nullable=False)})} ) @@ -127,7 +127,7 @@ def test_struct_with_rules(): assert failures.counts() == {"s|inner_a_nullability": 1, "s|inner_a_min_length": 1} -def test_nested_struct_with_rules(): +def test_nested_struct_with_rules() -> None: schema = create_schema( "test", { @@ -149,7 +149,7 @@ def test_nested_struct_with_rules(): } -def test_outer_inner_nullability(): +def test_outer_inner_nullability() -> None: schema = create_schema( "test", { diff --git a/tests/columns/test_alias.py b/tests/columns/test_alias.py index 96eed2c..ebffe5b 100644 --- a/tests/columns/test_alias.py +++ b/tests/columns/test_alias.py @@ -10,15 +10,15 @@ class AliasSchema(dy.Schema): a = dy.Int64(alias="hello world: col with space!") -def test_column_names(): +def test_column_names() -> None: assert AliasSchema.column_names() == ["hello world: col with space!"] -def test_validation(): +def test_validation() -> None: df = pl.DataFrame({"hello world: col with space!": [1, 2]}) assert AliasSchema.is_valid(df) -def test_create_empty(): +def test_create_empty() -> None: df = AliasSchema.create_empty() assert AliasSchema.is_valid(df) diff --git a/tests/columns/test_check.py b/tests/columns/test_check.py index c9ae745..17dde02 100644 --- a/tests/columns/test_check.py +++ b/tests/columns/test_check.py @@ -12,7 +12,7 @@ class CheckSchema(dy.Schema): b = dy.String(min_length=3, check=lambda col: col.str.contains("x")) -def test_check(): +def test_check() -> None: df = pl.DataFrame({"a": [7, 3, 15], "b": ["abc", "xyz", "x"]}) _, failures = CheckSchema.filter(df) assert validation_mask(df, failures).to_list() == [False, True, False] diff --git a/tests/columns/test_default_dtypes.py b/tests/columns/test_default_dtypes.py index 53c7ff6..5086065 100644 --- a/tests/columns/test_default_dtypes.py +++ b/tests/columns/test_default_dtypes.py @@ -40,7 +40,7 @@ (dy.Enum(["a", "b"]), pl.Enum(["a", "b"])), ], ) -def test_default_dtype(column: Column, dtype: pl.DataType): +def test_default_dtype(column: Column, dtype: pl.DataType) -> None: schema = create_schema("test", {"a": column}) df = schema.create_empty() assert df.schema["a"] == dtype diff --git a/tests/columns/test_metadata.py b/tests/columns/test_metadata.py index 91b3119..556ab95 100644 --- a/tests/columns/test_metadata.py +++ b/tests/columns/test_metadata.py @@ -9,7 +9,7 @@ class SchemaWithMetadata(dy.Schema): b = dy.String() -def test_metadata(): +def test_metadata() -> None: assert SchemaWithMetadata.a.metadata == { "masked": True, "comment": "foo", diff --git a/tests/columns/test_pyarrow.py b/tests/columns/test_pyarrow.py index b41d33a..680ea1b 100644 --- a/tests/columns/test_pyarrow.py +++ b/tests/columns/test_pyarrow.py @@ -14,14 +14,14 @@ @pytest.mark.parametrize("column_type", ALL_COLUMN_TYPES) -def test_equal_to_polars_schema(column_type: type[Column]): +def test_equal_to_polars_schema(column_type: type[Column]) -> None: schema = create_schema("test", {"a": column_type()}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema assert actual == expected -def test_equal_polars_schema_enum(): +def test_equal_polars_schema_enum() -> None: schema = create_schema("test", {"a": dy.Enum(["a", "b"])}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema @@ -34,7 +34,7 @@ def test_equal_polars_schema_enum(): + [dy.List(t()) for t in ALL_COLUMN_TYPES] + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES], ) -def test_equal_polars_schema_list(inner: Column): +def test_equal_polars_schema_list(inner: Column) -> None: schema = create_schema("test", {"a": dy.List(inner)}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema @@ -47,7 +47,7 @@ def test_equal_polars_schema_list(inner: Column): + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES] + [dy.List(t()) for t in ALL_COLUMN_TYPES], ) -def test_equal_polars_schema_struct(inner: Column): +def test_equal_polars_schema_struct(inner: Column) -> None: schema = create_schema("test", {"a": dy.Struct({"a": inner})}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema @@ -56,13 +56,13 @@ def test_equal_polars_schema_struct(inner: Column): @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES) @pytest.mark.parametrize("nullable", [True, False]) -def test_nullability_information(column_type: type[Column], nullable: bool): +def test_nullability_information(column_type: type[Column], nullable: bool) -> None: schema = create_schema("test", {"a": column_type(nullable=nullable)}) assert ("not null" in str(schema.pyarrow_schema())) != nullable @pytest.mark.parametrize("nullable", [True, False]) -def test_nullability_information_enum(nullable: bool): +def test_nullability_information_enum(nullable: bool) -> None: schema = create_schema("test", {"a": dy.Enum(["a", "b"], nullable=nullable)}) assert ("not null" in str(schema.pyarrow_schema())) != nullable @@ -74,7 +74,7 @@ def test_nullability_information_enum(nullable: bool): + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES], ) @pytest.mark.parametrize("nullable", [True, False]) -def test_nullability_information_list(inner: Column, nullable: bool): +def test_nullability_information_list(inner: Column, nullable: bool) -> None: schema = create_schema("test", {"a": dy.List(inner, nullable=nullable)}) assert ("not null" in str(schema.pyarrow_schema())) != nullable @@ -86,11 +86,11 @@ def test_nullability_information_list(inner: Column, nullable: bool): + [dy.List(t()) for t in ALL_COLUMN_TYPES], ) @pytest.mark.parametrize("nullable", [True, False]) -def test_nullability_information_struct(inner: Column, nullable: bool): +def test_nullability_information_struct(inner: Column, nullable: bool) -> None: schema = create_schema("test", {"a": dy.Struct({"a": inner}, nullable=nullable)}) assert ("not null" in str(schema.pyarrow_schema())) != nullable -def test_multiple_columns(): +def test_multiple_columns() -> None: schema = create_schema("test", {"a": dy.Int32(nullable=False), "b": dy.Integer()}) assert str(schema.pyarrow_schema()).split("\n") == ["a: int32 not null", "b: int64"] diff --git a/tests/columns/test_rules.py b/tests/columns/test_rules.py index 6fdf021..620ebf8 100644 --- a/tests/columns/test_rules.py +++ b/tests/columns/test_rules.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES) @pytest.mark.parametrize("nullable", [True, False]) -def test_rule_count_nullability(column_type: type[Column], nullable: bool): +def test_rule_count_nullability(column_type: type[Column], nullable: bool) -> None: column = column_type(nullable=nullable) assert len(column.validation_rules(pl.col("a"))) == int(not nullable) + ( 1 if isinstance(column, _BaseFloat) else 0 @@ -25,7 +25,7 @@ def test_rule_count_nullability(column_type: type[Column], nullable: bool): @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES) -def test_nullability_rule_for_primary_key(column_type: type[Column]): +def test_nullability_rule_for_primary_key(column_type: type[Column]) -> None: column = column_type(primary_key=True) assert len(column.validation_rules(pl.col("a"))) == ( 2 @@ -35,7 +35,7 @@ def test_nullability_rule_for_primary_key(column_type: type[Column]): @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES) -def test_nullability_rule(column_type: type[Column]): +def test_nullability_rule(column_type: type[Column]) -> None: column = column_type(nullable=False) lf = pl.LazyFrame({"a": [None]}, schema={"a": column.dtype}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) diff --git a/tests/columns/test_sample.py b/tests/columns/test_sample.py index cf0fa63..25b55a8 100644 --- a/tests/columns/test_sample.py +++ b/tests/columns/test_sample.py @@ -26,7 +26,9 @@ def generator() -> Generator: @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES) -def test_sample_custom_check(column_type: type[dy.Column], generator: Generator): +def test_sample_custom_check( + column_type: type[dy.Column], generator: Generator +) -> None: column = column_type(check=lambda expr: expr) with pytest.raises(ValueError): column.sample(generator) @@ -36,7 +38,7 @@ def test_sample_custom_check(column_type: type[dy.Column], generator: Generator) @pytest.mark.parametrize("nullable", [True, False]) def test_sample_valid( column_type: type[dy.Column], nullable: bool, generator: Generator -): +) -> None: if issubclass(column_type, _BaseFloat): # let's avoid sampling NaN and Inf by setting min/max column: dy.Column = column_type(nullable=nullable, min=-10_000, max=10_000) @@ -47,7 +49,7 @@ def test_sample_valid( assert math.isclose(cast(float, samples.is_null().mean()), 0.1, abs_tol=0.01) -def test_sample_any(generator: Generator): +def test_sample_any(generator: Generator) -> None: column = dy.Any() samples = sample_and_validate(column, generator, n=100) assert samples.is_null().all() @@ -61,7 +63,7 @@ def test_sample_integer_min_max( min_kwargs: dict[str, Any], max_kwargs: dict[str, Any], generator: Generator, -): +) -> None: column = column_type(**min_kwargs, **max_kwargs) samples = sample_and_validate(column, generator, n=10_000) if min_kwargs and max_kwargs: @@ -78,7 +80,9 @@ def test_sample_integer_min_max( @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) -def test_sample_integer_is_in(column_type: type[dy.Column], generator: Generator): +def test_sample_integer_is_in( + column_type: type[dy.Column], generator: Generator +) -> None: column = column_type(is_in=[4, 5, 6]) # type: ignore samples = sample_and_validate(column, generator, n=10_000) assert math.isclose(samples.mean(), 5, abs_tol=0.1) # type: ignore @@ -94,7 +98,7 @@ def test_sample_integer_is_in(column_type: type[dy.Column], generator: Generator dy.String(regex=".*", min_length=1, max_length=5), ], ) -def test_sample_string_invalid(column: dy.Column, generator: Generator): +def test_sample_string_invalid(column: dy.Column, generator: Generator) -> None: with pytest.raises(ValueError): column.sample(generator) @@ -109,18 +113,18 @@ def test_sample_string_invalid(column: dy.Column, generator: Generator): dy.String(regex="[abc]def(ghi)?"), ], ) -def test_sample_string(column: dy.Column, generator: Generator): +def test_sample_string(column: dy.Column, generator: Generator) -> None: sample_and_validate(column, generator, n=10_000) -def test_sample_decimal(generator: Generator): +def test_sample_decimal(generator: Generator) -> None: column = dy.Decimal(precision=3, scale=2, max_exclusive=decimal.Decimal("6.5")) samples = sample_and_validate(column, generator, n=100_000) assert samples.min() == decimal.Decimal("-9.99") assert samples.max() == decimal.Decimal("6.49") -def test_sample_date(generator: Generator): +def test_sample_date(generator: Generator) -> None: column = dy.Date( min=dt.date(2020, 1, 1), max=dt.date(2021, 12, 1), resolution="1mo" ) @@ -129,7 +133,7 @@ def test_sample_date(generator: Generator): assert samples.max() == dt.date(2021, 12, 1) -def test_sample_date_9999(generator: Generator): +def test_sample_date_9999(generator: Generator) -> None: column = dy.Date( min=dt.date(9998, 1, 1), max=dt.date(9999, 12, 1), resolution="1mo" ) @@ -138,7 +142,7 @@ def test_sample_date_9999(generator: Generator): assert samples.max() == dt.date(9999, 12, 1) -def test_sample_datetime(generator: Generator): +def test_sample_datetime(generator: Generator) -> None: column = dy.Datetime( min=dt.datetime(2020, 1, 1), max_exclusive=dt.datetime(2022, 1, 1), @@ -149,14 +153,14 @@ def test_sample_datetime(generator: Generator): assert samples.max() == dt.datetime(2021, 12, 31) -def test_sample_time(generator: Generator): +def test_sample_time(generator: Generator) -> None: column = dy.Time(min=dt.time(), max=dt.time(23, 59), resolution="1m") samples = sample_and_validate(column, generator, n=1_000_000) assert samples.min() == dt.time() assert samples.max() == dt.time(23, 59) -def test_sample_duration(generator: Generator): +def test_sample_duration(generator: Generator) -> None: column = dy.Duration( min=dt.timedelta(hours=24), max=dt.timedelta(hours=120), resolution="12h" ) @@ -165,19 +169,19 @@ def test_sample_duration(generator: Generator): assert samples.max() == dt.timedelta(hours=120) -def test_sample_enum(generator: Generator): +def test_sample_enum(generator: Generator) -> None: column = dy.Enum(["a", "b", "c"], nullable=False) samples = sample_and_validate(column, generator, n=10_000) assert set(samples) == {"a", "b", "c"} -def test_sample_list(generator: Generator): +def test_sample_list(generator: Generator) -> None: column = dy.List(dy.String(regex="[abc]"), min_length=5, max_length=10) samples = sample_and_validate(column, generator, n=10_000) assert set(samples.list.len()) == set(range(5, 11)) -def test_sample_struct(generator: Generator): +def test_sample_struct(generator: Generator) -> None: column = dy.Struct({"a": dy.String(regex="[abc]"), "b": dy.String(regex="[a-z]xx")}) samples = sample_and_validate(column, generator, n=10_000) assert len(samples) == 10_000 diff --git a/tests/columns/test_sql_schema.py b/tests/columns/test_sql_schema.py index 0b40d71..09a6367 100644 --- a/tests/columns/test_sql_schema.py +++ b/tests/columns/test_sql_schema.py @@ -48,7 +48,7 @@ (dy.Enum(["a", "abc"]), "VARCHAR(3)"), ], ) -def test_mssql_datatype(column: Column, datatype: str): +def test_mssql_datatype(column: Column, datatype: str) -> None: dialect = MSDialect_pyodbc() schema = create_schema("test", {"a": column}) columns = schema.sql_schema(dialect) @@ -92,7 +92,7 @@ def test_mssql_datatype(column: Column, datatype: str): (dy.Enum(["a", "abc"]), "VARCHAR(3)"), ], ) -def test_postgres_datatype(column: Column, datatype: str): +def test_postgres_datatype(column: Column, datatype: str) -> None: dialect = PGDialect_psycopg2() schema = create_schema("test", {"a": column}) columns = schema.sql_schema(dialect) @@ -105,7 +105,7 @@ def test_postgres_datatype(column: Column, datatype: str): @pytest.mark.parametrize("dialect", [MSDialect_pyodbc()]) def test_sql_nullability( column_type: type[Column], nullable: bool, dialect: sa.Dialect -): +) -> None: schema = create_schema("test", {"a": column_type(nullable=nullable)}) columns = schema.sql_schema(dialect) assert len(columns) == 1 @@ -117,7 +117,7 @@ def test_sql_nullability( @pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) def test_sql_primary_key( column_type: type[Column], primary_key: bool, dialect: sa.Dialect -): +) -> None: schema = create_schema("test", {"a": column_type(primary_key=primary_key)}) columns = schema.sql_schema(dialect) assert len(columns) == 1 @@ -126,13 +126,13 @@ def test_sql_primary_key( @pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) -def test_sql_multiple_columns(dialect: sa.Dialect): +def test_sql_multiple_columns(dialect: sa.Dialect) -> None: schema = create_schema("test", {"a": dy.Int32(nullable=False), "b": dy.Integer()}) assert len(schema.sql_schema(dialect)) == 2 @pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) -def test_raise_for_list_column(dialect: sa.Dialect): +def test_raise_for_list_column(dialect: sa.Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'List' type." ): @@ -140,7 +140,7 @@ def test_raise_for_list_column(dialect: sa.Dialect): @pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) -def test_raise_for_struct_column(dialect: sa.Dialect): +def test_raise_for_struct_column(dialect: sa.Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'Struct' type." ): diff --git a/tests/columns/test_str.py b/tests/columns/test_str.py index 2a835ab..0aa284d 100644 --- a/tests/columns/test_str.py +++ b/tests/columns/test_str.py @@ -9,21 +9,21 @@ @pytest.mark.parametrize("column_type", ALL_COLUMN_TYPES) -def test_string_representation(column_type: type[Column]): +def test_string_representation(column_type: type[Column]) -> None: column = column_type() assert str(column) == column_type.__name__.lower() -def test_string_representation_enum(): +def test_string_representation_enum() -> None: column = dy.Enum(["a", "b"]) assert str(column) == dy.Enum.__name__.lower() -def test_string_representation_list(): +def test_string_representation_list() -> None: column = dy.List(dy.String()) assert str(column) == dy.List.__name__.lower() -def test_string_representation_struct(): +def test_string_representation_struct() -> None: column = dy.Struct({"a": dy.String()}) assert str(column) == dy.Struct.__name__.lower() diff --git a/tests/columns/test_utils.py b/tests/columns/test_utils.py index 1dff4e7..4279d8d 100644 --- a/tests/columns/test_utils.py +++ b/tests/columns/test_utils.py @@ -5,22 +5,22 @@ from dataframely.columns._utils import first_non_null -def test_first_non_null_basic(): +def test_first_non_null_basic() -> None: assert first_non_null(1, 2, default=3) == 1 assert first_non_null(None, 2, default=3) == 2 assert first_non_null(None, None, default=3) == 3 -def test_first_non_null_allow_null_response(): +def test_first_non_null_allow_null_response() -> None: assert first_non_null(None, None, None, allow_null_response=True) is None -def test_first_non_null_with_terminal(): +def test_first_non_null_with_terminal() -> None: assert first_non_null(None, None, None, default=42) == 42 assert first_non_null(None, 3, None, default=42) == 3 -def test_first_non_null_mixed_types(): +def test_first_non_null_mixed_types() -> None: assert first_non_null(None, "a", default=3) == "a" assert first_non_null(None, 0, default="b") == 0 # 0 is a valid non-null value assert ( @@ -28,6 +28,6 @@ def test_first_non_null_mixed_types(): ) # False is a valid non-null value -def test_first_non_null_with_kwargs(): +def test_first_non_null_with_kwargs() -> None: assert first_non_null(None, None, allow_null_response=True) is None assert first_non_null(None, None, default="fallback") == "fallback" diff --git a/tests/core_validation/test_column_validation.py b/tests/core_validation/test_column_validation.py index 74d9afd..1058cd0 100644 --- a/tests/core_validation/test_column_validation.py +++ b/tests/core_validation/test_column_validation.py @@ -8,7 +8,7 @@ from dataframely.exc import ValidationError -def test_success(): +def test_success() -> None: df = pl.DataFrame(schema={k: pl.Int64() for k in ["a", "b"]}) lf = validate_columns(df.lazy(), actual=df.schema.keys(), expected=["a"]) assert set(lf.collect_schema().names()) == {"a"} @@ -21,7 +21,7 @@ def test_success(): (["c"], ["a", "b"], r"2 columns in the schema are missing.*'a'.*'b'"), ], ) -def test_failure(actual: list[str], expected: list[str], error: str): +def test_failure(actual: list[str], expected: list[str], error: str) -> None: df = pl.DataFrame(schema={k: pl.Int64() for k in actual}) with pytest.raises(ValidationError, match=error): validate_columns(df.lazy(), actual=df.schema.keys(), expected=expected) diff --git a/tests/core_validation/test_dtype_validation.py b/tests/core_validation/test_dtype_validation.py index 8091426..718b4cd 100644 --- a/tests/core_validation/test_dtype_validation.py +++ b/tests/core_validation/test_dtype_validation.py @@ -31,7 +31,7 @@ def test_success( actual: dict[str, pl.DataType], expected: dict[str, Column], casting: DtypeCasting, -): +) -> None: df = pl.DataFrame(schema=actual) lf = validate_dtypes( df.lazy(), actual=df.schema, expected=expected, casting=casting @@ -63,7 +63,7 @@ def test_failure( expected: dict[str, Column], error: str, fail_columns: set[str], -): +) -> None: df = pl.DataFrame(schema=actual) try: validate_dtypes(df.lazy(), actual=df.schema, expected=expected, casting="none") @@ -73,7 +73,7 @@ def test_failure( assert re.match(error, str(exc)) -def test_lenient_casting(): +def test_lenient_casting() -> None: lf = pl.LazyFrame( {"a": [1, 2, 3], "b": ["foo", "12", "1313"]}, schema={"a": pl.Int64(), "b": pl.String()}, @@ -91,7 +91,7 @@ def test_lenient_casting(): assert_frame_equal(actual, expected) -def test_strict_casting(): +def test_strict_casting() -> None: lf = pl.LazyFrame( {"a": [1, 2, 3], "b": ["foo", "12", "1313"]}, schema={"a": pl.Int64(), "b": pl.String()}, diff --git a/tests/core_validation/test_rule_evaluation.py b/tests/core_validation/test_rule_evaluation.py index a3bbe25..2877be5 100644 --- a/tests/core_validation/test_rule_evaluation.py +++ b/tests/core_validation/test_rule_evaluation.py @@ -8,7 +8,7 @@ from dataframely.testing import evaluate_rules -def test_single_column_single_rule(): +def test_single_column_single_rule() -> None: lf = pl.LazyFrame({"a": [1, 2]}) rules = { "a|min": Rule(pl.col("a") >= 2), @@ -19,7 +19,7 @@ def test_single_column_single_rule(): assert_frame_equal(actual, expected) -def test_single_column_multi_rule(): +def test_single_column_multi_rule() -> None: lf = pl.LazyFrame({"a": [1, 2, 3]}) rules = { "a|min": Rule(pl.col("a") >= 2), @@ -33,7 +33,7 @@ def test_single_column_multi_rule(): assert_frame_equal(actual, expected) -def test_multi_column_multi_rule(): +def test_multi_column_multi_rule() -> None: lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) rules = { "a|min": Rule(pl.col("a") >= 2), @@ -52,7 +52,7 @@ def test_multi_column_multi_rule(): assert_frame_equal(actual, expected) -def test_cross_column_rule(): +def test_cross_column_rule() -> None: lf = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 1, 1, 2]}) rules = {"primary_key": Rule(~pl.struct("a", "b").is_duplicated())} actual = evaluate_rules(lf, rules) @@ -61,7 +61,7 @@ def test_cross_column_rule(): assert_frame_equal(actual, expected) -def test_group_rule(): +def test_group_rule() -> None: lf = pl.LazyFrame({"a": [1, 1, 2, 2, 3], "b": [1, 1, 1, 2, 1]}) rules: dict[str, Rule] = { "unique_b": GroupRule(pl.col("b").n_unique() == 1, group_columns=["a"]) @@ -72,7 +72,7 @@ def test_group_rule(): assert_frame_equal(actual, expected) -def test_simple_rule_and_group_rule(): +def test_simple_rule_and_group_rule() -> None: lf = pl.LazyFrame({"a": [1, 1, 2, 2, 3], "b": [1, 1, 1, 2, 1]}) rules: dict[str, Rule] = { "b|max": Rule(pl.col("b") <= 1), @@ -89,7 +89,7 @@ def test_simple_rule_and_group_rule(): assert_frame_equal(actual, expected, check_column_order=False) -def test_multiple_group_rules(): +def test_multiple_group_rules() -> None: lf = pl.LazyFrame({"a": [1, 1, 2, 2, 3], "b": [1, 1, 1, 2, 1]}) rules: dict[str, Rule] = { "unique_b": GroupRule(pl.col("b").n_unique() == 1, group_columns=["a"]), diff --git a/tests/functional/test_concat.py b/tests/functional/test_concat.py index a2fa934..1c6976c 100644 --- a/tests/functional/test_concat.py +++ b/tests/functional/test_concat.py @@ -17,7 +17,7 @@ class SimpleCollection(dy.Collection): third: dy.LazyFrame[MySchema] | None -def test_concat(): +def test_concat() -> None: col1 = SimpleCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})}) col2 = SimpleCollection.cast( { @@ -38,6 +38,6 @@ def test_concat(): assert concat["third"].collect().get_column("a").to_list() == list(range(7, 10)) -def test_concat_empty(): +def test_concat_empty() -> None: with pytest.raises(ValueError): dy.concat_collection_members([]) diff --git a/tests/functional/test_relationships.py b/tests/functional/test_relationships.py index 7263986..b650e11 100644 --- a/tests/functional/test_relationships.py +++ b/tests/functional/test_relationships.py @@ -60,7 +60,7 @@ def employees() -> dy.LazyFrame[EmployeeSchema]: def test_one_to_one( departments: dy.LazyFrame[DepartmentSchema], managers: dy.LazyFrame[ManagerSchema], -): +) -> None: actual = dy.filter_relationship_one_to_one( departments, managers, on="department_id" ) @@ -70,7 +70,7 @@ def test_one_to_one( def test_one_to_at_least_one( departments: dy.LazyFrame[DepartmentSchema], employees: dy.LazyFrame[EmployeeSchema], -): +) -> None: actual = dy.filter_relationship_one_to_at_least_one( departments, employees, on="department_id" ) diff --git a/tests/schema/__init__.py b/tests/schema/__init__.py deleted file mode 100644 index e047415..0000000 --- a/tests/schema/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) QuantCo 2025-2025 -# SPDX-License-Identifier: BSD-3-Clause diff --git a/tests/schema/test_base.py b/tests/schema/test_base.py index 1e3600f..6813680 100644 --- a/tests/schema/test_base.py +++ b/tests/schema/test_base.py @@ -18,11 +18,11 @@ class MySchema(dy.Schema): d = dy.Any(alias="e") -def test_column_names(): +def test_column_names() -> None: assert MySchema.column_names() == ["a", "b", "c", "e"] -def test_columns(): +def test_columns() -> None: columns = MySchema.columns() assert isinstance(columns["a"], dy.Integer) assert isinstance(columns["b"], dy.String) @@ -30,7 +30,7 @@ def test_columns(): assert isinstance(columns["e"], dy.Any) -def test_nullability(): +def test_nullability() -> None: columns = MySchema.columns() assert not columns["a"].nullable assert not columns["b"].nullable @@ -38,11 +38,11 @@ def test_nullability(): assert columns["e"].nullable -def test_primary_keys(): +def test_primary_keys() -> None: assert MySchema.primary_keys() == ["a", "b"] -def test_no_rule_named_primary_key(): +def test_no_rule_named_primary_key() -> None: with pytest.raises(ImplementationError): create_schema( "test", @@ -51,14 +51,14 @@ def test_no_rule_named_primary_key(): ) -def test_col(): +def test_col() -> None: assert MySchema.a.col.__dict__ == pl.col("a").__dict__ assert MySchema.b.col.__dict__ == pl.col("b").__dict__ assert MySchema.c.col.__dict__ == pl.col("c").__dict__ assert MySchema.d.col.__dict__ == pl.col("e").__dict__ -def test_col_raise_if_none(): +def test_col_raise_if_none() -> None: class InvalidSchema(dy.Schema): a = dy.Integer() @@ -68,7 +68,7 @@ class InvalidSchema(dy.Schema): InvalidSchema.a.col -def test_col_in_polars_expression(): +def test_col_in_polars_expression() -> None: df = ( pl.DataFrame({"a": [1, 2], "b": ["a", "b"], "c": [1.0, 2.0], "e": [None, None]}) .filter((MySchema.b.col == "a") & (MySchema.a.col > 0)) diff --git a/tests/schema/test_cast.py b/tests/schema/test_cast.py index 44f7e04..45c7e49 100644 --- a/tests/schema/test_cast.py +++ b/tests/schema/test_cast.py @@ -25,20 +25,20 @@ class MySchema(dy.Schema): ) def test_cast_valid( df_type: type[pl.DataFrame] | type[pl.LazyFrame], data: dict[str, Any] -): +) -> None: df = df_type(data) out = MySchema.cast(df) assert isinstance(out, df_type) assert out.lazy().collect_schema() == MySchema.polars_schema() -def test_cast_invalid_schema_eager(): +def test_cast_invalid_schema_eager() -> None: df = pl.DataFrame({"a": [1]}) with pytest.raises(plexc.ColumnNotFoundError): MySchema.cast(df) -def test_cast_invalid_schema_lazy(): +def test_cast_invalid_schema_lazy() -> None: lf = pl.LazyFrame({"a": [1]}) lf = MySchema.cast(lf) with pytest.raises(plexc.ColumnNotFoundError): diff --git a/tests/schema/test_create_empty.py b/tests/schema/test_create_empty.py index 0035400..4090a8b 100644 --- a/tests/schema/test_create_empty.py +++ b/tests/schema/test_create_empty.py @@ -13,7 +13,7 @@ class MySchema(dy.Schema): @pytest.mark.parametrize("with_arg", [True, False]) -def test_create_empty_eager(with_arg: bool): +def test_create_empty_eager(with_arg: bool) -> None: if with_arg: df = MySchema.create_empty(lazy=False) else: @@ -25,7 +25,7 @@ def test_create_empty_eager(with_arg: bool): assert len(df) == 0 -def test_create_empty_lazy(): +def test_create_empty_lazy() -> None: df = MySchema.create_empty(lazy=True) assert isinstance(df, pl.LazyFrame) assert df.collect_schema().names() == ["a", "b"] diff --git a/tests/schema/test_create_empty_if_none.py b/tests/schema/test_create_empty_if_none.py index 73c4f5a..876899b 100644 --- a/tests/schema/test_create_empty_if_none.py +++ b/tests/schema/test_create_empty_if_none.py @@ -15,7 +15,7 @@ class MySchema(dy.Schema): @pytest.mark.parametrize("lazy_in", [True, False]) @pytest.mark.parametrize("lazy_out", [True, False]) -def test_create_empty_if_none_non_none(lazy_in: bool, lazy_out: bool): +def test_create_empty_if_none_non_none(lazy_in: bool, lazy_out: bool) -> None: # Arrange df_raw = MySchema.validate(pl.DataFrame({"a": [1], "b": ["foo"]})) df = df_raw.lazy() if lazy_in else df_raw @@ -32,7 +32,7 @@ def test_create_empty_if_none_non_none(lazy_in: bool, lazy_out: bool): @pytest.mark.parametrize("lazy", [True, False]) -def test_create_empty_if_none_none(lazy: bool): +def test_create_empty_if_none_none(lazy: bool) -> None: # Act result = MySchema.create_empty_if_none(None, lazy=lazy) diff --git a/tests/schema/test_filter.py b/tests/schema/test_filter.py index e4342b0..2874c4b 100644 --- a/tests/schema/test_filter.py +++ b/tests/schema/test_filter.py @@ -28,7 +28,7 @@ class MySchema(dy.Schema): ) def test_filter_extra_columns( schema: dict[str, DataTypeClass], expected_columns: list[str] | None -): +) -> None: df = pl.DataFrame(schema=schema) try: filtered, _ = MySchema.filter(df) @@ -47,7 +47,9 @@ def test_filter_extra_columns( ({"a": pl.String, "b": pl.String}, True, True), ], ) -def test_filter_dtypes(schema: dict[str, DataTypeClass], cast: bool, success: bool): +def test_filter_dtypes( + schema: dict[str, DataTypeClass], cast: bool, success: bool +) -> None: df = pl.DataFrame(schema=schema) try: MySchema.filter(df, cast=cast) @@ -89,7 +91,7 @@ def test_filter_failure( failure_mask: list[bool], counts: dict[str, int], cooccurrence_counts: dict[frozenset[str], int], -): +) -> None: df = df_type({"a": data_a, "b": data_b}) df_valid, failures = MySchema.filter(df) assert isinstance(df_valid, pl.DataFrame) @@ -101,7 +103,7 @@ def test_filter_failure( @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_filter_no_rules(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_filter_no_rules(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: schema = create_schema("test", {"a": dy.Int64()}) df = df_type({"a": [1, 2, 3]}) df_valid, failures = schema.filter(df) @@ -113,7 +115,9 @@ def test_filter_no_rules(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_filter_with_rule_all_valid(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_filter_with_rule_all_valid( + df_type: type[pl.DataFrame] | type[pl.LazyFrame], +) -> None: schema = create_schema("test", {"a": dy.String(min_length=3)}) df = df_type({"a": ["foo", "foobar"]}) df_valid, failures = schema.filter(df) @@ -125,7 +129,7 @@ def test_filter_with_rule_all_valid(df_type: type[pl.DataFrame] | type[pl.LazyFr @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_filter_cast(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_filter_cast(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: data = { # validation: [true, true, false, false, false, false] "a": ["1", "2", "foo", None, "123x", "9223372036854775808"], @@ -152,7 +156,7 @@ def test_filter_cast(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): } -def test_filter_nondeterministic_lazyframe(): +def test_filter_nondeterministic_lazyframe() -> None: n = 10_000 lf = pl.LazyFrame( { diff --git a/tests/schema/test_inheritance.py b/tests/schema/test_inheritance.py index c6bdc36..14d7b5a 100644 --- a/tests/schema/test_inheritance.py +++ b/tests/schema/test_inheritance.py @@ -16,7 +16,7 @@ class GrandchildSchema(ChildSchema): c = dy.Integer() -def test_columns(): +def test_columns() -> None: assert ParentSchema.column_names() == ["a"] assert ChildSchema.column_names() == ["a", "b"] assert GrandchildSchema.column_names() == ["a", "b", "c"] diff --git a/tests/schema/test_rule_implementation.py b/tests/schema/test_rule_implementation.py index 2530eeb..e8a3b97 100644 --- a/tests/schema/test_rule_implementation.py +++ b/tests/schema/test_rule_implementation.py @@ -10,7 +10,7 @@ from dataframely.testing import create_schema -def test_group_rule_group_by_error(): +def test_group_rule_group_by_error() -> None: with pytest.raises( ImplementationError, match=( @@ -29,7 +29,7 @@ def test_group_rule_group_by_error(): ) -def test_rule_implementation_error(): +def test_rule_implementation_error() -> None: with pytest.raises( RuleImplementationError, match=r"rule 'integer_rule'.*returns dtype 'Int64'" ): @@ -40,7 +40,7 @@ def test_rule_implementation_error(): ) -def test_group_rule_implementation_error(): +def test_group_rule_implementation_error() -> None: with pytest.raises( RuleImplementationError, match=( @@ -55,7 +55,7 @@ def test_group_rule_implementation_error(): ) -def test_rule_column_overlap_error(): +def test_rule_column_overlap_error() -> None: with pytest.raises( ImplementationError, match=r"Rules and columns must not be named equally but found 1 overlaps", diff --git a/tests/schema/test_sample.py b/tests/schema/test_sample.py index 87751f6..e14daf3 100644 --- a/tests/schema/test_sample.py +++ b/tests/schema/test_sample.py @@ -60,7 +60,7 @@ def minimum_two_per_a() -> pl.Expr: @pytest.mark.parametrize("n", [0, 1000]) -def test_sample_deterministic(n: int): +def test_sample_deterministic(n: int) -> None: with dy.Config(max_sampling_iterations=1): df = MySimpleSchema.sample(n) MySimpleSchema.validate(df) @@ -68,27 +68,27 @@ def test_sample_deterministic(n: int): @pytest.mark.parametrize("schema", [PrimaryKeySchema, CheckSchema, ComplexSchema]) @pytest.mark.parametrize("n", [0, 1000]) -def test_sample_fuzzy(schema: type[dy.Schema], n: int): +def test_sample_fuzzy(schema: type[dy.Schema], n: int) -> None: df = schema.sample(n, generator=Generator(seed=42)) assert len(df) == n schema.validate(df) -def test_sample_fuzzy_failure(): +def test_sample_fuzzy_failure() -> None: with pytest.raises(ValueError): with dy.Config(max_sampling_iterations=5): ComplexSchema.sample(1000, generator=Generator(seed=42)) @pytest.mark.parametrize("n", [1, 1000]) -def test_sample_overrides(n: int): +def test_sample_overrides(n: int) -> None: df = CheckSchema.sample(overrides={"b": range(n)}) CheckSchema.validate(df) assert len(df) == n assert df.get_column("b").to_list() == list(range(n)) -def test_sample_overrides_with_removing_groups(): +def test_sample_overrides_with_removing_groups() -> None: generator = Generator() n = 333 # we cannot use something too large here or we'll never return overrides = np.random.randint(100, size=n) @@ -99,7 +99,7 @@ def test_sample_overrides_with_removing_groups(): @pytest.mark.parametrize("n", [1, 1000]) -def test_sample_overrides_allow_no_fuzzy(n: int): +def test_sample_overrides_allow_no_fuzzy(n: int) -> None: with dy.Config(max_sampling_iterations=1): df = CheckSchema.sample(n, overrides={"b": [0] * n}) CheckSchema.validate(df) @@ -108,29 +108,29 @@ def test_sample_overrides_allow_no_fuzzy(n: int): @pytest.mark.parametrize("n", [1, 1000]) -def test_sample_overrides_full(n: int): +def test_sample_overrides_full(n: int) -> None: df = CheckSchema.sample(n) df_override = CheckSchema.sample(n, overrides=df.to_dict()) assert_frame_equal(df, df_override) -def test_sample_overrides_row_layout(): +def test_sample_overrides_row_layout() -> None: df = MySimpleSchema.sample(overrides=[{"a": 1}, {"a": 2}, {"a": 3}]) assert len(df) == 3 assert df.get_column("a").to_list() == [1, 2, 3] -def test_sample_overrides_invalid_column(): +def test_sample_overrides_invalid_column() -> None: with pytest.raises(ValueError, match=r"not in the schema"): MySimpleSchema.sample(overrides={"foo": []}) -def test_sample_overrides_invalid_length(): +def test_sample_overrides_invalid_length() -> None: with pytest.raises(ValueError, match=r"`num_rows` is different"): MySimpleSchema.sample(3, overrides={"a": [1, 2]}) -def test_sample_no_overrides_no_num_rows(): +def test_sample_no_overrides_no_num_rows() -> None: # This case infers `num_rows == 1` df = MySimpleSchema.sample() MySimpleSchema.validate(df) diff --git a/tests/schema/test_validate.py b/tests/schema/test_validate.py index 59afa89..b45e9a8 100644 --- a/tests/schema/test_validate.py +++ b/tests/schema/test_validate.py @@ -32,7 +32,7 @@ def b_unique_within_a() -> pl.Expr: @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_missing_columns(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_missing_columns(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: df = df_type({"a": [1], "b": [""]}) with pytest.raises(ValidationError): MySchema.validate(df) @@ -43,7 +43,7 @@ def test_missing_columns(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_invalid_dtype(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_invalid_dtype(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: df = df_type({"a": [1], "b": [1], "c": [1]}) try: MySchema.validate(df) @@ -54,7 +54,7 @@ def test_invalid_dtype(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_invalid_dtype_cast(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_invalid_dtype_cast(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: df = df_type({"a": [1], "b": [1], "c": [1]}) actual = MySchema.validate(df, cast=True) expected = pl.DataFrame({"a": [1], "b": ["1"], "c": ["1"]}) @@ -66,7 +66,9 @@ def test_invalid_dtype_cast(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_invalid_column_contents(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_invalid_column_contents( + df_type: type[pl.DataFrame] | type[pl.LazyFrame], +) -> None: df = df_type({"a": [1, 2, 3], "b": ["x", "longtext", None], "c": ["1", None, "3"]}) try: MySchema.validate(df) @@ -78,7 +80,7 @@ def test_invalid_column_contents(df_type: type[pl.DataFrame] | type[pl.LazyFrame @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_invalid_primary_key(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_invalid_primary_key(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: df = df_type({"a": [1, 1], "b": ["x", "y"], "c": ["1", "2"]}) try: MySchema.validate(df) @@ -90,7 +92,7 @@ def test_invalid_primary_key(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) -def test_violated_custom_rule(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): +def test_violated_custom_rule(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: df = df_type({"a": [1, 1, 2, 3, 3], "b": [2, 2, 2, 4, 5]}) try: MyComplexSchema.validate(df) @@ -104,7 +106,7 @@ def test_violated_custom_rule(df_type: type[pl.DataFrame] | type[pl.LazyFrame]): @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) def test_success_multi_row_strip_cast( df_type: type[pl.DataFrame] | type[pl.LazyFrame], -): +) -> None: df = df_type( {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [1, None, None], "d": [1, 2, 3]} ) diff --git a/tests/test_compat.py b/tests/test_compat.py index f753e76..5f5d75c 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -6,7 +6,7 @@ from dataframely._compat import _DummyModule -def test_dummy_module(): +def test_dummy_module() -> None: module = "sqlalchemy" dm = _DummyModule(module=module) assert dm.module == module diff --git a/tests/test_config.py b/tests/test_config.py index f4fcd2a..ba3007e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,13 +4,13 @@ import dataframely as dy -def test_config_global(): +def test_config_global() -> None: dy.Config.set_max_sampling_iterations(50) assert dy.Config.options["max_sampling_iterations"] == 50 dy.Config.restore_defaults() -def test_config_local(): +def test_config_local() -> None: try: with dy.Config(max_sampling_iterations=35): assert dy.Config.options["max_sampling_iterations"] == 35 @@ -19,7 +19,7 @@ def test_config_local(): dy.Config.restore_defaults() -def test_config_local_nested(): +def test_config_local_nested() -> None: try: with dy.Config(max_sampling_iterations=35): assert dy.Config.options["max_sampling_iterations"] == 35 @@ -31,7 +31,7 @@ def test_config_local_nested(): dy.Config.restore_defaults() -def test_config_global_local(): +def test_config_global_local() -> None: try: dy.Config.set_max_sampling_iterations(50) assert dy.Config.options["max_sampling_iterations"] == 50 diff --git a/tests/test_exc.py b/tests/test_exc.py index 2629501..c1af303 100644 --- a/tests/test_exc.py +++ b/tests/test_exc.py @@ -6,13 +6,13 @@ from dataframely.exc import DtypeValidationError, RuleValidationError, ValidationError -def test_validation_error_str(): +def test_validation_error_str() -> None: message = "validation failed" exc = ValidationError(message) assert str(exc) == message -def test_dtype_validation_error_str(): +def test_dtype_validation_error_str() -> None: exc = DtypeValidationError( errors={"a": (pl.Int64, pl.String), "b": (pl.Boolean, pl.String)} ) @@ -23,7 +23,7 @@ def test_dtype_validation_error_str(): ] -def test_rule_validation_error_str(): +def test_rule_validation_error_str() -> None: exc = RuleValidationError( { "b|max_length": 1500, diff --git a/tests/test_extre.py b/tests/test_extre.py index c6ee883..9f657b6 100644 --- a/tests/test_extre.py +++ b/tests/test_extre.py @@ -25,14 +25,14 @@ ) def test_matching_string_length( regex: str, expected_lower: int, expected_upper: int | None -): +) -> None: actual_lower, actual_upper = extre.matching_string_length(regex) assert actual_lower == expected_lower assert actual_upper == expected_upper @pytest.mark.parametrize("regex", [r"(?=[A-Za-z\d])"]) -def test_failing_matching_string_length(regex: str): +def test_failing_matching_string_length(regex: str) -> None: with pytest.raises(ValueError): extre.matching_string_length(regex) @@ -52,40 +52,40 @@ def test_failing_matching_string_length(regex: str): @pytest.mark.parametrize("regex", TEST_REGEXES) -def test_sample_one(regex: str): +def test_sample_one(regex: str) -> None: sample = extre.sample(regex, max_repetitions=10) assert re.fullmatch(regex, sample) is not None @pytest.mark.parametrize("regex", TEST_REGEXES) -def test_sample_many(regex: str): +def test_sample_many(regex: str) -> None: samples = extre.sample(regex, n=100, max_repetitions=10) assert all(re.fullmatch(regex, s) is not None for s in samples) -def test_sample_equal_alternation_probabilities(): +def test_sample_equal_alternation_probabilities() -> None: n = 100_000 samples = extre.sample("a|b|c", n=n) np.allclose(np.unique_counts(samples).counts / n, np.ones(3) / 3, atol=0.01) -def test_sample_max_repetitions(): +def test_sample_max_repetitions() -> None: samples = extre.sample(".*", n=100_000, max_repetitions=10) assert max(len(s) for s in samples) == 10 assert math.isclose(np.mean([len(s) for s in samples]), 5, abs_tol=0.05) -def test_sample_equal_class_probabilities(): +def test_sample_equal_class_probabilities() -> None: n = 1_000_000 samples = extre.sample("[a-z0-9]", n=n) np.allclose(np.unique_counts(samples).counts / n, np.ones(36) / 36, atol=0.001) -def test_sample_one_seed(): +def test_sample_one_seed() -> None: choices = [extre.sample("a|b", seed=42) for _ in range(10_000)] assert len(set(choices)) == 1 -def test_sample_many_seed(): +def test_sample_many_seed() -> None: choices = extre.sample("a|b", n=10_000, seed=42) assert len(set(choices)) == 2 diff --git a/tests/test_failure_info.py b/tests/test_failure_info.py index 1c60ea2..c8d736e 100644 --- a/tests/test_failure_info.py +++ b/tests/test_failure_info.py @@ -14,7 +14,7 @@ class MySchema(dy.Schema): b = dy.Integer(nullable=False, is_in=[1, 2, 3, 5, 7, 11]) -def test_read_write_parquet(tmp_path: Path): +def test_read_write_parquet(tmp_path: Path) -> None: df = pl.DataFrame( { "a": [4, 5, 6, 6, 7, 8], diff --git a/tests/test_random.py b/tests/test_random.py index 65198a3..4755a93 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -22,12 +22,12 @@ def generator() -> Generator: # -------------------------------- GENERAL PROPERTIES -------------------------------- # -def test_seeding_constant(): +def test_seeding_constant() -> None: results = {Generator(seed=42).sample_seed() for _ in range(1000)} assert len(results) == 1 -def test_seeding_nonconstant(): +def test_seeding_nonconstant() -> None: results = {Generator().sample_seed() for _ in range(1000)} assert len(results) > 1 @@ -55,7 +55,7 @@ def test_seeding_nonconstant(): @pytest.mark.parametrize("n", [1, 100]) def test_sample_correct_n( generator: Generator, fn: Callable[[Generator, int], pl.Series], n: int -): +) -> None: assert len(fn(generator, n)) == n @@ -94,7 +94,7 @@ def test_sample_correct_null_probability( generator: Generator, fn: Callable[[Generator, int, float], pl.Series], null_probability: float, -): +) -> None: n = 100_000 assert math.isclose( fn(generator, n, null_probability).is_null().sum() / n, @@ -106,7 +106,7 @@ def test_sample_correct_null_probability( # ---------------------------- INDIVIDUAL SAMPLING METHODS --------------------------- # -def test_sample_int(generator: Generator): +def test_sample_int(generator: Generator) -> None: samples = generator.sample_int(100_000, min=1, max=4) assert samples.min() == 1 assert samples.max() == 3 @@ -114,24 +114,24 @@ def test_sample_int(generator: Generator): @pytest.mark.parametrize("p_true", [0, 0.1, 0.5, None, 0.9, 1.0]) -def test_sample_bool(generator: Generator, p_true: bool | None): +def test_sample_bool(generator: Generator, p_true: bool | None) -> None: samples = generator.sample_bool(100_000, p_true=p_true) assert math.isclose(samples.mean(), p_true or 0.5, abs_tol=0.01) # type: ignore -def test_sample_float(generator: Generator): +def test_sample_float(generator: Generator) -> None: samples = generator.sample_float(100_000, min=1, max=3) assert samples.min() >= 1 # type: ignore assert samples.max() < 3 # type: ignore assert math.isclose(samples.mean(), 2, abs_tol=0.01) # type: ignore -def test_sample_string(generator: Generator): +def test_sample_string(generator: Generator) -> None: samples = generator.sample_string(100_000, regex="[abc]d") assert (samples.str.len_bytes() == 2).all() -def test_sample_choice(generator: Generator): +def test_sample_choice(generator: Generator) -> None: samples = generator.sample_choice(100_000, choices=[1, 2, 3]) assert np.allclose( samples.value_counts().sort("").get_column("count") / 100_000, @@ -141,7 +141,7 @@ def test_sample_choice(generator: Generator): @pytest.mark.parametrize("weight_factor", [0.01, 1, 1000]) -def test_sample_choice_weights(generator: Generator, weight_factor: float): +def test_sample_choice_weights(generator: Generator, weight_factor: float) -> None: with pytest.raises(ValueError): generator.sample_choice( 100, choices=[1, 2, 3], null_probability=0.1, weights=[1] @@ -200,7 +200,7 @@ def test_sample_resolutions( fn: Callable[[Generator, str], pl.Series], column_type: type[dy.Column], resolution: str, -): +) -> None: samples = fn(generator, resolution) schema = create_schema("test", {"a": column_type(resolution=resolution)}) # type: ignore schema.validate(samples.to_frame("a")) @@ -247,6 +247,8 @@ def test_sample_resolutions( ), ], ) -def test_sample_invalid_arg(generator: Generator, fn: Callable[[Generator], pl.Series]): +def test_sample_invalid_arg( + generator: Generator, fn: Callable[[Generator], pl.Series] +) -> None: with pytest.raises(ValueError): fn(generator) diff --git a/tests/test_typing.py b/tests/test_typing.py index fe88d63..3fae814 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -32,17 +32,17 @@ class Schema(dy.Schema): a = dy.Int64() -def test_data_frame_lazy(): +def test_data_frame_lazy() -> None: df = Schema.create_empty() df.lazy() -def test_lazy_frame_lazy(): +def test_lazy_frame_lazy() -> None: df = Schema.create_empty(lazy=True) df.lazy() -def test_lazy_frame_collect(): +def test_lazy_frame_collect() -> None: df = Schema.create_empty(lazy=True) df.collect() @@ -66,7 +66,7 @@ class MyCollection(dy.Collection): second: dy.LazyFrame[MySecondSchema] -def test_collection_filter_return_value(): +def test_collection_filter_return_value() -> None: _, failure = MyCollection.filter( {"first": pl.LazyFrame(), "second": pl.LazyFrame()}, ) @@ -121,7 +121,9 @@ def my_schema_df() -> dy.DataFrame[MySchema]: ) -def test_iter_rows_assignment_correct_type(my_schema_df: dy.DataFrame[MySchema]): +def test_iter_rows_assignment_correct_type( + my_schema_df: dy.DataFrame[MySchema], +) -> None: entry = next(my_schema_df.iter_rows(named=True)) a: int = entry["a"] # noqa: F841 @@ -129,7 +131,7 @@ def test_iter_rows_assignment_correct_type(my_schema_df: dy.DataFrame[MySchema]) c: list[Any] = entry["custom_col_list"] # noqa: F841 -def test_iter_rows_schema_subtypes(my_schema_df: dy.DataFrame[MySchema]): +def test_iter_rows_schema_subtypes(my_schema_df: dy.DataFrame[MySchema]) -> None: class MySubSchema(MySchema): i = dy.Int64() @@ -150,32 +152,32 @@ class MySubSubSchema(MySubSchema): j2: int = entry2["j"] # noqa: F841 -def test_iter_rows_assignment_wrong_type(my_schema_df: dy.DataFrame[MySchema]): +def test_iter_rows_assignment_wrong_type(my_schema_df: dy.DataFrame[MySchema]) -> None: entry = next(my_schema_df.iter_rows(named=True)) a: int = entry["b"] # type: ignore[assignment] # noqa: F841 -def test_iter_rows_read_only(my_schema_df: dy.DataFrame[MySchema]): +def test_iter_rows_read_only(my_schema_df: dy.DataFrame[MySchema]) -> None: entry = next(my_schema_df.iter_rows(named=True)) entry["a"] = 1 # type: ignore[typeddict-readonly-mutated] -def test_iter_rows_missing_key(my_schema_df: dy.DataFrame[MySchema]): +def test_iter_rows_missing_key(my_schema_df: dy.DataFrame[MySchema]) -> None: entry = next(my_schema_df.iter_rows(named=True)) _ = entry["i"] # type: ignore[misc] -def test_iter_rows_without_named(my_schema_df: dy.DataFrame[MySchema]): +def test_iter_rows_without_named(my_schema_df: dy.DataFrame[MySchema]) -> None: # Make sure we don't accidentally override the return type of `iter_rows` with `named=False`. entry = next(my_schema_df.iter_rows(named=False)) _ = entry["g"] # type: ignore[call-overload] -def test_iter_rows_imported_schema(): +def test_iter_rows_imported_schema() -> None: my_imported_schema_df = MyImportedSchema.validate( pl.DataFrame( { @@ -200,7 +202,7 @@ def test_iter_rows_imported_schema(): _ = entry["i"] # type: ignore[misc] -def test_iter_rows_imported_subschema(): +def test_iter_rows_imported_subschema() -> None: class MySubFromImportedSchema(MyImportedSchema): i = dy.Int64() From 9990d6aeec674e45f0fb564c6b828e319bf756e0 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Sun, 20 Apr 2025 10:34:28 +0200 Subject: [PATCH 2/2] Update quickstart.rst Co-authored-by: Daniel Elsner --- docs/sites/quickstart.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sites/quickstart.rst b/docs/sites/quickstart.rst index f340300..8361023 100644 --- a/docs/sites/quickstart.rst +++ b/docs/sites/quickstart.rst @@ -197,7 +197,7 @@ In this case, ``good`` remains to be a ``dy.DataFrame[HouseSchema]``, albeit wit The ``failure`` object is of type :class:`~dataframely.FailureInfo` and provides means to inspect the reasons for validation failures for invalid rows. -Given the example data above and the schema that we defined, we know that rows 2, 3, 4, and 5 are invalid (0-indexed) -> None: +Given the example data above and the schema that we defined, we know that rows 2, 3, 4, and 5 are invalid (0-indexed): - Row 2 has a zip code that does not appear at least twice - Row 3 has a NULL value for the number of bedrooms