From fe3bfd9c13ee5f4817b8fdcc592b1b3a78f2d3b3 Mon Sep 17 00:00:00 2001 From: "anthropic-code-agent[bot]" <242468646+Claude@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:09:10 +0000 Subject: [PATCH 01/12] Initial plan From 4de9e25f8c4cfa629e1e16c8618e85cbd19bc2ed Mon Sep 17 00:00:00 2001 From: "anthropic-code-agent[bot]" <242468646+Claude@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:22:00 +0000 Subject: [PATCH 02/12] Add pydantic conversion support to Column and Schema classes Agent-Logs-Url: https://github.com/Quantco/dataframely/sessions/c7146408-7625-42bd-93a1-9681303abab7 Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> --- dataframely/columns/_base.py | 31 +++++ dataframely/columns/any.py | 7 + dataframely/columns/array.py | 26 ++++ dataframely/columns/binary.py | 8 ++ dataframely/columns/bool.py | 8 ++ dataframely/columns/categorical.py | 9 ++ dataframely/columns/datetime.py | 197 +++++++++++++++++++++++++++++ dataframely/columns/decimal.py | 53 +++++++- dataframely/columns/enum.py | 19 ++- dataframely/columns/float.py | 50 ++++++++ dataframely/columns/integer.py | 57 ++++++++- dataframely/columns/list.py | 32 +++++ dataframely/columns/object.py | 9 ++ dataframely/columns/string.py | 34 ++++- dataframely/columns/struct.py | 23 ++++ dataframely/schema.py | 54 ++++++++ 16 files changed, 613 insertions(+), 4 deletions(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index e01a1c8..9cf8ee9 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -222,6 +222,37 @@ def pyarrow_field(self, name: str) -> pa.Field: def pyarrow_dtype(self) -> pa.DataType: """The :mod:`pyarrow` dtype equivalent of this column data type.""" + # ----------------------------------- PYDANTIC ----------------------------------- # + + def pydantic_field(self) -> Any: + """Obtain a pydantic field type for this column definition. + + Returns: + A pydantic-compatible type annotation that includes structured constraints + (e.g., min, max, regex) but excludes custom checks. + + Warning: + Custom checks defined via the `check` parameter are not included in the + returned pydantic field. A UserWarning is raised if custom checks are present. + """ + import warnings + + from dataframely._compat import pydantic + + if self.check is not None: + warnings.warn( + f"Custom checks for column '{self.name or self.__class__.__name__}' " + "are not translated to pydantic constraints.", + UserWarning, + stacklevel=2, + ) + + return self._pydantic_field_inner() + + @abstractmethod + def _pydantic_field_inner(self) -> Any: + """Subclasses implement this to return the actual pydantic field type.""" + # ------------------------------------ HELPER ------------------------------------ # @property diff --git a/dataframely/columns/any.py b/dataframely/columns/any.py index ca3ca22..44cdef7 100644 --- a/dataframely/columns/any.py +++ b/dataframely/columns/any.py @@ -79,3 +79,10 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return pl.repeat(None, n, dtype=pl.Null, eager=True) + + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Any column.""" + from typing import Any as AnyType + + # Any columns are always nullable + return AnyType diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index e100449..8ad0b55 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -148,3 +148,29 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]: def from_dict(cls, data: dict[str, Any]) -> Self: data["inner"] = column_from_dict(data["inner"]) return super().from_dict(data) + + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Array column.""" + import warnings + from typing import Union + + warnings.warn( + f"Array column '{self.name or self.__class__.__name__}' cannot be fully " + "translated to pydantic. Using list as the base type.", + UserWarning, + stacklevel=3, + ) + + # Get the inner type + inner_type = self.inner.pydantic_field() + + # Build the type annotation - use list for arrays + from typing import List as ListType + + base_type = ListType[inner_type] # type: ignore + + # Handle nullability + if self.nullable: + return Union[base_type, None] # type: ignore + + return base_type # type: ignore diff --git a/dataframely/columns/binary.py b/dataframely/columns/binary.py index a6a6b4b..677df3d 100644 --- a/dataframely/columns/binary.py +++ b/dataframely/columns/binary.py @@ -38,3 +38,11 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: max_bytes=32, null_probability=self._null_probability, ) + + def _pydantic_field_inner(self) -> type[bytes] | None: + """Return pydantic field type for binary column.""" + if self.nullable: + from typing import Union + + return Union[bytes, None] # type: ignore + return bytes diff --git a/dataframely/columns/bool.py b/dataframely/columns/bool.py index fcc3e91..1708b99 100644 --- a/dataframely/columns/bool.py +++ b/dataframely/columns/bool.py @@ -29,3 +29,11 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_bool(n, null_probability=self._null_probability) + + def _pydantic_field_inner(self) -> type[bool] | None: + """Return pydantic field type for bool column.""" + if self.nullable: + from typing import Union + + return Union[bool, None] # type: ignore + return bool diff --git a/dataframely/columns/categorical.py b/dataframely/columns/categorical.py index cdba3d1..db34c88 100644 --- a/dataframely/columns/categorical.py +++ b/dataframely/columns/categorical.py @@ -76,3 +76,12 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_string( n, regex=r"[a-z]{1,2}", null_probability=self._null_probability ).cast(self.dtype) + + def _pydantic_field_inner(self) -> type[str] | None: + """Return pydantic field type for categorical column.""" + # Categorical is essentially a string type + if self.nullable: + from typing import Union + + return Union[str, None] # type: ignore + return str diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index 030eee5..818efad 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -149,6 +149,50 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Date column.""" + import datetime as dt + import warnings + from typing import Annotated + + from dataframely._compat import pydantic + + # Warn about untranslated constraints + if self.resolution is not None: + warnings.warn( + f"Date column '{self.name or self.__class__.__name__}' has a resolution " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + + # Build constraints + merged_kwargs = {} + if self.min is not None: + merged_kwargs["ge"] = self.min + if self.min_exclusive is not None: + merged_kwargs["gt"] = self.min_exclusive + if self.max is not None: + merged_kwargs["le"] = self.max + if self.max_exclusive is not None: + merged_kwargs["lt"] = self.max_exclusive + + # Build the type annotation + base_type = dt.date + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + @register class Time(OrdinalMixin[dt.time], Column): @@ -278,6 +322,50 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Time column.""" + import datetime as dt + import warnings + from typing import Annotated + + from dataframely._compat import pydantic + + # Warn about untranslated constraints + if self.resolution is not None: + warnings.warn( + f"Time column '{self.name or self.__class__.__name__}' has a resolution " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + + # Build constraints + merged_kwargs = {} + if self.min is not None: + merged_kwargs["ge"] = self.min + if self.min_exclusive is not None: + merged_kwargs["gt"] = self.min_exclusive + if self.max is not None: + merged_kwargs["le"] = self.max + if self.max_exclusive is not None: + merged_kwargs["lt"] = self.max_exclusive + + # Build the type annotation + base_type = dt.time + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + @register class Datetime(OrdinalMixin[dt.datetime], Column): @@ -425,6 +513,64 @@ def _attributes_match( return lhs.utcoffset(now) == rhs.utcoffset(now) return super()._attributes_match(lhs, rhs, name, column_expr) + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Datetime column.""" + import datetime as dt + import warnings + from typing import Annotated + + from dataframely._compat import pydantic + + # Warn about untranslated constraints + if self.resolution is not None: + warnings.warn( + f"Datetime column '{self.name or self.__class__.__name__}' has a resolution " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + if self.time_zone is not None: + warnings.warn( + f"Datetime column '{self.name or self.__class__.__name__}' has a time_zone " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + if self.time_unit != "us": + warnings.warn( + f"Datetime column '{self.name or self.__class__.__name__}' has a time_unit " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + + # Build constraints + merged_kwargs = {} + if self.min is not None: + merged_kwargs["ge"] = self.min + if self.min_exclusive is not None: + merged_kwargs["gt"] = self.min_exclusive + if self.max is not None: + merged_kwargs["le"] = self.max + if self.max_exclusive is not None: + merged_kwargs["lt"] = self.max_exclusive + + # Build the type annotation + base_type = dt.datetime + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + @register class Duration(OrdinalMixin[dt.timedelta], Column): @@ -550,6 +696,57 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Duration column.""" + import datetime as dt + import warnings + from typing import Annotated + + from dataframely._compat import pydantic + + # Warn about untranslated constraints + if self.resolution is not None: + warnings.warn( + f"Duration column '{self.name or self.__class__.__name__}' has a resolution " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + if self.time_unit != "us": + warnings.warn( + f"Duration column '{self.name or self.__class__.__name__}' has a time_unit " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + + # Build constraints + merged_kwargs = {} + if self.min is not None: + merged_kwargs["ge"] = self.min + if self.min_exclusive is not None: + merged_kwargs["gt"] = self.min_exclusive + if self.max is not None: + merged_kwargs["le"] = self.max + if self.max_exclusive is not None: + merged_kwargs["lt"] = self.max_exclusive + + # Build the type annotation + base_type = dt.timedelta + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index e29ca76..09cab12 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -8,7 +8,7 @@ import polars as pl -from dataframely._compat import pa, sa, sa_TypeEngine +from dataframely._compat import pa, pydantic, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -169,6 +169,57 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: ) return ((samples * 10**self.scale).floor() / 10**self.scale).cast(self.dtype) + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Decimal column.""" + import decimal + import warnings + from typing import Annotated + + from dataframely._compat import pydantic + + # Warn about untranslated constraints + if self.precision is not None: + warnings.warn( + f"Decimal column '{self.name or self.__class__.__name__}' has a precision " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + if self.scale != 0: + warnings.warn( + f"Decimal column '{self.name or self.__class__.__name__}' has a scale " + "constraint that cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + + # Build constraints + merged_kwargs = {} + if self.min is not None: + merged_kwargs["ge"] = self.min + if self.min_exclusive is not None: + merged_kwargs["gt"] = self.min_exclusive + if self.max is not None: + merged_kwargs["le"] = self.max + if self.max_exclusive is not None: + merged_kwargs["lt"] = self.max_exclusive + + # Build the type annotation + base_type = decimal.Decimal + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 878c0d2..f958143 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -10,7 +10,7 @@ import polars as pl -from dataframely._compat import pa, sa, sa_TypeEngine +from dataframely._compat import pa, pydantic, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -101,3 +101,20 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: choices=self.categories, null_probability=self._null_probability, ).cast(self.dtype) + + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Enum column.""" + from typing import Literal, Union + + # Use Literal for enum categories + if len(self.categories) == 0: + # Empty enum, use str as fallback + base_type = str + else: + base_type = Literal[tuple(self.categories)] # type: ignore + + # Handle nullability + if self.nullable: + return Union[base_type, None] # type: ignore + + return base_type # type: ignore diff --git a/dataframely/columns/float.py b/dataframely/columns/float.py index 2f5a5bd..4a416e6 100644 --- a/dataframely/columns/float.py +++ b/dataframely/columns/float.py @@ -142,6 +142,56 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: inf_probability=self._inf_probability, ).cast(self.dtype) + def _pydantic_field_inner(self) -> type[float] | None: + """Return pydantic field type for float column.""" + import warnings + from typing import Annotated + + from dataframely._compat import pydantic + + # Warn about untranslated constraints + if not self.allow_inf: + warnings.warn( + f"Float column '{self.name or self.__class__.__name__}' does not allow " + "infinity values, but this constraint cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + if not self.allow_nan: + warnings.warn( + f"Float column '{self.name or self.__class__.__name__}' does not allow " + "NaN values, but this constraint cannot be translated to pydantic.", + UserWarning, + stacklevel=3, + ) + + # Build constraints + merged_kwargs = {} + if self.min is not None: + merged_kwargs["ge"] = self.min + if self.min_exclusive is not None: + merged_kwargs["gt"] = self.min_exclusive + if self.max is not None: + merged_kwargs["le"] = self.max + if self.max_exclusive is not None: + merged_kwargs["lt"] = self.max_exclusive + + # Build the type annotation + base_type = float + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/integer.py b/dataframely/columns/integer.py index 9fe99bc..ef9061e 100644 --- a/dataframely/columns/integer.py +++ b/dataframely/columns/integer.py @@ -10,7 +10,7 @@ import polars as pl from polars.datatypes.group import INTEGER_DTYPES -from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine +from dataframely._compat import pa, pydantic, sa, sa_mssql, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -142,6 +142,61 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ).cast(self.dtype) + def _pydantic_field_inner(self) -> type[int] | None: + """Return pydantic field type for integer column.""" + import warnings + from typing import Annotated + + from dataframely._compat import pydantic + + # Build constraints list + constraints = [] + + if self.is_in is not None: + # Use Literal for enumerations + from typing import Literal + + # Convert to tuple since Literal requires literal values + return Literal[tuple(self.is_in)] # type: ignore + + # Add range constraints + if self.min is not None: + constraints.append(pydantic.Field(ge=self.min)) + if self.min_exclusive is not None: + constraints.append(pydantic.Field(gt=self.min_exclusive)) + if self.max is not None: + constraints.append(pydantic.Field(le=self.max)) + if self.max_exclusive is not None: + constraints.append(pydantic.Field(lt=self.max_exclusive)) + + # Build the type annotation + base_type = int + + if constraints: + # Merge all Field constraints + merged_kwargs = {} + for constraint in constraints: + if hasattr(constraint, "ge") and constraint.ge is not None: + merged_kwargs["ge"] = constraint.ge + if hasattr(constraint, "gt") and constraint.gt is not None: + merged_kwargs["gt"] = constraint.gt + if hasattr(constraint, "le") and constraint.le is not None: + merged_kwargs["le"] = constraint.le + if hasattr(constraint, "lt") and constraint.lt is not None: + merged_kwargs["lt"] = constraint.lt + + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index cd5ed80..e620221 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -177,6 +177,38 @@ def from_dict(cls, data: dict[str, Any]) -> Self: data["inner"] = column_from_dict(data["inner"]) return super().from_dict(data) + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for List column.""" + from typing import Annotated, Union + + from dataframely._compat import pydantic + + # Get the inner type + inner_type = self.inner.pydantic_field() + + # Build constraints + merged_kwargs = {} + if self.min_length is not None: + merged_kwargs["min_length"] = self.min_length + if self.max_length is not None: + merged_kwargs["max_length"] = self.max_length + + # Build the type annotation + from typing import List as ListType + + base_type = ListType[inner_type] # type: ignore + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore + def _list_primary_key_check( list_expr: ExprListNameSpace | ExprArrayNameSpace, inner: Column diff --git a/dataframely/columns/object.py b/dataframely/columns/object.py index 31e8650..fc99d20 100644 --- a/dataframely/columns/object.py +++ b/dataframely/columns/object.py @@ -71,3 +71,12 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: raise NotImplementedError( "Random data sampling not implemented for 'Object' type." ) + + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Object column.""" + from typing import Any as AnyType, Union + + # Object columns can contain any Python object + if self.nullable: + return Union[AnyType, None] # type: ignore + return AnyType diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index c21ba7e..055fad0 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -7,7 +7,7 @@ import polars as pl -from dataframely._compat import pa, sa, sa_TypeEngine +from dataframely._compat import pa, pydantic, sa, sa_TypeEngine from dataframely._native import regex_matching_string_length from dataframely.random import Generator @@ -137,3 +137,35 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: regex=regex, null_probability=self._null_probability, ) + + def _pydantic_field_inner(self) -> type[str] | None: + """Return pydantic field type for string column.""" + from typing import Annotated + + from dataframely._compat import pydantic + + # Build constraints + merged_kwargs = {} + if self.min_length is not None: + merged_kwargs["min_length"] = self.min_length + if self.max_length is not None: + merged_kwargs["max_length"] = self.max_length + if self.regex is not None: + # Pydantic uses 'pattern' for regex validation + merged_kwargs["pattern"] = self.regex + + # Build the type annotation + base_type = str + + if merged_kwargs: + annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + else: + annotated_type = base_type + + # Handle nullability + if self.nullable: + from typing import Union + + return Union[annotated_type, None] # type: ignore + + return annotated_type # type: ignore diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index edf597b..584472c 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -156,3 +156,26 @@ def from_dict(cls, data: dict[str, Any]) -> Self: name: column_from_dict(col) for name, col in data["inner"].items() } return super().from_dict(data) + + def _pydantic_field_inner(self) -> type: + """Return pydantic field type for Struct column.""" + from typing import Union + + from dataframely._compat import pydantic + + # Build a pydantic model for the struct + fields = {} + for field_name, col in self.inner.items(): + field_type = col.pydantic_field() + fields[field_name] = (field_type, ...) + + # Create a dynamic pydantic model + model = pydantic.create_model( + f"{self.name or 'Struct'}Model", **fields # type: ignore + ) + + # Handle nullability + if self.nullable: + return Union[model, None] # type: ignore + + return model # type: ignore diff --git a/dataframely/schema.py b/dataframely/schema.py index 1d526ff..f79148f 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -1339,6 +1339,60 @@ def to_pyarrow_schema(cls) -> pa.Schema: [col.pyarrow_field(name) for name, col in cls.columns().items()] ) + @classmethod + def to_pydantic_model(cls) -> type: + """Convert this schema to a pydantic model. + + This method creates a pydantic model with fields corresponding to the columns + in this schema. Structured constraints (min, max, regex, etc.) are translated + to pydantic field constraints where possible. Custom checks and group rules + are not translated. + + Returns: + A pydantic model class that can be used for data validation. + + Warning: + Custom checks defined via the `check` parameter on columns are not + included in the pydantic model. A UserWarning is raised if custom checks + are present. + + Warning: + Group rules defined on the schema are not translated to pydantic validators. + A UserWarning is raised if group rules are present. + + Example: + >>> class MySchema(dy.Schema): + ... x = dy.Integer(min=0, max=100) + ... y = dy.String(regex=r"^[A-Z]+$") + ... + >>> Model = MySchema.to_pydantic_model() + >>> # Now you can use Model for validation, e.g., with LLM APIs + """ + import warnings + + from dataframely._compat import pydantic + + # Check for group rules and warn if present + if cls._schema_validation_rules(): + warnings.warn( + f"Schema '{cls.__name__}' has group rules that cannot be " + "translated to pydantic validators.", + UserWarning, + stacklevel=2, + ) + + # Build field definitions for the pydantic model + fields = {} + for col_name, col in cls.columns().items(): + # Get the pydantic field type for this column + field_type = col.pydantic_field() + # All fields are required (use ... as default) + fields[col_name] = (field_type, ...) + + # Create the pydantic model dynamically + model_name = f"{cls.__name__}PydanticModel" + return pydantic.create_model(model_name, **fields) # type: ignore + # ----------------------------------- EQUALITY ----------------------------------- # @classmethod From 07e707ac318047afbf607df3b720932fff3a83ef Mon Sep 17 00:00:00 2001 From: "anthropic-code-agent[bot]" <242468646+Claude@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:29:27 +0000 Subject: [PATCH 03/12] Add pydantic model conversion with tests and fix lint issues Agent-Logs-Url: https://github.com/Quantco/dataframely/sessions/c7146408-7625-42bd-93a1-9681303abab7 Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> --- dataframely/columns/_base.py | 2 - dataframely/columns/array.py | 3 +- dataframely/columns/decimal.py | 2 - dataframely/columns/enum.py | 2 +- dataframely/columns/integer.py | 64 ++--- dataframely/columns/list.py | 3 +- dataframely/columns/object.py | 3 +- dataframely/columns/string.py | 2 - dataframely/columns/struct.py | 3 +- tests/test_pydantic_conversion.py | 392 ++++++++++++++++++++++++++++++ 10 files changed, 419 insertions(+), 57 deletions(-) create mode 100644 tests/test_pydantic_conversion.py diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index 9cf8ee9..ded2573 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -237,8 +237,6 @@ def pydantic_field(self) -> Any: """ import warnings - from dataframely._compat import pydantic - if self.check is not None: warnings.warn( f"Custom checks for column '{self.name or self.__class__.__name__}' " diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index 8ad0b55..f13e92d 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -165,9 +165,8 @@ def _pydantic_field_inner(self) -> type: inner_type = self.inner.pydantic_field() # Build the type annotation - use list for arrays - from typing import List as ListType - base_type = ListType[inner_type] # type: ignore + base_type = list[inner_type] # type: ignore # Handle nullability if self.nullable: diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index 09cab12..e959a72 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -175,8 +175,6 @@ def _pydantic_field_inner(self) -> type: import warnings from typing import Annotated - from dataframely._compat import pydantic - # Warn about untranslated constraints if self.precision is not None: warnings.warn( diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index f958143..e5d9aad 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -10,7 +10,7 @@ import polars as pl -from dataframely._compat import pa, pydantic, sa, sa_TypeEngine +from dataframely._compat import pa, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator diff --git a/dataframely/columns/integer.py b/dataframely/columns/integer.py index ef9061e..f6f4b47 100644 --- a/dataframely/columns/integer.py +++ b/dataframely/columns/integer.py @@ -144,58 +144,34 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type[int] | None: """Return pydantic field type for integer column.""" - import warnings - from typing import Annotated - - from dataframely._compat import pydantic - - # Build constraints list - constraints = [] + from typing import Annotated, Literal, Union + # Build constraints if self.is_in is not None: # Use Literal for enumerations - from typing import Literal - - # Convert to tuple since Literal requires literal values - return Literal[tuple(self.is_in)] # type: ignore - - # Add range constraints - if self.min is not None: - constraints.append(pydantic.Field(ge=self.min)) - if self.min_exclusive is not None: - constraints.append(pydantic.Field(gt=self.min_exclusive)) - if self.max is not None: - constraints.append(pydantic.Field(le=self.max)) - if self.max_exclusive is not None: - constraints.append(pydantic.Field(lt=self.max_exclusive)) - - # Build the type annotation - base_type = int - - if constraints: - # Merge all Field constraints - merged_kwargs = {} - for constraint in constraints: - if hasattr(constraint, "ge") and constraint.ge is not None: - merged_kwargs["ge"] = constraint.ge - if hasattr(constraint, "gt") and constraint.gt is not None: - merged_kwargs["gt"] = constraint.gt - if hasattr(constraint, "le") and constraint.le is not None: - merged_kwargs["le"] = constraint.le - if hasattr(constraint, "lt") and constraint.lt is not None: - merged_kwargs["lt"] = constraint.lt - - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + base_type = Literal[tuple(self.is_in)] # type: ignore else: - annotated_type = base_type + # Build range constraints in a single Field + field_kwargs = {} + if self.min is not None: + field_kwargs["ge"] = self.min + if self.min_exclusive is not None: + field_kwargs["gt"] = self.min_exclusive + if self.max is not None: + field_kwargs["le"] = self.max + if self.max_exclusive is not None: + field_kwargs["lt"] = self.max_exclusive + + if field_kwargs: + base_type = Annotated[int, pydantic.Field(**field_kwargs)] # type: ignore + else: + base_type = int # Handle nullability if self.nullable: - from typing import Union - - return Union[annotated_type, None] # type: ignore + return Union[base_type, None] # type: ignore - return annotated_type # type: ignore + return base_type # type: ignore # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index e620221..b6b9904 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -194,9 +194,8 @@ def _pydantic_field_inner(self) -> type: merged_kwargs["max_length"] = self.max_length # Build the type annotation - from typing import List as ListType - base_type = ListType[inner_type] # type: ignore + base_type = list[inner_type] # type: ignore if merged_kwargs: annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] diff --git a/dataframely/columns/object.py b/dataframely/columns/object.py index fc99d20..78c6deb 100644 --- a/dataframely/columns/object.py +++ b/dataframely/columns/object.py @@ -74,7 +74,8 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type: """Return pydantic field type for Object column.""" - from typing import Any as AnyType, Union + from typing import Any as AnyType + from typing import Union # Object columns can contain any Python object if self.nullable: diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index 055fad0..e3ac12a 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -142,8 +142,6 @@ def _pydantic_field_inner(self) -> type[str] | None: """Return pydantic field type for string column.""" from typing import Annotated - from dataframely._compat import pydantic - # Build constraints merged_kwargs = {} if self.min_length is not None: diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index 584472c..76fc8b0 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -171,7 +171,8 @@ def _pydantic_field_inner(self) -> type: # Create a dynamic pydantic model model = pydantic.create_model( - f"{self.name or 'Struct'}Model", **fields # type: ignore + f"{self.name or 'Struct'}Model", + **fields, # type: ignore ) # Handle nullability diff --git a/tests/test_pydantic_conversion.py b/tests/test_pydantic_conversion.py new file mode 100644 index 0000000..c6a9ed0 --- /dev/null +++ b/tests/test_pydantic_conversion.py @@ -0,0 +1,392 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for pydantic conversion functionality.""" + +import datetime as dt +import decimal +import warnings + +import pytest + +import dataframely as dy +from dataframely._compat import pydantic + +pytestmark = pytest.mark.with_optionals + + +def test_integer_column_pydantic_field() -> None: + # Arrange + column = dy.Integer(min=0, max=100) + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + # Should be annotated int with constraints + # We can test by creating a pydantic model with it + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value=50) + assert instance.value == 50 + + # Invalid value (too small) + with pytest.raises(pydantic.ValidationError): + Model(value=-1) + + # Invalid value (too large) + with pytest.raises(pydantic.ValidationError): + Model(value=101) + + +def test_integer_column_with_is_in() -> None: + # Arrange + column = dy.Integer(is_in=[1, 2, 3]) + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid values + for val in [1, 2, 3]: + instance = Model(value=val) + assert instance.value == val + + # Invalid value + with pytest.raises(pydantic.ValidationError): + Model(value=4) + + +def test_string_column_pydantic_field() -> None: + # Arrange + column = dy.String(min_length=3, max_length=10, regex=r"^[A-Z]+$") + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value="HELLO") + assert instance.value == "HELLO" + + # Invalid: too short + with pytest.raises(pydantic.ValidationError): + Model(value="HI") + + # Invalid: too long + with pytest.raises(pydantic.ValidationError): + Model(value="VERYLONGSTRING") + + # Invalid: doesn't match regex + with pytest.raises(pydantic.ValidationError): + Model(value="hello") + + +def test_float_column_pydantic_field() -> None: + # Arrange + column = dy.Float(min=0.0, max=1.0) + column._name = "test_col" + + # Act + with warnings.catch_warnings(record=True): + # Suppress warnings about allow_inf and allow_nan + warnings.simplefilter("always") + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value=0.5) + assert instance.value == 0.5 + + # Invalid: too small + with pytest.raises(pydantic.ValidationError): + Model(value=-0.1) + + # Invalid: too large + with pytest.raises(pydantic.ValidationError): + Model(value=1.1) + + +def test_bool_column_pydantic_field() -> None: + # Arrange + column = dy.Bool() + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid values + for val in [True, False]: + instance = Model(value=val) + assert instance.value == val + + +def test_date_column_pydantic_field() -> None: + # Arrange + column = dy.Date(min=dt.date(2020, 1, 1), max=dt.date(2025, 12, 31)) + column._name = "test_col" + + # Act + with warnings.catch_warnings(record=True): + # Suppress warnings about resolution + warnings.simplefilter("always") + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value=dt.date(2023, 6, 15)) + assert instance.value == dt.date(2023, 6, 15) + + # Invalid: too early + with pytest.raises(pydantic.ValidationError): + Model(value=dt.date(2019, 1, 1)) + + # Invalid: too late + with pytest.raises(pydantic.ValidationError): + Model(value=dt.date(2026, 1, 1)) + + +def test_datetime_column_pydantic_field() -> None: + # Arrange + column = dy.Datetime( + min=dt.datetime(2020, 1, 1), max=dt.datetime(2025, 12, 31, 23, 59, 59) + ) + column._name = "test_col" + + # Act + with warnings.catch_warnings(record=True): + # Suppress warnings + warnings.simplefilter("always") + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value=dt.datetime(2023, 6, 15, 12, 0, 0)) + assert instance.value == dt.datetime(2023, 6, 15, 12, 0, 0) + + +def test_enum_column_pydantic_field() -> None: + # Arrange + column = dy.Enum(categories=["red", "green", "blue"]) + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid values + for val in ["red", "green", "blue"]: + instance = Model(value=val) + assert instance.value == val + + # Invalid value + with pytest.raises(pydantic.ValidationError): + Model(value="yellow") + + +def test_list_column_pydantic_field() -> None: + # Arrange + inner = dy.Integer(min=0, max=100) + column = dy.List(inner, min_length=2, max_length=5) + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value=[1, 2, 3]) + assert instance.value == [1, 2, 3] + + # Invalid: too short + with pytest.raises(pydantic.ValidationError): + Model(value=[1]) + + # Invalid: too long + with pytest.raises(pydantic.ValidationError): + Model(value=[1, 2, 3, 4, 5, 6]) + + # Invalid: element out of range + with pytest.raises(pydantic.ValidationError): + Model(value=[1, 2, 101]) + + +def test_struct_column_pydantic_field() -> None: + # Arrange + column = dy.Struct({"x": dy.Integer(min=0), "y": dy.String(max_length=10)}) + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value={"x": 5, "y": "hello"}) + assert instance.value.x == 5 + assert instance.value.y == "hello" + + # Invalid: x out of range + with pytest.raises(pydantic.ValidationError): + Model(value={"x": -1, "y": "hello"}) + + +def test_nullable_column_pydantic_field() -> None: + # Arrange + column = dy.Integer(min=0, max=100, nullable=True) + column._name = "test_col" + + # Act + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid: None + instance = Model(value=None) + assert instance.value is None + + # Valid: integer + instance = Model(value=50) + assert instance.value == 50 + + +def test_column_with_custom_check_raises_warning() -> None: + # Arrange + column = dy.Integer(min=0, max=100, check=lambda x: x.is_even()) + column._name = "test_col" + + # Act & Assert + with pytest.warns(UserWarning, match="Custom checks .* are not translated"): + column.pydantic_field() + + +def test_schema_to_pydantic_model() -> None: + # Arrange + class MySchema(dy.Schema): + x = dy.Integer(min=0, max=100) + y = dy.String(regex=r"^[A-Z]+$") + z = dy.Float(nullable=True, allow_inf=True, allow_nan=True) + + # Act + Model = MySchema.to_pydantic_model() + + # Assert + # Valid instance + instance = Model(x=50, y="HELLO", z=3.14) + assert instance.x == 50 + assert instance.y == "HELLO" + assert instance.z == 3.14 + + # Valid with None + instance = Model(x=50, y="HELLO", z=None) + assert instance.z is None + + # Invalid: x out of range + with pytest.raises(pydantic.ValidationError): + Model(x=101, y="HELLO", z=3.14) + + # Invalid: y doesn't match regex + with pytest.raises(pydantic.ValidationError): + Model(x=50, y="hello", z=3.14) + + +def test_schema_with_nested_struct() -> None: + # Arrange + class NestedSchema(dy.Schema): + point = dy.Struct( + { + "x": dy.Float(allow_inf=True, allow_nan=True), + "y": dy.Float(allow_inf=True, allow_nan=True), + } + ) + label = dy.String() + + # Act + Model = NestedSchema.to_pydantic_model() + + # Assert + instance = Model(point={"x": 1.0, "y": 2.0}, label="A") + assert instance.point.x == 1.0 + assert instance.point.y == 2.0 + assert instance.label == "A" + + +def test_schema_with_list_of_ints() -> None: + # Arrange + class ListSchema(dy.Schema): + numbers = dy.List(dy.Integer(min=0), min_length=1, max_length=10) + + # Act + Model = ListSchema.to_pydantic_model() + + # Assert + instance = Model(numbers=[1, 2, 3, 4, 5]) + assert instance.numbers == [1, 2, 3, 4, 5] + + # Invalid: contains negative number + with pytest.raises(pydantic.ValidationError): + Model(numbers=[1, -2, 3]) + + +def test_decimal_column_pydantic_field() -> None: + # Arrange + column = dy.Decimal( + precision=10, scale=2, min=decimal.Decimal("0.00"), max=decimal.Decimal("100.00") + ) + column._name = "test_col" + + # Act + with warnings.catch_warnings(record=True): + # Suppress warnings about precision and scale + warnings.simplefilter("always") + field_type = column.pydantic_field() + + # Assert + Model = pydantic.create_model("TestModel", value=(field_type, ...)) + + # Valid value + instance = Model(value=decimal.Decimal("50.00")) + assert instance.value == decimal.Decimal("50.00") + + # Invalid: out of range + with pytest.raises(pydantic.ValidationError): + Model(value=decimal.Decimal("150.00")) + + +def test_schema_with_group_rules_raises_warning() -> None: + # Arrange + class SchemaWithRules(dy.Schema): + x = dy.Integer() + y = dy.Integer() + + # Add a custom group rule + @dy.rule() + def sum_check(cls, lf): # type: ignore + return (lf.select("x") + lf.select("y") > 0).to_series() + + # Act & Assert + with pytest.warns(UserWarning, match="group rules that cannot be translated"): + SchemaWithRules.to_pydantic_model() From 00e94d25158e6a38a071de586100f6ce9af524bd Mon Sep 17 00:00:00 2001 From: "anthropic-code-agent[bot]" <242468646+Claude@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:08:12 +0000 Subject: [PATCH 04/12] Addressing PR comments Agent-Logs-Url: https://github.com/Quantco/dataframely/sessions/5299dcd4-f2fc-4862-b4b6-79fe11566ad8 Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> --- dataframely/columns/_base.py | 10 +- dataframely/columns/_mixins.py | 32 ++++++ dataframely/columns/array.py | 15 +-- dataframely/columns/binary.py | 6 +- dataframely/columns/bool.py | 6 +- dataframely/columns/categorical.py | 7 +- dataframely/columns/datetime.py | 150 ++++------------------------- dataframely/columns/decimal.py | 54 +++-------- dataframely/columns/enum.py | 10 +- dataframely/columns/float.py | 54 +++-------- dataframely/columns/integer.py | 15 +-- dataframely/columns/list.py | 14 +-- dataframely/columns/object.py | 6 +- dataframely/columns/string.py | 17 +--- dataframely/columns/struct.py | 10 +- dataframely/schema.py | 14 +-- tests/test_pydantic_conversion.py | 8 +- 17 files changed, 116 insertions(+), 312 deletions(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index ded2573..ec54ae5 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -240,9 +240,7 @@ def pydantic_field(self) -> Any: if self.check is not None: warnings.warn( f"Custom checks for column '{self.name or self.__class__.__name__}' " - "are not translated to pydantic constraints.", - UserWarning, - stacklevel=2, + "are not translated to pydantic constraints." ) return self._pydantic_field_inner() @@ -251,6 +249,12 @@ def pydantic_field(self) -> Any: def _pydantic_field_inner(self) -> Any: """Subclasses implement this to return the actual pydantic field type.""" + def _make_nullable_type(self, base_type: Any) -> Any: + """Helper to make a type nullable if needed.""" + if self.nullable: + return base_type | None + return base_type + # ------------------------------------ HELPER ------------------------------------ # @property diff --git a/dataframely/columns/_mixins.py b/dataframely/columns/_mixins.py index 4e1ada4..2b68f12 100644 --- a/dataframely/columns/_mixins.py +++ b/dataframely/columns/_mixins.py @@ -80,6 +80,38 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: result["max_exclusive"] = expr < self.max_exclusive # type: ignore return result + def _add_ordinal_constraints_to_pydantic_field( + self, base_type: Any + ) -> tuple[Any, dict[str, Any]]: + """Helper method to add ordinal constraints to a pydantic field. + + Returns: + A tuple of (annotated_type_or_base_type, constraint_kwargs). + If there are constraints, the first element is an Annotated type, + otherwise it's the base_type unchanged. + """ + from typing import Annotated + + from dataframely._compat import pydantic + + constraint_kwargs = {} + if self.min is not None: + constraint_kwargs["ge"] = self.min + if self.min_exclusive is not None: + constraint_kwargs["gt"] = self.min_exclusive + if self.max is not None: + constraint_kwargs["le"] = self.max + if self.max_exclusive is not None: + constraint_kwargs["lt"] = self.max_exclusive + + if constraint_kwargs: + return ( + Annotated[base_type, pydantic.Field(**constraint_kwargs)], # type: ignore[call-overload] + constraint_kwargs, + ) + + return base_type, constraint_kwargs + # ------------------------------------ IS IN MIXIN ----------------------------------- # diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index f13e92d..a86fe00 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -152,24 +152,13 @@ def from_dict(cls, data: dict[str, Any]) -> Self: def _pydantic_field_inner(self) -> type: """Return pydantic field type for Array column.""" import warnings - from typing import Union warnings.warn( f"Array column '{self.name or self.__class__.__name__}' cannot be fully " - "translated to pydantic. Using list as the base type.", - UserWarning, - stacklevel=3, + "translated to pydantic. Using list as the base type." ) - # Get the inner type inner_type = self.inner.pydantic_field() - - # Build the type annotation - use list for arrays - base_type = list[inner_type] # type: ignore - # Handle nullability - if self.nullable: - return Union[base_type, None] # type: ignore - - return base_type # type: ignore + return self._make_nullable_type(base_type) diff --git a/dataframely/columns/binary.py b/dataframely/columns/binary.py index 677df3d..cb8719f 100644 --- a/dataframely/columns/binary.py +++ b/dataframely/columns/binary.py @@ -41,8 +41,4 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type[bytes] | None: """Return pydantic field type for binary column.""" - if self.nullable: - from typing import Union - - return Union[bytes, None] # type: ignore - return bytes + return self._make_nullable_type(bytes) diff --git a/dataframely/columns/bool.py b/dataframely/columns/bool.py index 1708b99..68355c5 100644 --- a/dataframely/columns/bool.py +++ b/dataframely/columns/bool.py @@ -32,8 +32,4 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type[bool] | None: """Return pydantic field type for bool column.""" - if self.nullable: - from typing import Union - - return Union[bool, None] # type: ignore - return bool + return self._make_nullable_type(bool) diff --git a/dataframely/columns/categorical.py b/dataframely/columns/categorical.py index db34c88..5e3e397 100644 --- a/dataframely/columns/categorical.py +++ b/dataframely/columns/categorical.py @@ -79,9 +79,4 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type[str] | None: """Return pydantic field type for categorical column.""" - # Categorical is essentially a string type - if self.nullable: - from typing import Union - - return Union[str, None] # type: ignore - return str + return self._make_nullable_type(str) diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index 818efad..9f4fa29 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -153,45 +153,19 @@ def _pydantic_field_inner(self) -> type: """Return pydantic field type for Date column.""" import datetime as dt import warnings - from typing import Annotated - - from dataframely._compat import pydantic # Warn about untranslated constraints if self.resolution is not None: warnings.warn( f"Date column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "constraint that cannot be translated to pydantic." ) - # Build constraints - merged_kwargs = {} - if self.min is not None: - merged_kwargs["ge"] = self.min - if self.min_exclusive is not None: - merged_kwargs["gt"] = self.min_exclusive - if self.max is not None: - merged_kwargs["le"] = self.max - if self.max_exclusive is not None: - merged_kwargs["lt"] = self.max_exclusive - - # Build the type annotation - base_type = dt.date - - if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] - else: - annotated_type = base_type + # Build the type annotation using mixin helper + annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(dt.date) # Handle nullability - if self.nullable: - from typing import Union - - return Union[annotated_type, None] # type: ignore - - return annotated_type # type: ignore + return self._make_nullable_type(annotated_type) @register @@ -326,45 +300,19 @@ def _pydantic_field_inner(self) -> type: """Return pydantic field type for Time column.""" import datetime as dt import warnings - from typing import Annotated - - from dataframely._compat import pydantic # Warn about untranslated constraints if self.resolution is not None: warnings.warn( f"Time column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "constraint that cannot be translated to pydantic." ) - # Build constraints - merged_kwargs = {} - if self.min is not None: - merged_kwargs["ge"] = self.min - if self.min_exclusive is not None: - merged_kwargs["gt"] = self.min_exclusive - if self.max is not None: - merged_kwargs["le"] = self.max - if self.max_exclusive is not None: - merged_kwargs["lt"] = self.max_exclusive - - # Build the type annotation - base_type = dt.time - - if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] - else: - annotated_type = base_type + # Build the type annotation using mixin helper + annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(dt.time) # Handle nullability - if self.nullable: - from typing import Union - - return Union[annotated_type, None] # type: ignore - - return annotated_type # type: ignore + return self._make_nullable_type(annotated_type) @register @@ -517,59 +465,29 @@ def _pydantic_field_inner(self) -> type: """Return pydantic field type for Datetime column.""" import datetime as dt import warnings - from typing import Annotated - - from dataframely._compat import pydantic # Warn about untranslated constraints if self.resolution is not None: warnings.warn( f"Datetime column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "constraint that cannot be translated to pydantic." ) if self.time_zone is not None: warnings.warn( f"Datetime column '{self.name or self.__class__.__name__}' has a time_zone " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "constraint that cannot be translated to pydantic." ) if self.time_unit != "us": warnings.warn( f"Datetime column '{self.name or self.__class__.__name__}' has a time_unit " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "constraint that cannot be translated to pydantic." ) - # Build constraints - merged_kwargs = {} - if self.min is not None: - merged_kwargs["ge"] = self.min - if self.min_exclusive is not None: - merged_kwargs["gt"] = self.min_exclusive - if self.max is not None: - merged_kwargs["le"] = self.max - if self.max_exclusive is not None: - merged_kwargs["lt"] = self.max_exclusive - - # Build the type annotation - base_type = dt.datetime - - if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] - else: - annotated_type = base_type + # Build the type annotation using mixin helper + annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(dt.datetime) # Handle nullability - if self.nullable: - from typing import Union - - return Union[annotated_type, None] # type: ignore - - return annotated_type # type: ignore + return self._make_nullable_type(annotated_type) @register @@ -700,52 +618,26 @@ def _pydantic_field_inner(self) -> type: """Return pydantic field type for Duration column.""" import datetime as dt import warnings - from typing import Annotated - - from dataframely._compat import pydantic # Warn about untranslated constraints if self.resolution is not None: warnings.warn( f"Duration column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "constraint that cannot be translated to pydantic." ) if self.time_unit != "us": warnings.warn( f"Duration column '{self.name or self.__class__.__name__}' has a time_unit " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "constraint that cannot be translated to pydantic." ) - # Build constraints - merged_kwargs = {} - if self.min is not None: - merged_kwargs["ge"] = self.min - if self.min_exclusive is not None: - merged_kwargs["gt"] = self.min_exclusive - if self.max is not None: - merged_kwargs["le"] = self.max - if self.max_exclusive is not None: - merged_kwargs["lt"] = self.max_exclusive - - # Build the type annotation - base_type = dt.timedelta - - if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] - else: - annotated_type = base_type + # Build the type annotation using mixin helper + annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field( + dt.timedelta + ) # Handle nullability - if self.nullable: - from typing import Union - - return Union[annotated_type, None] # type: ignore - - return annotated_type # type: ignore + return self._make_nullable_type(annotated_type) # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index e959a72..afb4e63 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -8,7 +8,7 @@ import polars as pl -from dataframely._compat import pa, pydantic, sa, sa_TypeEngine +from dataframely._compat import pa, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -173,50 +173,20 @@ def _pydantic_field_inner(self) -> type: """Return pydantic field type for Decimal column.""" import decimal import warnings - from typing import Annotated - - # Warn about untranslated constraints - if self.precision is not None: - warnings.warn( - f"Decimal column '{self.name or self.__class__.__name__}' has a precision " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, - ) - if self.scale != 0: - warnings.warn( - f"Decimal column '{self.name or self.__class__.__name__}' has a scale " - "constraint that cannot be translated to pydantic.", - UserWarning, - stacklevel=3, - ) - - # Build constraints - merged_kwargs = {} - if self.min is not None: - merged_kwargs["ge"] = self.min - if self.min_exclusive is not None: - merged_kwargs["gt"] = self.min_exclusive - if self.max is not None: - merged_kwargs["le"] = self.max - if self.max_exclusive is not None: - merged_kwargs["lt"] = self.max_exclusive - - # Build the type annotation - base_type = decimal.Decimal - - if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] - else: - annotated_type = base_type - # Handle nullability - if self.nullable: - from typing import Union + # Always warn about untranslated constraints + warnings.warn( + f"Decimal column '{self.name or self.__class__.__name__}' has precision and scale " + "constraints that cannot be translated to pydantic." + ) - return Union[annotated_type, None] # type: ignore + # Build the type annotation using mixin helper + annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field( + decimal.Decimal + ) - return annotated_type # type: ignore + # Handle nullability + return self._make_nullable_type(annotated_type) # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index e5d9aad..498264f 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -104,17 +104,11 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type: """Return pydantic field type for Enum column.""" - from typing import Literal, Union + from typing import Literal - # Use Literal for enum categories if len(self.categories) == 0: - # Empty enum, use str as fallback base_type = str else: base_type = Literal[tuple(self.categories)] # type: ignore - # Handle nullability - if self.nullable: - return Union[base_type, None] # type: ignore - - return base_type # type: ignore + return self._make_nullable_type(base_type) diff --git a/dataframely/columns/float.py b/dataframely/columns/float.py index 4a416e6..876678e 100644 --- a/dataframely/columns/float.py +++ b/dataframely/columns/float.py @@ -145,52 +145,30 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type[float] | None: """Return pydantic field type for float column.""" import warnings - from typing import Annotated - - from dataframely._compat import pydantic # Warn about untranslated constraints - if not self.allow_inf: + if self.allow_inf == self.allow_nan and not self.allow_inf: warnings.warn( f"Float column '{self.name or self.__class__.__name__}' does not allow " - "infinity values, but this constraint cannot be translated to pydantic.", - UserWarning, - stacklevel=3, + "infinity or NaN values, but this constraint cannot be translated to pydantic." ) - if not self.allow_nan: - warnings.warn( - f"Float column '{self.name or self.__class__.__name__}' does not allow " - "NaN values, but this constraint cannot be translated to pydantic.", - UserWarning, - stacklevel=3, - ) - - # Build constraints - merged_kwargs = {} - if self.min is not None: - merged_kwargs["ge"] = self.min - if self.min_exclusive is not None: - merged_kwargs["gt"] = self.min_exclusive - if self.max is not None: - merged_kwargs["le"] = self.max - if self.max_exclusive is not None: - merged_kwargs["lt"] = self.max_exclusive - - # Build the type annotation - base_type = float - - if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] else: - annotated_type = base_type + if not self.allow_inf: + warnings.warn( + f"Float column '{self.name or self.__class__.__name__}' does not allow " + "infinity values, but this constraint cannot be translated to pydantic." + ) + if not self.allow_nan: + warnings.warn( + f"Float column '{self.name or self.__class__.__name__}' does not allow " + "NaN values, but this constraint cannot be translated to pydantic." + ) + + # Build the type annotation using mixin helper + annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(float) # Handle nullability - if self.nullable: - from typing import Union - - return Union[annotated_type, None] # type: ignore - - return annotated_type # type: ignore + return self._make_nullable_type(annotated_type) # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/integer.py b/dataframely/columns/integer.py index f6f4b47..ba0c972 100644 --- a/dataframely/columns/integer.py +++ b/dataframely/columns/integer.py @@ -144,14 +144,11 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type[int] | None: """Return pydantic field type for integer column.""" - from typing import Annotated, Literal, Union + from typing import Annotated, Literal - # Build constraints if self.is_in is not None: - # Use Literal for enumerations - base_type = Literal[tuple(self.is_in)] # type: ignore + base_type: Any = Literal[tuple(self.is_in)] # type: ignore else: - # Build range constraints in a single Field field_kwargs = {} if self.min is not None: field_kwargs["ge"] = self.min @@ -163,15 +160,11 @@ def _pydantic_field_inner(self) -> type[int] | None: field_kwargs["lt"] = self.max_exclusive if field_kwargs: - base_type = Annotated[int, pydantic.Field(**field_kwargs)] # type: ignore + base_type = Annotated[int, pydantic.Field(**field_kwargs)] # type: ignore[call-overload] else: base_type = int - # Handle nullability - if self.nullable: - return Union[base_type, None] # type: ignore - - return base_type # type: ignore + return self._make_nullable_type(base_type) # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index b6b9904..e5c60ee 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -179,34 +179,26 @@ def from_dict(cls, data: dict[str, Any]) -> Self: def _pydantic_field_inner(self) -> type: """Return pydantic field type for List column.""" - from typing import Annotated, Union + from typing import Annotated from dataframely._compat import pydantic - # Get the inner type inner_type = self.inner.pydantic_field() - # Build constraints merged_kwargs = {} if self.min_length is not None: merged_kwargs["min_length"] = self.min_length if self.max_length is not None: merged_kwargs["max_length"] = self.max_length - # Build the type annotation - base_type = list[inner_type] # type: ignore if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + annotated_type: Any = Annotated[base_type, pydantic.Field(**merged_kwargs)] # type: ignore[call-overload, misc] else: annotated_type = base_type - # Handle nullability - if self.nullable: - return Union[annotated_type, None] # type: ignore - - return annotated_type # type: ignore + return self._make_nullable_type(annotated_type) def _list_primary_key_check( diff --git a/dataframely/columns/object.py b/dataframely/columns/object.py index 78c6deb..a87e67a 100644 --- a/dataframely/columns/object.py +++ b/dataframely/columns/object.py @@ -75,9 +75,5 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: def _pydantic_field_inner(self) -> type: """Return pydantic field type for Object column.""" from typing import Any as AnyType - from typing import Union - # Object columns can contain any Python object - if self.nullable: - return Union[AnyType, None] # type: ignore - return AnyType + return self._make_nullable_type(AnyType) diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index e3ac12a..3a5ba3f 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -142,28 +142,17 @@ def _pydantic_field_inner(self) -> type[str] | None: """Return pydantic field type for string column.""" from typing import Annotated - # Build constraints merged_kwargs = {} if self.min_length is not None: merged_kwargs["min_length"] = self.min_length if self.max_length is not None: merged_kwargs["max_length"] = self.max_length if self.regex is not None: - # Pydantic uses 'pattern' for regex validation merged_kwargs["pattern"] = self.regex - # Build the type annotation - base_type = str - if merged_kwargs: - annotated_type = Annotated[base_type, pydantic.Field(**merged_kwargs)] + annotated_type: Any = Annotated[str, pydantic.Field(**merged_kwargs)] # type: ignore[call-overload, misc] else: - annotated_type = base_type - - # Handle nullability - if self.nullable: - from typing import Union - - return Union[annotated_type, None] # type: ignore + annotated_type = str - return annotated_type # type: ignore + return self._make_nullable_type(annotated_type) diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index 76fc8b0..669adfb 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -159,24 +159,16 @@ def from_dict(cls, data: dict[str, Any]) -> Self: def _pydantic_field_inner(self) -> type: """Return pydantic field type for Struct column.""" - from typing import Union - from dataframely._compat import pydantic - # Build a pydantic model for the struct fields = {} for field_name, col in self.inner.items(): field_type = col.pydantic_field() fields[field_name] = (field_type, ...) - # Create a dynamic pydantic model model = pydantic.create_model( f"{self.name or 'Struct'}Model", **fields, # type: ignore ) - # Handle nullability - if self.nullable: - return Union[model, None] # type: ignore - - return model # type: ignore + return self._make_nullable_type(model) diff --git a/dataframely/schema.py b/dataframely/schema.py index f79148f..ff2a180 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -16,7 +16,7 @@ import polars.exceptions as plexc from polars._typing import FileSource -from dataframely._compat import deltalake +from dataframely._compat import deltalake, pydantic from ._base_schema import ORIGINAL_COLUMN_PREFIX, BaseSchema from ._compat import PartitionSchemeOrSinkDirectory, pa, sa @@ -1368,17 +1368,11 @@ def to_pydantic_model(cls) -> type: >>> Model = MySchema.to_pydantic_model() >>> # Now you can use Model for validation, e.g., with LLM APIs """ - import warnings - - from dataframely._compat import pydantic - # Check for group rules and warn if present if cls._schema_validation_rules(): warnings.warn( - f"Schema '{cls.__name__}' has group rules that cannot be " - "translated to pydantic validators.", - UserWarning, - stacklevel=2, + f"Schema '{cls.__name__}' has group rules. These are not currently " + "translated to pydantic validators." ) # Build field definitions for the pydantic model @@ -1390,7 +1384,7 @@ def to_pydantic_model(cls) -> type: fields[col_name] = (field_type, ...) # Create the pydantic model dynamically - model_name = f"{cls.__name__}PydanticModel" + model_name = f"{cls.__name__.removesuffix('Schema')}Model" return pydantic.create_model(model_name, **fields) # type: ignore # ----------------------------------- EQUALITY ----------------------------------- # diff --git a/tests/test_pydantic_conversion.py b/tests/test_pydantic_conversion.py index c6a9ed0..dfb4b6a 100644 --- a/tests/test_pydantic_conversion.py +++ b/tests/test_pydantic_conversion.py @@ -1,6 +1,5 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause - """Tests for pydantic conversion functionality.""" import datetime as dt @@ -354,7 +353,10 @@ class ListSchema(dy.Schema): def test_decimal_column_pydantic_field() -> None: # Arrange column = dy.Decimal( - precision=10, scale=2, min=decimal.Decimal("0.00"), max=decimal.Decimal("100.00") + precision=10, + scale=2, + min=decimal.Decimal("0.00"), + max=decimal.Decimal("100.00"), ) column._name = "test_col" @@ -388,5 +390,5 @@ def sum_check(cls, lf): # type: ignore return (lf.select("x") + lf.select("y") > 0).to_series() # Act & Assert - with pytest.warns(UserWarning, match="group rules that cannot be translated"): + with pytest.warns(UserWarning, match="group rules.*not currently translated"): SchemaWithRules.to_pydantic_model() From bbfde68194cce66af56e3fac7ced6d5b5ef66408 Mon Sep 17 00:00:00 2001 From: "anthropic-code-agent[bot]" <242468646+Claude@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:30:15 +0000 Subject: [PATCH 05/12] Refactor pydantic field generation into _python_type and _pydantic_field_kwargs Agent-Logs-Url: https://github.com/Quantco/dataframely/sessions/d86654c8-ab94-47c3-a954-51d8f34341d7 Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> --- dataframely/columns/_base.py | 43 ++++++++++++++++++++++++------ dataframely/columns/_mixins.py | 33 ++++++----------------- dataframely/columns/any.py | 4 +-- dataframely/columns/array.py | 8 +++--- dataframely/columns/binary.py | 6 ++--- dataframely/columns/bool.py | 6 ++--- dataframely/columns/categorical.py | 6 ++--- dataframely/columns/datetime.py | 42 +++++++++-------------------- dataframely/columns/decimal.py | 12 +++------ dataframely/columns/enum.py | 11 +++----- dataframely/columns/float.py | 10 +++---- dataframely/columns/integer.py | 26 ++++-------------- dataframely/columns/list.py | 27 +++++++------------ dataframely/columns/object.py | 6 ++--- dataframely/columns/string.py | 24 ++++++++--------- dataframely/columns/struct.py | 6 ++--- 16 files changed, 111 insertions(+), 159 deletions(-) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index ec54ae5..de67ad3 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -236,6 +236,9 @@ def pydantic_field(self) -> Any: returned pydantic field. A UserWarning is raised if custom checks are present. """ import warnings + from typing import Annotated + + from dataframely._compat import pydantic if self.check is not None: warnings.warn( @@ -243,17 +246,41 @@ def pydantic_field(self) -> Any: "are not translated to pydantic constraints." ) - return self._pydantic_field_inner() + # Get the base Python type + python_type = self._python_type() + + # Apply nullability to the type + if self.nullable: + python_type = python_type | None + + # Get pydantic field kwargs + field_kwargs = self._pydantic_field_kwargs() + + # Construct the annotated type if there are constraints + if field_kwargs: + return Annotated[python_type, pydantic.Field(**field_kwargs)] # type: ignore[call-overload] + + return python_type @abstractmethod - def _pydantic_field_inner(self) -> Any: - """Subclasses implement this to return the actual pydantic field type.""" + def _python_type(self) -> type: + """Return the base Python type for this column. - def _make_nullable_type(self, base_type: Any) -> Any: - """Helper to make a type nullable if needed.""" - if self.nullable: - return base_type | None - return base_type + Subclasses must implement this to return the appropriate Python type + (e.g., int, str, float, etc.) without any pydantic constraints or nullability. + """ + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + """Return kwargs for pydantic.Field initialization. + + This method should be extended by subclasses and mixins to add their + specific constraints. Subclasses should call super() and extend the + returned dictionary. + + Returns: + A dictionary of kwargs to pass to pydantic.Field. + """ + return {} # ------------------------------------ HELPER ------------------------------------ # diff --git a/dataframely/columns/_mixins.py b/dataframely/columns/_mixins.py index 2b68f12..137f925 100644 --- a/dataframely/columns/_mixins.py +++ b/dataframely/columns/_mixins.py @@ -80,37 +80,20 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: result["max_exclusive"] = expr < self.max_exclusive # type: ignore return result - def _add_ordinal_constraints_to_pydantic_field( - self, base_type: Any - ) -> tuple[Any, dict[str, Any]]: - """Helper method to add ordinal constraints to a pydantic field. + def _pydantic_field_kwargs(self) -> dict[str, Any]: + """Add ordinal constraints to pydantic field kwargs.""" + kwargs = super()._pydantic_field_kwargs() - Returns: - A tuple of (annotated_type_or_base_type, constraint_kwargs). - If there are constraints, the first element is an Annotated type, - otherwise it's the base_type unchanged. - """ - from typing import Annotated - - from dataframely._compat import pydantic - - constraint_kwargs = {} if self.min is not None: - constraint_kwargs["ge"] = self.min + kwargs["ge"] = self.min if self.min_exclusive is not None: - constraint_kwargs["gt"] = self.min_exclusive + kwargs["gt"] = self.min_exclusive if self.max is not None: - constraint_kwargs["le"] = self.max + kwargs["le"] = self.max if self.max_exclusive is not None: - constraint_kwargs["lt"] = self.max_exclusive - - if constraint_kwargs: - return ( - Annotated[base_type, pydantic.Field(**constraint_kwargs)], # type: ignore[call-overload] - constraint_kwargs, - ) + kwargs["lt"] = self.max_exclusive - return base_type, constraint_kwargs + return kwargs # ------------------------------------ IS IN MIXIN ----------------------------------- # diff --git a/dataframely/columns/any.py b/dataframely/columns/any.py index 44cdef7..0b5429a 100644 --- a/dataframely/columns/any.py +++ b/dataframely/columns/any.py @@ -80,8 +80,8 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return pl.repeat(None, n, dtype=pl.Null, eager=True) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Any column.""" + def _python_type(self) -> type: + """Return the base Python type for Any column.""" from typing import Any as AnyType # Any columns are always nullable diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index a86fe00..2ce3d9b 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -149,8 +149,8 @@ def from_dict(cls, data: dict[str, Any]) -> Self: data["inner"] = column_from_dict(data["inner"]) return super().from_dict(data) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Array column.""" + def _python_type(self) -> type: + """Return the base Python type for Array column.""" import warnings warnings.warn( @@ -159,6 +159,4 @@ def _pydantic_field_inner(self) -> type: ) inner_type = self.inner.pydantic_field() - base_type = list[inner_type] # type: ignore - - return self._make_nullable_type(base_type) + return list[inner_type] # type: ignore diff --git a/dataframely/columns/binary.py b/dataframely/columns/binary.py index cb8719f..2935d04 100644 --- a/dataframely/columns/binary.py +++ b/dataframely/columns/binary.py @@ -39,6 +39,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _pydantic_field_inner(self) -> type[bytes] | None: - """Return pydantic field type for binary column.""" - return self._make_nullable_type(bytes) + def _python_type(self) -> type: + """Return the base Python type for binary column.""" + return bytes diff --git a/dataframely/columns/bool.py b/dataframely/columns/bool.py index 68355c5..29f8073 100644 --- a/dataframely/columns/bool.py +++ b/dataframely/columns/bool.py @@ -30,6 +30,6 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_bool(n, null_probability=self._null_probability) - def _pydantic_field_inner(self) -> type[bool] | None: - """Return pydantic field type for bool column.""" - return self._make_nullable_type(bool) + def _python_type(self) -> type: + """Return the base Python type for bool column.""" + return bool diff --git a/dataframely/columns/categorical.py b/dataframely/columns/categorical.py index 5e3e397..b89d81d 100644 --- a/dataframely/columns/categorical.py +++ b/dataframely/columns/categorical.py @@ -77,6 +77,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: n, regex=r"[a-z]{1,2}", null_probability=self._null_probability ).cast(self.dtype) - def _pydantic_field_inner(self) -> type[str] | None: - """Return pydantic field type for categorical column.""" - return self._make_nullable_type(str) + def _python_type(self) -> type: + """Return the base Python type for categorical column.""" + return str diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index 9f4fa29..f0ca3fb 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -149,8 +149,8 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Date column.""" + def _python_type(self) -> type: + """Return the base Python type for Date column.""" import datetime as dt import warnings @@ -161,11 +161,7 @@ def _pydantic_field_inner(self) -> type: "constraint that cannot be translated to pydantic." ) - # Build the type annotation using mixin helper - annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(dt.date) - - # Handle nullability - return self._make_nullable_type(annotated_type) + return dt.date @register @@ -296,8 +292,8 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Time column.""" + def _python_type(self) -> type: + """Return the base Python type for Time column.""" import datetime as dt import warnings @@ -308,11 +304,7 @@ def _pydantic_field_inner(self) -> type: "constraint that cannot be translated to pydantic." ) - # Build the type annotation using mixin helper - annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(dt.time) - - # Handle nullability - return self._make_nullable_type(annotated_type) + return dt.time @register @@ -461,8 +453,8 @@ def _attributes_match( return lhs.utcoffset(now) == rhs.utcoffset(now) return super()._attributes_match(lhs, rhs, name, column_expr) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Datetime column.""" + def _python_type(self) -> type: + """Return the base Python type for Datetime column.""" import datetime as dt import warnings @@ -483,11 +475,7 @@ def _pydantic_field_inner(self) -> type: "constraint that cannot be translated to pydantic." ) - # Build the type annotation using mixin helper - annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(dt.datetime) - - # Handle nullability - return self._make_nullable_type(annotated_type) + return dt.datetime @register @@ -614,8 +602,8 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Duration column.""" + def _python_type(self) -> type: + """Return the base Python type for Duration column.""" import datetime as dt import warnings @@ -631,13 +619,7 @@ def _pydantic_field_inner(self) -> type: "constraint that cannot be translated to pydantic." ) - # Build the type annotation using mixin helper - annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field( - dt.timedelta - ) - - # Handle nullability - return self._make_nullable_type(annotated_type) + return dt.timedelta # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index afb4e63..11ad0fa 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -169,8 +169,8 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: ) return ((samples * 10**self.scale).floor() / 10**self.scale).cast(self.dtype) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Decimal column.""" + def _python_type(self) -> type: + """Return the base Python type for Decimal column.""" import decimal import warnings @@ -180,13 +180,7 @@ def _pydantic_field_inner(self) -> type: "constraints that cannot be translated to pydantic." ) - # Build the type annotation using mixin helper - annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field( - decimal.Decimal - ) - - # Handle nullability - return self._make_nullable_type(annotated_type) + return decimal.Decimal # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 498264f..08629bb 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -102,13 +102,10 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ).cast(self.dtype) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Enum column.""" + def _python_type(self) -> type: + """Return the base Python type for Enum column.""" from typing import Literal if len(self.categories) == 0: - base_type = str - else: - base_type = Literal[tuple(self.categories)] # type: ignore - - return self._make_nullable_type(base_type) + return str + return Literal[tuple(self.categories)] # type: ignore diff --git a/dataframely/columns/float.py b/dataframely/columns/float.py index 876678e..af8973b 100644 --- a/dataframely/columns/float.py +++ b/dataframely/columns/float.py @@ -142,8 +142,8 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: inf_probability=self._inf_probability, ).cast(self.dtype) - def _pydantic_field_inner(self) -> type[float] | None: - """Return pydantic field type for float column.""" + def _python_type(self) -> type: + """Return the base Python type for float column.""" import warnings # Warn about untranslated constraints @@ -164,11 +164,7 @@ def _pydantic_field_inner(self) -> type[float] | None: "NaN values, but this constraint cannot be translated to pydantic." ) - # Build the type annotation using mixin helper - annotated_type, _ = self._add_ordinal_constraints_to_pydantic_field(float) - - # Handle nullability - return self._make_nullable_type(annotated_type) + return float # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/integer.py b/dataframely/columns/integer.py index ba0c972..7bfc401 100644 --- a/dataframely/columns/integer.py +++ b/dataframely/columns/integer.py @@ -142,29 +142,13 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ).cast(self.dtype) - def _pydantic_field_inner(self) -> type[int] | None: - """Return pydantic field type for integer column.""" - from typing import Annotated, Literal + def _python_type(self) -> type: + """Return the base Python type for integer column.""" + from typing import Literal if self.is_in is not None: - base_type: Any = Literal[tuple(self.is_in)] # type: ignore - else: - field_kwargs = {} - if self.min is not None: - field_kwargs["ge"] = self.min - if self.min_exclusive is not None: - field_kwargs["gt"] = self.min_exclusive - if self.max is not None: - field_kwargs["le"] = self.max - if self.max_exclusive is not None: - field_kwargs["lt"] = self.max_exclusive - - if field_kwargs: - base_type = Annotated[int, pydantic.Field(**field_kwargs)] # type: ignore[call-overload] - else: - base_type = int - - return self._make_nullable_type(base_type) + return Literal[tuple(self.is_in)] # type: ignore + return int # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index e5c60ee..2fb81c8 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -177,28 +177,21 @@ def from_dict(cls, data: dict[str, Any]) -> Self: data["inner"] = column_from_dict(data["inner"]) return super().from_dict(data) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for List column.""" - from typing import Annotated - - from dataframely._compat import pydantic - + def _python_type(self) -> type: + """Return the base Python type for List column.""" inner_type = self.inner.pydantic_field() + return list[inner_type] # type: ignore + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + """Return pydantic field kwargs for list constraints.""" + kwargs = super()._pydantic_field_kwargs() - merged_kwargs = {} if self.min_length is not None: - merged_kwargs["min_length"] = self.min_length + kwargs["min_length"] = self.min_length if self.max_length is not None: - merged_kwargs["max_length"] = self.max_length - - base_type = list[inner_type] # type: ignore - - if merged_kwargs: - annotated_type: Any = Annotated[base_type, pydantic.Field(**merged_kwargs)] # type: ignore[call-overload, misc] - else: - annotated_type = base_type + kwargs["max_length"] = self.max_length - return self._make_nullable_type(annotated_type) + return kwargs def _list_primary_key_check( diff --git a/dataframely/columns/object.py b/dataframely/columns/object.py index a87e67a..e3e3155 100644 --- a/dataframely/columns/object.py +++ b/dataframely/columns/object.py @@ -72,8 +72,8 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: "Random data sampling not implemented for 'Object' type." ) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Object column.""" + def _python_type(self) -> type: + """Return the base Python type for Object column.""" from typing import Any as AnyType - return self._make_nullable_type(AnyType) + return AnyType diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index 3a5ba3f..45328db 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -138,21 +138,19 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _pydantic_field_inner(self) -> type[str] | None: - """Return pydantic field type for string column.""" - from typing import Annotated + def _python_type(self) -> type: + """Return the base Python type for string column.""" + return str + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + """Return pydantic field kwargs for string constraints.""" + kwargs = super()._pydantic_field_kwargs() - merged_kwargs = {} if self.min_length is not None: - merged_kwargs["min_length"] = self.min_length + kwargs["min_length"] = self.min_length if self.max_length is not None: - merged_kwargs["max_length"] = self.max_length + kwargs["max_length"] = self.max_length if self.regex is not None: - merged_kwargs["pattern"] = self.regex - - if merged_kwargs: - annotated_type: Any = Annotated[str, pydantic.Field(**merged_kwargs)] # type: ignore[call-overload, misc] - else: - annotated_type = str + kwargs["pattern"] = self.regex - return self._make_nullable_type(annotated_type) + return kwargs diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index 669adfb..d89fbfc 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -157,8 +157,8 @@ def from_dict(cls, data: dict[str, Any]) -> Self: } return super().from_dict(data) - def _pydantic_field_inner(self) -> type: - """Return pydantic field type for Struct column.""" + def _python_type(self) -> type: + """Return the base Python type for Struct column.""" from dataframely._compat import pydantic fields = {} @@ -171,4 +171,4 @@ def _pydantic_field_inner(self) -> type: **fields, # type: ignore ) - return self._make_nullable_type(model) + return model From f730d9d6886e3bb7e3ecac8193909df7252605f8 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Fri, 10 Apr 2026 17:00:39 +0200 Subject: [PATCH 06/12] Clean up package code --- dataframely/columns/_base.py | 35 +-- dataframely/columns/_mixins.py | 3 - dataframely/columns/any.py | 13 +- dataframely/columns/array.py | 30 ++- dataframely/columns/binary.py | 10 +- dataframely/columns/bool.py | 10 +- dataframely/columns/categorical.py | 8 +- dataframely/columns/datetime.py | 123 ++++----- dataframely/columns/decimal.py | 23 +- dataframely/columns/enum.py | 14 +- dataframely/columns/float.py | 63 +++-- dataframely/columns/integer.py | 26 +- dataframely/columns/list.py | 28 +- dataframely/columns/object.py | 10 +- dataframely/columns/string.py | 32 +-- dataframely/columns/struct.py | 22 +- dataframely/schema.py | 50 +--- tests/test_pydantic_conversion.py | 394 ----------------------------- 18 files changed, 218 insertions(+), 676 deletions(-) delete mode 100644 tests/test_pydantic_conversion.py diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index de67ad3..d664639 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -5,14 +5,15 @@ import inspect import sys +import warnings from abc import ABC, abstractmethod from collections import Counter from collections.abc import Callable, Mapping, Sequence -from typing import Any, TypeAlias, cast +from typing import Annotated, Any, TypeAlias, cast import polars as pl -from dataframely._compat import pa, sa, sa_TypeEngine +from dataframely._compat import pa, pydantic, sa, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -229,46 +230,30 @@ def pydantic_field(self) -> Any: Returns: A pydantic-compatible type annotation that includes structured constraints - (e.g., min, max, regex) but excludes custom checks. + (such as `min`, `max`, ...). Warning: - Custom checks defined via the `check` parameter are not included in the - returned pydantic field. A UserWarning is raised if custom checks are present. + Custom checks are not translated to pydantic validators. """ - import warnings - from typing import Annotated - - from dataframely._compat import pydantic - if self.check is not None: warnings.warn( f"Custom checks for column '{self.name or self.__class__.__name__}' " "are not translated to pydantic constraints." ) - # Get the base Python type - python_type = self._python_type() - - # Apply nullability to the type + python_type = self._python_type if self.nullable: python_type = python_type | None - # Get pydantic field kwargs field_kwargs = self._pydantic_field_kwargs() - - # Construct the annotated type if there are constraints if field_kwargs: - return Annotated[python_type, pydantic.Field(**field_kwargs)] # type: ignore[call-overload] - + return Annotated[python_type, pydantic.Field(**field_kwargs)] return python_type + @property @abstractmethod - def _python_type(self) -> type: - """Return the base Python type for this column. - - Subclasses must implement this to return the appropriate Python type - (e.g., int, str, float, etc.) without any pydantic constraints or nullability. - """ + def _python_type(self) -> Any: + """The native Python type corresponding to this column definition.""" def _pydantic_field_kwargs(self) -> dict[str, Any]: """Return kwargs for pydantic.Field initialization. diff --git a/dataframely/columns/_mixins.py b/dataframely/columns/_mixins.py index 137f925..d6d61c6 100644 --- a/dataframely/columns/_mixins.py +++ b/dataframely/columns/_mixins.py @@ -81,9 +81,7 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: return result def _pydantic_field_kwargs(self) -> dict[str, Any]: - """Add ordinal constraints to pydantic field kwargs.""" kwargs = super()._pydantic_field_kwargs() - if self.min is not None: kwargs["ge"] = self.min if self.min_exclusive is not None: @@ -92,7 +90,6 @@ def _pydantic_field_kwargs(self) -> dict[str, Any]: kwargs["le"] = self.max if self.max_exclusive is not None: kwargs["lt"] = self.max_exclusive - return kwargs diff --git a/dataframely/columns/any.py b/dataframely/columns/any.py index 0b5429a..988c788 100644 --- a/dataframely/columns/any.py +++ b/dataframely/columns/any.py @@ -3,6 +3,8 @@ from __future__ import annotations +from typing import Any as AnyType + import polars as pl from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine @@ -77,12 +79,9 @@ def pyarrow_field(self, name: str) -> pa.Field: def pyarrow_dtype(self) -> pa.DataType: return pa.null() + @property + def _python_type(self) -> AnyType: + return AnyType + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return pl.repeat(None, n, dtype=pl.Null, eager=True) - - def _python_type(self) -> type: - """Return the base Python type for Any column.""" - from typing import Any as AnyType - - # Any columns are always nullable - return AnyType diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index 2ce3d9b..44112a7 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -5,6 +5,7 @@ import math import sys +import warnings from collections.abc import Sequence from typing import Any, Literal, cast @@ -121,6 +122,23 @@ def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field: def pyarrow_dtype(self) -> pa.DataType: return self._pyarrow_field_of_shape(self.shape).type + @property + def _python_type(self) -> Any: + inner_type = self.inner.pydantic_field() + return list[inner_type] # type: ignore + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + if len(self.shape) != 1: + warnings.warn( + "Multi-dimensional arrays are flattened for pydantic validation." + ) + + return { + **super()._pydantic_field_kwargs(), + "min_items": math.prod(self.shape), + "max_items": math.prod(self.shape), + } + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # Sample the inner elements in a flat series n_elements = n * math.prod(self.shape) @@ -148,15 +166,3 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]: def from_dict(cls, data: dict[str, Any]) -> Self: data["inner"] = column_from_dict(data["inner"]) return super().from_dict(data) - - def _python_type(self) -> type: - """Return the base Python type for Array column.""" - import warnings - - warnings.warn( - f"Array column '{self.name or self.__class__.__name__}' cannot be fully " - "translated to pydantic. Using list as the base type." - ) - - inner_type = self.inner.pydantic_field() - return list[inner_type] # type: ignore diff --git a/dataframely/columns/binary.py b/dataframely/columns/binary.py index 2935d04..888d1bd 100644 --- a/dataframely/columns/binary.py +++ b/dataframely/columns/binary.py @@ -3,6 +3,8 @@ from __future__ import annotations +from typing import Any + import polars as pl from dataframely._compat import pa, sa, sa_TypeEngine @@ -31,6 +33,10 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.large_binary() + @property + def _python_type(self) -> Any: + return bytes + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_binary( n, @@ -38,7 +44,3 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: max_bytes=32, null_probability=self._null_probability, ) - - def _python_type(self) -> type: - """Return the base Python type for binary column.""" - return bytes diff --git a/dataframely/columns/bool.py b/dataframely/columns/bool.py index 29f8073..a1e48d1 100644 --- a/dataframely/columns/bool.py +++ b/dataframely/columns/bool.py @@ -3,6 +3,8 @@ from __future__ import annotations +from typing import Any + import polars as pl from dataframely._compat import pa, sa, sa_TypeEngine @@ -27,9 +29,9 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.bool_() + @property + def _python_type(self) -> Any: + return bool + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_bool(n, null_probability=self._null_probability) - - def _python_type(self) -> type: - """Return the base Python type for bool column.""" - return bool diff --git a/dataframely/columns/categorical.py b/dataframely/columns/categorical.py index b89d81d..37d795e 100644 --- a/dataframely/columns/categorical.py +++ b/dataframely/columns/categorical.py @@ -71,12 +71,12 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.dictionary(pa.uint32(), pa.large_string()) + @property + def _python_type(self) -> Any: + return str + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # We simply sample low-cardinality strings here return generator.sample_string( n, regex=r"[a-z]{1,2}", null_probability=self._null_probability ).cast(self.dtype) - - def _python_type(self) -> type: - """Return the base Python type for categorical column.""" - return str diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index f0ca3fb..56b8df7 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -4,6 +4,7 @@ from __future__ import annotations import datetime as dt +import warnings from typing import Any, cast import polars as pl @@ -132,6 +133,16 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.date32() + @property + def _python_type(self) -> Any: + return dt.date + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + if self.resolution is not None: + warnings.warn("Date resolution is not translated to a pydantic constraint.") + + return super()._pydantic_field_kwargs() + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_date( n, @@ -149,20 +160,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _python_type(self) -> type: - """Return the base Python type for Date column.""" - import datetime as dt - import warnings - - # Warn about untranslated constraints - if self.resolution is not None: - warnings.warn( - f"Date column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic." - ) - - return dt.date - @register class Time(OrdinalMixin[dt.time], Column): @@ -275,6 +272,19 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.time64("ns") + @property + def _python_type(self) -> Any: + return dt.time + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + if self.resolution is not None: + warnings.warn( + f"Time column '{self.name or self.__class__.__name__}' has a resolution " + "constraint that cannot be translated to pydantic." + ) + + return super()._pydantic_field_kwargs() + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_time( n, @@ -292,20 +302,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _python_type(self) -> type: - """Return the base Python type for Time column.""" - import datetime as dt - import warnings - - # Warn about untranslated constraints - if self.resolution is not None: - warnings.warn( - f"Time column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic." - ) - - return dt.time - @register class Datetime(OrdinalMixin[dt.datetime], Column): @@ -422,6 +418,22 @@ def pyarrow_dtype(self) -> pa.DataType: ) return pa.timestamp(self.time_unit, time_zone) + @property + def _python_type(self) -> Any: + return dt.datetime + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + if self.resolution is not None: + warnings.warn( + "Datetime resolution is not translated to a pydantic constraint." + ) + if self.time_zone is not None: + warnings.warn( + "Datetime time zone is not translated to a pydantic constraint." + ) + + return super()._pydantic_field_kwargs() + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_datetime( n, @@ -453,30 +465,6 @@ def _attributes_match( return lhs.utcoffset(now) == rhs.utcoffset(now) return super()._attributes_match(lhs, rhs, name, column_expr) - def _python_type(self) -> type: - """Return the base Python type for Datetime column.""" - import datetime as dt - import warnings - - # Warn about untranslated constraints - if self.resolution is not None: - warnings.warn( - f"Datetime column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic." - ) - if self.time_zone is not None: - warnings.warn( - f"Datetime column '{self.name or self.__class__.__name__}' has a time_zone " - "constraint that cannot be translated to pydantic." - ) - if self.time_unit != "us": - warnings.warn( - f"Datetime column '{self.name or self.__class__.__name__}' has a time_unit " - "constraint that cannot be translated to pydantic." - ) - - return dt.datetime - @register class Duration(OrdinalMixin[dt.timedelta], Column): @@ -583,6 +571,18 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.duration(self.time_unit) + @property + def _python_type(self) -> Any: + return dt.timedelta + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + if self.resolution is not None: + warnings.warn( + "Duration resolution is not translated to a pydantic constraint." + ) + + return super()._pydantic_field_kwargs() + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # NOTE: If no duration is specified, we default to 100 years return generator.sample_duration( @@ -602,25 +602,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ) - def _python_type(self) -> type: - """Return the base Python type for Duration column.""" - import datetime as dt - import warnings - - # Warn about untranslated constraints - if self.resolution is not None: - warnings.warn( - f"Duration column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic." - ) - if self.time_unit != "us": - warnings.warn( - f"Duration column '{self.name or self.__class__.__name__}' has a time_unit " - "constraint that cannot be translated to pydantic." - ) - - return dt.timedelta - # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index 11ad0fa..dcea545 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -128,6 +128,16 @@ def pyarrow_dtype(self) -> pa.DataType: # We do not use decimal256 since its values cannot be represented in SQL Server. return pa.decimal128(self.precision or 38, self.scale) + @property + def _python_type(self) -> Any: + return decimal.Decimal + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + return { + **super()._pydantic_field_kwargs(), + "decimal_places": self.scale, + } + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # NOTE: Default precision to 38 for sampling, just like for SQL and Pyarrow precision = self.precision or 38 @@ -169,19 +179,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: ) return ((samples * 10**self.scale).floor() / 10**self.scale).cast(self.dtype) - def _python_type(self) -> type: - """Return the base Python type for Decimal column.""" - import decimal - import warnings - - # Always warn about untranslated constraints - warnings.warn( - f"Decimal column '{self.name or self.__class__.__name__}' has precision and scale " - "constraints that cannot be translated to pydantic." - ) - - return decimal.Decimal - # --------------------------------------- UTILS -------------------------------------- # diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 08629bb..4bf63b6 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -6,7 +6,7 @@ import enum from collections.abc import Iterable from inspect import isclass -from typing import Any +from typing import Any, Literal import polars as pl @@ -95,17 +95,13 @@ def pyarrow_dtype(self) -> pa.DataType: dtype = pa.uint32() return pa.dictionary(dtype, pa.large_string()) + @property + def _python_type(self) -> Any: + return Literal[tuple(self.categories)] + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_choice( n, choices=self.categories, null_probability=self._null_probability, ).cast(self.dtype) - - def _python_type(self) -> type: - """Return the base Python type for Enum column.""" - from typing import Literal - - if len(self.categories) == 0: - return str - return Literal[tuple(self.categories)] # type: ignore diff --git a/dataframely/columns/float.py b/dataframely/columns/float.py index af8973b..7cc2ca9 100644 --- a/dataframely/columns/float.py +++ b/dataframely/columns/float.py @@ -5,6 +5,7 @@ import math import sys +import warnings from abc import abstractmethod from typing import Any @@ -101,6 +102,44 @@ def max_value(self) -> float: def min_value(self) -> float: """Minimum value of the column's type.""" + @property + def _python_type(self) -> Any: + # Warn about untranslated constraints + if self.allow_inf == self.allow_nan and not self.allow_inf: + warnings.warn( + f"Float column '{self.name or self.__class__.__name__}' does not allow " + "infinity or NaN values, but this constraint cannot be translated to pydantic." + ) + else: + if not self.allow_inf: + warnings.warn( + f"Float column '{self.name or self.__class__.__name__}' does not allow " + "infinity values, but this constraint cannot be translated to pydantic." + ) + if not self.allow_nan: + warnings.warn( + f"Float column '{self.name or self.__class__.__name__}' does not allow " + "NaN values, but this constraint cannot be translated to pydantic." + ) + + return float + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + if self.allow_inf != self.allow_nan: + warnings.warn( + "Unequal settings of `allow_inf` and `allow_nan` cannot be translated to " + "pydantic constraints." + ) + + kwargs = super()._pydantic_field_kwargs() + if self.allow_inf == self.allow_nan: + kwargs["allow_inf_nan"] = self.allow_inf + if "le" not in kwargs: + kwargs["le"] = self.max_value + if "ge" not in kwargs: + kwargs["ge"] = self.min_value + return kwargs + @property def _nan_probability(self) -> float: """Private utility for the null probability used during sampling.""" @@ -142,30 +181,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: inf_probability=self._inf_probability, ).cast(self.dtype) - def _python_type(self) -> type: - """Return the base Python type for float column.""" - import warnings - - # Warn about untranslated constraints - if self.allow_inf == self.allow_nan and not self.allow_inf: - warnings.warn( - f"Float column '{self.name or self.__class__.__name__}' does not allow " - "infinity or NaN values, but this constraint cannot be translated to pydantic." - ) - else: - if not self.allow_inf: - warnings.warn( - f"Float column '{self.name or self.__class__.__name__}' does not allow " - "infinity values, but this constraint cannot be translated to pydantic." - ) - if not self.allow_nan: - warnings.warn( - f"Float column '{self.name or self.__class__.__name__}' does not allow " - "NaN values, but this constraint cannot be translated to pydantic." - ) - - return float - # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/integer.py b/dataframely/columns/integer.py index 7bfc401..93b945e 100644 --- a/dataframely/columns/integer.py +++ b/dataframely/columns/integer.py @@ -5,12 +5,12 @@ from abc import abstractmethod from collections.abc import Sequence -from typing import Any +from typing import Any, Literal import polars as pl from polars.datatypes.group import INTEGER_DTYPES -from dataframely._compat import pa, pydantic, sa, sa_mssql, sa_TypeEngine +from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -114,6 +114,20 @@ def min_value(self) -> int: """Minimum value of the column's type.""" return 0 if self.is_unsigned else -(2 ** (self.num_bytes * 8 - 1)) + @property + def _python_type(self) -> Any: + if self.is_in is not None: + return Literal[tuple(self.is_in)] + return int + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + kwargs = super()._pydantic_field_kwargs() + if "le" not in kwargs: + kwargs["le"] = self.max_value + if "ge" not in kwargs: + kwargs["ge"] = self.min_value + return kwargs + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: if self.is_in is not None: return generator.sample_choice( @@ -142,14 +156,6 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: null_probability=self._null_probability, ).cast(self.dtype) - def _python_type(self) -> type: - """Return the base Python type for integer column.""" - from typing import Literal - - if self.is_in is not None: - return Literal[tuple(self.is_in)] # type: ignore - return int - # ------------------------------------------------------------------------------------ # diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 2fb81c8..5ca1547 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -133,6 +133,18 @@ def pyarrow_dtype(self) -> pa.DataType: # NOTE: Polars uses `large_list`s by default. return pa.large_list(self.inner.pyarrow_field("item")) + def _python_type(self) -> Any: + inner_type = self.inner.pydantic_field() + return list[inner_type] # type: ignore + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + kwargs = super()._pydantic_field_kwargs() + if self.min_length is not None: + kwargs["min_length"] = self.min_length + if self.max_length is not None: + kwargs["max_length"] = self.max_length + return kwargs + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # First, sample the number of items per list element # NOTE: We default to 32 for the upper bound as we need some kind of reasonable @@ -177,22 +189,6 @@ def from_dict(cls, data: dict[str, Any]) -> Self: data["inner"] = column_from_dict(data["inner"]) return super().from_dict(data) - def _python_type(self) -> type: - """Return the base Python type for List column.""" - inner_type = self.inner.pydantic_field() - return list[inner_type] # type: ignore - - def _pydantic_field_kwargs(self) -> dict[str, Any]: - """Return pydantic field kwargs for list constraints.""" - kwargs = super()._pydantic_field_kwargs() - - if self.min_length is not None: - kwargs["min_length"] = self.min_length - if self.max_length is not None: - kwargs["max_length"] = self.max_length - - return kwargs - def _list_primary_key_check( list_expr: ExprListNameSpace | ExprArrayNameSpace, inner: Column diff --git a/dataframely/columns/object.py b/dataframely/columns/object.py index e3e3155..18e125c 100644 --- a/dataframely/columns/object.py +++ b/dataframely/columns/object.py @@ -67,13 +67,11 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: raise NotImplementedError("PyArrow column cannot have 'Object' type.") + @property + def _python_type(self) -> Any: + return Any + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: raise NotImplementedError( "Random data sampling not implemented for 'Object' type." ) - - def _python_type(self) -> type: - """Return the base Python type for Object column.""" - from typing import Any as AnyType - - return AnyType diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index 45328db..c253749 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -7,7 +7,7 @@ import polars as pl -from dataframely._compat import pa, pydantic, sa, sa_TypeEngine +from dataframely._compat import pa, sa, sa_TypeEngine from dataframely._native import regex_matching_string_length from dataframely.random import Generator @@ -112,6 +112,19 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.large_string() + def _python_type(self) -> Any: + return str + + def _pydantic_field_kwargs(self) -> dict[str, Any]: + kwargs = super()._pydantic_field_kwargs() + if self.min_length is not None: + kwargs["min_length"] = self.min_length + if self.max_length is not None: + kwargs["max_length"] = self.max_length + if self.regex is not None: + kwargs["pattern"] = self.regex + return kwargs + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: if ( self.min_length is not None or self.max_length is not None @@ -137,20 +150,3 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: regex=regex, null_probability=self._null_probability, ) - - def _python_type(self) -> type: - """Return the base Python type for string column.""" - return str - - def _pydantic_field_kwargs(self) -> dict[str, Any]: - """Return pydantic field kwargs for string constraints.""" - kwargs = super()._pydantic_field_kwargs() - - if self.min_length is not None: - kwargs["min_length"] = self.min_length - if self.max_length is not None: - kwargs["max_length"] = self.max_length - if self.regex is not None: - kwargs["pattern"] = self.regex - - return kwargs diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index d89fbfc..c452871 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -8,7 +8,7 @@ import polars as pl -from dataframely._compat import pa, sa, sa_postgresql, sa_TypeEngine +from dataframely._compat import pa, pydantic, sa, sa_postgresql, sa_TypeEngine from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -117,6 +117,10 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.struct([col.pyarrow_field(name) for name, col in self.inner.items()]) + def _python_type(self) -> Any: + fields = {name: col.pydantic_field() for name, col in self.inner.items()} + return pydantic.create_model("StructModel", **fields) + def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: series = ( pl.DataFrame( @@ -156,19 +160,3 @@ def from_dict(cls, data: dict[str, Any]) -> Self: name: column_from_dict(col) for name, col in data["inner"].items() } return super().from_dict(data) - - def _python_type(self) -> type: - """Return the base Python type for Struct column.""" - from dataframely._compat import pydantic - - fields = {} - for field_name, col in self.inner.items(): - field_type = col.pydantic_field() - fields[field_name] = (field_type, ...) - - model = pydantic.create_model( - f"{self.name or 'Struct'}Model", - **fields, # type: ignore - ) - - return model diff --git a/dataframely/schema.py b/dataframely/schema.py index ff2a180..aed7e7d 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -16,10 +16,8 @@ import polars.exceptions as plexc from polars._typing import FileSource -from dataframely._compat import deltalake, pydantic - from ._base_schema import ORIGINAL_COLUMN_PREFIX, BaseSchema -from ._compat import PartitionSchemeOrSinkDirectory, pa, sa +from ._compat import PartitionSchemeOrSinkDirectory, deltalake, pa, pydantic, sa from ._match_to_schema import match_to_schema from ._native import format_rule_failures from ._plugin import all_rules, all_rules_horizontal, all_rules_required @@ -1340,52 +1338,26 @@ def to_pyarrow_schema(cls) -> pa.Schema: ) @classmethod - def to_pydantic_model(cls) -> type: + def to_pydantic_model(cls) -> type[pydantic.BaseModel]: """Convert this schema to a pydantic model. - This method creates a pydantic model with fields corresponding to the columns - in this schema. Structured constraints (min, max, regex, etc.) are translated - to pydantic field constraints where possible. Custom checks and group rules - are not translated. + The pydantic model includes all columns defined in the schema along with their + (structured) constraints. Custom checks and schema-level rules are not included + in the pydantic model. Returns: - A pydantic model class that can be used for data validation. - - Warning: - Custom checks defined via the `check` parameter on columns are not - included in the pydantic model. A UserWarning is raised if custom checks - are present. - - Warning: - Group rules defined on the schema are not translated to pydantic validators. - A UserWarning is raised if group rules are present. - - Example: - >>> class MySchema(dy.Schema): - ... x = dy.Integer(min=0, max=100) - ... y = dy.String(regex=r"^[A-Z]+$") - ... - >>> Model = MySchema.to_pydantic_model() - >>> # Now you can use Model for validation, e.g., with LLM APIs + A :mod:`pydantic` model class. """ - # Check for group rules and warn if present if cls._schema_validation_rules(): warnings.warn( - f"Schema '{cls.__name__}' has group rules. These are not currently " - "translated to pydantic validators." + "pydantic models generated from schemas do not include schema-level rules." ) - # Build field definitions for the pydantic model - fields = {} - for col_name, col in cls.columns().items(): - # Get the pydantic field type for this column - field_type = col.pydantic_field() - # All fields are required (use ... as default) - fields[col_name] = (field_type, ...) - - # Create the pydantic model dynamically model_name = f"{cls.__name__.removesuffix('Schema')}Model" - return pydantic.create_model(model_name, **fields) # type: ignore + fields = { + col_name: col.pydantic_field() for col_name, col in cls.columns().items() + } + return pydantic.create_model(model_name, **fields) # ----------------------------------- EQUALITY ----------------------------------- # diff --git a/tests/test_pydantic_conversion.py b/tests/test_pydantic_conversion.py deleted file mode 100644 index dfb4b6a..0000000 --- a/tests/test_pydantic_conversion.py +++ /dev/null @@ -1,394 +0,0 @@ -# Copyright (c) QuantCo 2025-2026 -# SPDX-License-Identifier: BSD-3-Clause -"""Tests for pydantic conversion functionality.""" - -import datetime as dt -import decimal -import warnings - -import pytest - -import dataframely as dy -from dataframely._compat import pydantic - -pytestmark = pytest.mark.with_optionals - - -def test_integer_column_pydantic_field() -> None: - # Arrange - column = dy.Integer(min=0, max=100) - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - # Should be annotated int with constraints - # We can test by creating a pydantic model with it - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value=50) - assert instance.value == 50 - - # Invalid value (too small) - with pytest.raises(pydantic.ValidationError): - Model(value=-1) - - # Invalid value (too large) - with pytest.raises(pydantic.ValidationError): - Model(value=101) - - -def test_integer_column_with_is_in() -> None: - # Arrange - column = dy.Integer(is_in=[1, 2, 3]) - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid values - for val in [1, 2, 3]: - instance = Model(value=val) - assert instance.value == val - - # Invalid value - with pytest.raises(pydantic.ValidationError): - Model(value=4) - - -def test_string_column_pydantic_field() -> None: - # Arrange - column = dy.String(min_length=3, max_length=10, regex=r"^[A-Z]+$") - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value="HELLO") - assert instance.value == "HELLO" - - # Invalid: too short - with pytest.raises(pydantic.ValidationError): - Model(value="HI") - - # Invalid: too long - with pytest.raises(pydantic.ValidationError): - Model(value="VERYLONGSTRING") - - # Invalid: doesn't match regex - with pytest.raises(pydantic.ValidationError): - Model(value="hello") - - -def test_float_column_pydantic_field() -> None: - # Arrange - column = dy.Float(min=0.0, max=1.0) - column._name = "test_col" - - # Act - with warnings.catch_warnings(record=True): - # Suppress warnings about allow_inf and allow_nan - warnings.simplefilter("always") - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value=0.5) - assert instance.value == 0.5 - - # Invalid: too small - with pytest.raises(pydantic.ValidationError): - Model(value=-0.1) - - # Invalid: too large - with pytest.raises(pydantic.ValidationError): - Model(value=1.1) - - -def test_bool_column_pydantic_field() -> None: - # Arrange - column = dy.Bool() - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid values - for val in [True, False]: - instance = Model(value=val) - assert instance.value == val - - -def test_date_column_pydantic_field() -> None: - # Arrange - column = dy.Date(min=dt.date(2020, 1, 1), max=dt.date(2025, 12, 31)) - column._name = "test_col" - - # Act - with warnings.catch_warnings(record=True): - # Suppress warnings about resolution - warnings.simplefilter("always") - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value=dt.date(2023, 6, 15)) - assert instance.value == dt.date(2023, 6, 15) - - # Invalid: too early - with pytest.raises(pydantic.ValidationError): - Model(value=dt.date(2019, 1, 1)) - - # Invalid: too late - with pytest.raises(pydantic.ValidationError): - Model(value=dt.date(2026, 1, 1)) - - -def test_datetime_column_pydantic_field() -> None: - # Arrange - column = dy.Datetime( - min=dt.datetime(2020, 1, 1), max=dt.datetime(2025, 12, 31, 23, 59, 59) - ) - column._name = "test_col" - - # Act - with warnings.catch_warnings(record=True): - # Suppress warnings - warnings.simplefilter("always") - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value=dt.datetime(2023, 6, 15, 12, 0, 0)) - assert instance.value == dt.datetime(2023, 6, 15, 12, 0, 0) - - -def test_enum_column_pydantic_field() -> None: - # Arrange - column = dy.Enum(categories=["red", "green", "blue"]) - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid values - for val in ["red", "green", "blue"]: - instance = Model(value=val) - assert instance.value == val - - # Invalid value - with pytest.raises(pydantic.ValidationError): - Model(value="yellow") - - -def test_list_column_pydantic_field() -> None: - # Arrange - inner = dy.Integer(min=0, max=100) - column = dy.List(inner, min_length=2, max_length=5) - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value=[1, 2, 3]) - assert instance.value == [1, 2, 3] - - # Invalid: too short - with pytest.raises(pydantic.ValidationError): - Model(value=[1]) - - # Invalid: too long - with pytest.raises(pydantic.ValidationError): - Model(value=[1, 2, 3, 4, 5, 6]) - - # Invalid: element out of range - with pytest.raises(pydantic.ValidationError): - Model(value=[1, 2, 101]) - - -def test_struct_column_pydantic_field() -> None: - # Arrange - column = dy.Struct({"x": dy.Integer(min=0), "y": dy.String(max_length=10)}) - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value={"x": 5, "y": "hello"}) - assert instance.value.x == 5 - assert instance.value.y == "hello" - - # Invalid: x out of range - with pytest.raises(pydantic.ValidationError): - Model(value={"x": -1, "y": "hello"}) - - -def test_nullable_column_pydantic_field() -> None: - # Arrange - column = dy.Integer(min=0, max=100, nullable=True) - column._name = "test_col" - - # Act - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid: None - instance = Model(value=None) - assert instance.value is None - - # Valid: integer - instance = Model(value=50) - assert instance.value == 50 - - -def test_column_with_custom_check_raises_warning() -> None: - # Arrange - column = dy.Integer(min=0, max=100, check=lambda x: x.is_even()) - column._name = "test_col" - - # Act & Assert - with pytest.warns(UserWarning, match="Custom checks .* are not translated"): - column.pydantic_field() - - -def test_schema_to_pydantic_model() -> None: - # Arrange - class MySchema(dy.Schema): - x = dy.Integer(min=0, max=100) - y = dy.String(regex=r"^[A-Z]+$") - z = dy.Float(nullable=True, allow_inf=True, allow_nan=True) - - # Act - Model = MySchema.to_pydantic_model() - - # Assert - # Valid instance - instance = Model(x=50, y="HELLO", z=3.14) - assert instance.x == 50 - assert instance.y == "HELLO" - assert instance.z == 3.14 - - # Valid with None - instance = Model(x=50, y="HELLO", z=None) - assert instance.z is None - - # Invalid: x out of range - with pytest.raises(pydantic.ValidationError): - Model(x=101, y="HELLO", z=3.14) - - # Invalid: y doesn't match regex - with pytest.raises(pydantic.ValidationError): - Model(x=50, y="hello", z=3.14) - - -def test_schema_with_nested_struct() -> None: - # Arrange - class NestedSchema(dy.Schema): - point = dy.Struct( - { - "x": dy.Float(allow_inf=True, allow_nan=True), - "y": dy.Float(allow_inf=True, allow_nan=True), - } - ) - label = dy.String() - - # Act - Model = NestedSchema.to_pydantic_model() - - # Assert - instance = Model(point={"x": 1.0, "y": 2.0}, label="A") - assert instance.point.x == 1.0 - assert instance.point.y == 2.0 - assert instance.label == "A" - - -def test_schema_with_list_of_ints() -> None: - # Arrange - class ListSchema(dy.Schema): - numbers = dy.List(dy.Integer(min=0), min_length=1, max_length=10) - - # Act - Model = ListSchema.to_pydantic_model() - - # Assert - instance = Model(numbers=[1, 2, 3, 4, 5]) - assert instance.numbers == [1, 2, 3, 4, 5] - - # Invalid: contains negative number - with pytest.raises(pydantic.ValidationError): - Model(numbers=[1, -2, 3]) - - -def test_decimal_column_pydantic_field() -> None: - # Arrange - column = dy.Decimal( - precision=10, - scale=2, - min=decimal.Decimal("0.00"), - max=decimal.Decimal("100.00"), - ) - column._name = "test_col" - - # Act - with warnings.catch_warnings(record=True): - # Suppress warnings about precision and scale - warnings.simplefilter("always") - field_type = column.pydantic_field() - - # Assert - Model = pydantic.create_model("TestModel", value=(field_type, ...)) - - # Valid value - instance = Model(value=decimal.Decimal("50.00")) - assert instance.value == decimal.Decimal("50.00") - - # Invalid: out of range - with pytest.raises(pydantic.ValidationError): - Model(value=decimal.Decimal("150.00")) - - -def test_schema_with_group_rules_raises_warning() -> None: - # Arrange - class SchemaWithRules(dy.Schema): - x = dy.Integer() - y = dy.Integer() - - # Add a custom group rule - @dy.rule() - def sum_check(cls, lf): # type: ignore - return (lf.select("x") + lf.select("y") > 0).to_series() - - # Act & Assert - with pytest.warns(UserWarning, match="group rules.*not currently translated"): - SchemaWithRules.to_pydantic_model() From 1d97ed3afea7d722b27784234f64b27a30cd753a Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Fri, 10 Apr 2026 17:44:47 +0200 Subject: [PATCH 07/12] Clean tests --- dataframely/columns/array.py | 4 +- dataframely/columns/float.py | 18 -- dataframely/columns/list.py | 1 + dataframely/columns/string.py | 1 + dataframely/columns/struct.py | 1 + dataframely/schema.py | 4 +- tests/columns/test_pydantic.py | 341 ++++++++++++++++++++++++++++ tests/schema/test_pydantic_model.py | 63 +++++ 8 files changed, 410 insertions(+), 23 deletions(-) create mode 100644 tests/columns/test_pydantic.py create mode 100644 tests/schema/test_pydantic_model.py diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index 44112a7..b6a26b4 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -135,8 +135,8 @@ def _pydantic_field_kwargs(self) -> dict[str, Any]: return { **super()._pydantic_field_kwargs(), - "min_items": math.prod(self.shape), - "max_items": math.prod(self.shape), + "min_length": math.prod(self.shape), + "max_length": math.prod(self.shape), } def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: diff --git a/dataframely/columns/float.py b/dataframely/columns/float.py index 7cc2ca9..8b2b2f1 100644 --- a/dataframely/columns/float.py +++ b/dataframely/columns/float.py @@ -104,24 +104,6 @@ def min_value(self) -> float: @property def _python_type(self) -> Any: - # Warn about untranslated constraints - if self.allow_inf == self.allow_nan and not self.allow_inf: - warnings.warn( - f"Float column '{self.name or self.__class__.__name__}' does not allow " - "infinity or NaN values, but this constraint cannot be translated to pydantic." - ) - else: - if not self.allow_inf: - warnings.warn( - f"Float column '{self.name or self.__class__.__name__}' does not allow " - "infinity values, but this constraint cannot be translated to pydantic." - ) - if not self.allow_nan: - warnings.warn( - f"Float column '{self.name or self.__class__.__name__}' does not allow " - "NaN values, but this constraint cannot be translated to pydantic." - ) - return float def _pydantic_field_kwargs(self) -> dict[str, Any]: diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 5ca1547..ad300f2 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -133,6 +133,7 @@ def pyarrow_dtype(self) -> pa.DataType: # NOTE: Polars uses `large_list`s by default. return pa.large_list(self.inner.pyarrow_field("item")) + @property def _python_type(self) -> Any: inner_type = self.inner.pydantic_field() return list[inner_type] # type: ignore diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index c253749..d0e80e0 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -112,6 +112,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.large_string() + @property def _python_type(self) -> Any: return str diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index c452871..ebf34e7 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -117,6 +117,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: def pyarrow_dtype(self) -> pa.DataType: return pa.struct([col.pyarrow_field(name) for name, col in self.inner.items()]) + @property def _python_type(self) -> Any: fields = {name: col.pydantic_field() for name, col in self.inner.items()} return pydantic.create_model("StructModel", **fields) diff --git a/dataframely/schema.py b/dataframely/schema.py index aed7e7d..87fa15a 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -1349,9 +1349,7 @@ def to_pydantic_model(cls) -> type[pydantic.BaseModel]: A :mod:`pydantic` model class. """ if cls._schema_validation_rules(): - warnings.warn( - "pydantic models generated from schemas do not include schema-level rules." - ) + warnings.warn("pydantic models do not include schema-level rules.") model_name = f"{cls.__name__.removesuffix('Schema')}Model" fields = { diff --git a/tests/columns/test_pydantic.py b/tests/columns/test_pydantic.py new file mode 100644 index 0000000..653fe7b --- /dev/null +++ b/tests/columns/test_pydantic.py @@ -0,0 +1,341 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +import datetime as dt +import decimal +from typing import Any, Literal, get_args, get_origin + +import pytest + +import dataframely as dy +from dataframely._compat import pydantic +from dataframely.columns import Column +from dataframely.testing import ALL_COLUMN_TYPES, COLUMN_TYPES, SUPERTYPE_COLUMN_TYPES + +pytestmark = pytest.mark.with_optionals + + +# ------------------------------------ BASIC TESTS ----------------------------------- # + + +@pytest.mark.parametrize("column_type", ALL_COLUMN_TYPES) +def test_pydantic_field_returns(column_type: type[Column]) -> None: + field = column_type().pydantic_field() + assert field is not None + + +@pytest.mark.parametrize( + ("column", "expected_type"), + [ + (dy.Any(), Any), + (dy.Bool(), bool), + (dy.Date(), dt.date), + (dy.Datetime(), dt.datetime), + (dy.Time(), dt.time), + (dy.Decimal(), decimal.Decimal), + (dy.Duration(), dt.timedelta), + (dy.Float32(), float), + (dy.Float64(), float), + (dy.Int8(), int), + (dy.Int16(), int), + (dy.Int32(), int), + (dy.Int64(), int), + (dy.UInt8(), int), + (dy.UInt16(), int), + (dy.UInt32(), int), + (dy.UInt64(), int), + (dy.String(), str), + (dy.Categorical(), str), + (dy.Binary(), bytes), + (dy.Float(), float), + (dy.Integer(), int), + (dy.Object(), Any), + ], +) +def test_python_type(column: Column, expected_type: type) -> None: + assert column._python_type == expected_type + + +# ----------------------------------- NULLABILITY ------------------------------------ # + + +def test_nullable_any() -> None: + col = dy.Any() + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=None) # should not raise + + +@pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES) +def test_nullable_includes_none(column_type: type[Column]) -> None: + col = column_type(nullable=True) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=None) # should not raise + + +@pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES) +def test_non_nullable_rejects_none(column_type: type[Column]) -> None: + col = column_type(nullable=False) + model = pydantic.create_model("Test", val=col.pydantic_field()) + with pytest.raises(pydantic.ValidationError): + model(val=None) + + +# --------------------------------------- DATE --------------------------------------- # + + +def test_date_min_max() -> None: + col = dy.Date(min=dt.date(2020, 1, 1), max=dt.date(2020, 12, 31)) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=dt.date(2020, 6, 15)) + with pytest.raises(pydantic.ValidationError): + model(val=dt.date(2019, 12, 31)) + with pytest.raises(pydantic.ValidationError): + model(val=dt.date(2021, 1, 1)) + + +def test_date_resolution_warning() -> None: + col = dy.Date(resolution="1w") + with pytest.warns( + match="Date resolution is not translated to a pydantic constraint" + ): + col.pydantic_field() + + +# ------------------------------------- DATETIME ------------------------------------- # + + +def test_datetime_min_max() -> None: + col = dy.Datetime( + min=dt.datetime(2020, 1, 1), + max=dt.datetime(2020, 12, 31), + ) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=dt.datetime(2020, 6, 15)) + with pytest.raises(pydantic.ValidationError): + model(val=dt.datetime(2019, 12, 31)) + + +def test_datetime_resolution_warning() -> None: + col = dy.Datetime(resolution="1h") + with pytest.warns( + match="Datetime resolution is not translated to a pydantic constraint" + ): + col.pydantic_field() + + +def test_datetime_timezone_warning() -> None: + col = dy.Datetime(time_zone="Etc/UTC") + with pytest.warns( + match="Datetime time zone is not translated to a pydantic constraint" + ): + col.pydantic_field() + + +# -------------------------------------- DECIMAL ------------------------------------- # + + +def test_decimal_scale() -> None: + col = dy.Decimal(precision=5, scale=2) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=decimal.Decimal("1.23")) + + +def test_decimal_min_max() -> None: + col = dy.Decimal(min=decimal.Decimal("0"), max=decimal.Decimal("100")) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=decimal.Decimal("0")) + model(val=decimal.Decimal("100")) + with pytest.raises(pydantic.ValidationError): + model(val=decimal.Decimal("-1")) + with pytest.raises(pydantic.ValidationError): + model(val=decimal.Decimal("101")) + + +# --------------------------------------- FLOAT -------------------------------------- # + + +def test_float_min_max() -> None: + col = dy.Float64(min=0.0, max=1.0) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=0.0) + model(val=1.0) + with pytest.raises(pydantic.ValidationError): + model(val=-0.1) + with pytest.raises(pydantic.ValidationError): + model(val=1.1) + + +def test_float_no_inf_nan() -> None: + col = dy.Float64(allow_inf=False, allow_nan=False) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=1.0) + with pytest.raises(pydantic.ValidationError): + model(val=float("inf")) + with pytest.raises(pydantic.ValidationError): + model(val=float("nan")) + + +def test_float_unequal_inf_nan_warning() -> None: + col = dy.Float64(allow_inf=False, allow_nan=True) + with pytest.warns( + match="Unequal settings of `allow_inf` and `allow_nan` cannot be translated" + ): + col.pydantic_field() + + +# -------------------------------------- INTEGER ------------------------------------- # + + +def test_integer_min_max() -> None: + col = dy.Int64(min=0, max=10) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=0) + model(val=10) + with pytest.raises(pydantic.ValidationError): + model(val=-1) + with pytest.raises(pydantic.ValidationError): + model(val=11) + + +def test_integer_min_max_exclusive() -> None: + col = dy.Int64(min_exclusive=0, max_exclusive=10) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=1) + model(val=9) + with pytest.raises(pydantic.ValidationError): + model(val=0) + with pytest.raises(pydantic.ValidationError): + model(val=10) + + +def test_integer_type_bounds() -> None: + """Int8 should enforce [-128, 127] even without explicit min/max.""" + col = dy.Int8() + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=-128) + model(val=127) + with pytest.raises(pydantic.ValidationError): + model(val=-129) + with pytest.raises(pydantic.ValidationError): + model(val=128) + + +def test_uint8_type_bounds() -> None: + col = dy.UInt8() + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=0) + model(val=255) + with pytest.raises(pydantic.ValidationError): + model(val=-1) + with pytest.raises(pydantic.ValidationError): + model(val=256) + + +def test_integer_is_in() -> None: + col = dy.Int64(is_in=[1, 2, 3]) + tp = col._python_type + assert get_origin(tp) is Literal + assert set(get_args(tp)) == {1, 2, 3} + + +# -------------------------------------- STRING -------------------------------------- # + + +def test_string_min_max_length() -> None: + col = dy.String(min_length=2, max_length=5) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val="ab") + model(val="abcde") + with pytest.raises(pydantic.ValidationError): + model(val="a") + with pytest.raises(pydantic.ValidationError): + model(val="abcdef") + + +def test_string_regex() -> None: + col = dy.String(regex=r"^[a-z]+$") + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val="abc") + with pytest.raises(pydantic.ValidationError): + model(val="ABC") + + +# --------------------------------------- ENUM --------------------------------------- # + + +def test_enum_field() -> None: + col = dy.Enum(["a", "b", "c"]) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val="a") + with pytest.raises(pydantic.ValidationError): + model(val="d") + + +# --------------------------------------- LIST --------------------------------------- # + + +def test_list_min_max_length() -> None: + col = dy.List(dy.Int64(), min_length=1, max_length=3) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=[1]) + model(val=[1, 2, 3]) + with pytest.raises(pydantic.ValidationError): + model(val=[]) + with pytest.raises(pydantic.ValidationError): + model(val=[1, 2, 3, 4]) + + +def test_list_inner_type_validation() -> None: + col = dy.List(dy.Int64(min=0)) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=[1, 2, 3]) + with pytest.raises(pydantic.ValidationError): + model(val=[-1]) + + +# --------------------------------------- ARRAY -------------------------------------- # + + +def test_python_type_array() -> None: + col = dy.Array(dy.Int64(), shape=3) + tp = col._python_type + assert get_origin(tp) is list + + +def test_array_fixed_size() -> None: + col = dy.Array(dy.Int64(), shape=3) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val=[1, 2, 3]) + with pytest.raises(pydantic.ValidationError): + model(val=[1, 2]) + with pytest.raises(pydantic.ValidationError): + model(val=[1, 2, 3, 4]) + + +def test_array_multidim_warning() -> None: + col = dy.Array(dy.Int64(), shape=(2, 3)) + with pytest.warns(match="Multi-dimensional arrays are flattened for pydantic"): + col.pydantic_field() + + +# -------------------------------------- STRUCT -------------------------------------- # + + +def test_struct_field() -> None: + col = dy.Struct({"x": dy.Int64(nullable=False), "y": dy.String(nullable=True)}) + model = pydantic.create_model("Test", val=col.pydantic_field()) + model(val={"x": 1, "y": "hello"}) + model(val={"x": 1, "y": None}) + with pytest.raises(pydantic.ValidationError): + model(val={"x": None, "y": "hello"}) + + +# ------------------------------------- WARNINGS ------------------------------------- # + + +def test_custom_check_warning() -> None: + col = dy.Int64(check=lambda s: s > 0) + with pytest.warns( + match="Custom checks.*are not translated to pydantic constraints" + ): + col.pydantic_field() diff --git a/tests/schema/test_pydantic_model.py b/tests/schema/test_pydantic_model.py new file mode 100644 index 0000000..ef0b537 --- /dev/null +++ b/tests/schema/test_pydantic_model.py @@ -0,0 +1,63 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +import polars as pl +import pytest + +import dataframely as dy +from dataframely._compat import pydantic +from dataframely._rule import Rule +from dataframely.testing import create_schema + +pytestmark = pytest.mark.with_optionals + + +def test_basic_model() -> None: + schema = create_schema( + "TestSchema", + {"x": dy.Int64(), "y": dy.String(nullable=True)}, + ) + model_cls = schema.to_pydantic_model() + assert issubclass(model_cls, pydantic.BaseModel) + assert model_cls.__name__ == "TestModel" + assert set(model_cls.model_fields.keys()) == {"x", "y"} + + +def test_validation_success() -> None: + schema = create_schema( + "TestSchema", + { + "x": dy.Int64(), + "name": dy.String(), + "active": dy.Bool(), + }, + ) + model_cls = schema.to_pydantic_model() + instance = model_cls(x=42, name="hello", active=True) + assert instance.x == 42 # type: ignore + assert instance.name == "hello" # type: ignore + assert instance.active is True # type: ignore + + +def test_validation_failure() -> None: + schema = create_schema( + "TestSchema", + { + "x": dy.Int64(), + "name": dy.String(), + "active": dy.Bool(), + }, + ) + model_cls = schema.to_pydantic_model() + with pytest.raises(pydantic.ValidationError): + model_cls(x="not an int", name="hello", active=True) + + +def test_schema_with_rules_warns() -> None: + schema = create_schema( + "TestSchema", + {"x": dy.Int64()}, + rules={"my_rule": Rule(pl.col("x") > 0)}, + ) + with pytest.warns(match="pydantic models do not include schema-level rules"): + schema.to_pydantic_model() From 88969f631b2ad1351804a574eb1305b3f1644657 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Fri, 10 Apr 2026 17:59:17 +0200 Subject: [PATCH 08/12] Coverage --- tests/columns/test_pydantic.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/columns/test_pydantic.py b/tests/columns/test_pydantic.py index 653fe7b..7c65a03 100644 --- a/tests/columns/test_pydantic.py +++ b/tests/columns/test_pydantic.py @@ -131,6 +131,17 @@ def test_datetime_timezone_warning() -> None: col.pydantic_field() +# --------------------------------------- TIME --------------------------------------- # + + +def test_time_resolution_warning() -> None: + col = dy.Time(resolution="1h") + with pytest.warns( + match="Time column.*has a resolution constraint that cannot be translated" + ): + col.pydantic_field() + + # -------------------------------------- DECIMAL ------------------------------------- # @@ -151,6 +162,17 @@ def test_decimal_min_max() -> None: model(val=decimal.Decimal("101")) +# ------------------------------------- DURATION ------------------------------------- # + + +def test_duration_resolution_warning() -> None: + col = dy.Duration(resolution="1h") + with pytest.warns( + match="Duration resolution is not translated to a pydantic constraint" + ): + col.pydantic_field() + + # --------------------------------------- FLOAT -------------------------------------- # From afa230ec439904ca6e5301224c1ff32222bfce9f Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Fri, 10 Apr 2026 18:04:48 +0200 Subject: [PATCH 09/12] Add another test --- tests/schema/test_pydantic_model.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/schema/test_pydantic_model.py b/tests/schema/test_pydantic_model.py index ef0b537..0d9ac8e 100644 --- a/tests/schema/test_pydantic_model.py +++ b/tests/schema/test_pydantic_model.py @@ -14,7 +14,7 @@ def test_basic_model() -> None: schema = create_schema( - "TestSchema", + "test", {"x": dy.Int64(), "y": dy.String(nullable=True)}, ) model_cls = schema.to_pydantic_model() @@ -25,7 +25,7 @@ def test_basic_model() -> None: def test_validation_success() -> None: schema = create_schema( - "TestSchema", + "test", { "x": dy.Int64(), "name": dy.String(), @@ -41,7 +41,7 @@ def test_validation_success() -> None: def test_validation_failure() -> None: schema = create_schema( - "TestSchema", + "test", { "x": dy.Int64(), "name": dy.String(), @@ -55,9 +55,16 @@ def test_validation_failure() -> None: def test_schema_with_rules_warns() -> None: schema = create_schema( - "TestSchema", + "test", {"x": dy.Int64()}, rules={"my_rule": Rule(pl.col("x") > 0)}, ) with pytest.warns(match="pydantic models do not include schema-level rules"): schema.to_pydantic_model() + + +def test_schema_with_alias() -> None: + schema = create_schema("test", {"x": dy.Int64(alias="column with space")}) + model_cls = schema.to_pydantic_model() + assert model_cls.model_fields.keys() == {"column with space"} + model_cls(**{"column with space": 42}) From 6a06e139a0501833d7aae31e202392932844f39e Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Fri, 10 Apr 2026 18:32:58 +0200 Subject: [PATCH 10/12] Fix --- tests/schema/test_pydantic_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/schema/test_pydantic_model.py b/tests/schema/test_pydantic_model.py index 0d9ac8e..efdb64c 100644 --- a/tests/schema/test_pydantic_model.py +++ b/tests/schema/test_pydantic_model.py @@ -14,7 +14,7 @@ def test_basic_model() -> None: schema = create_schema( - "test", + "TestSchema", {"x": dy.Int64(), "y": dy.String(nullable=True)}, ) model_cls = schema.to_pydantic_model() From ae29d250d979f80d266f1d6d95ba6eb450d3d228 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Mon, 13 Apr 2026 00:06:11 +0200 Subject: [PATCH 11/12] Update --- dataframely/columns/datetime.py | 5 +---- dataframely/schema.py | 12 +++++++++--- tests/schema/test_pydantic_model.py | 6 ++++++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index 56b8df7..13e3212 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -278,10 +278,7 @@ def _python_type(self) -> Any: def _pydantic_field_kwargs(self) -> dict[str, Any]: if self.resolution is not None: - warnings.warn( - f"Time column '{self.name or self.__class__.__name__}' has a resolution " - "constraint that cannot be translated to pydantic." - ) + warnings.warn("Time resolution is not translated to a pydantic constraint.") return super()._pydantic_field_kwargs() diff --git a/dataframely/schema.py b/dataframely/schema.py index 87fa15a..0836974 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -1338,20 +1338,26 @@ def to_pyarrow_schema(cls) -> pa.Schema: ) @classmethod - def to_pydantic_model(cls) -> type[pydantic.BaseModel]: + def to_pydantic_model(cls, name: str | None = None) -> type[pydantic.BaseModel]: """Convert this schema to a pydantic model. The pydantic model includes all columns defined in the schema along with their (structured) constraints. Custom checks and schema-level rules are not included in the pydantic model. + Args: + name: The name of the returned pydantic model. If `None`, a default name is + generated based on the name of this schema. + Returns: A :mod:`pydantic` model class. """ if cls._schema_validation_rules(): - warnings.warn("pydantic models do not include schema-level rules.") + warnings.warn( + "Schema-level rules are not translated to pydantic validators." + ) - model_name = f"{cls.__name__.removesuffix('Schema')}Model" + model_name = name or f"{cls.__name__.removesuffix('Schema')}Model" fields = { col_name: col.pydantic_field() for col_name, col in cls.columns().items() } diff --git a/tests/schema/test_pydantic_model.py b/tests/schema/test_pydantic_model.py index efdb64c..f86bda1 100644 --- a/tests/schema/test_pydantic_model.py +++ b/tests/schema/test_pydantic_model.py @@ -23,6 +23,12 @@ def test_basic_model() -> None: assert set(model_cls.model_fields.keys()) == {"x", "y"} +def test_custom_name() -> None: + schema = create_schema("TestSchema", {"x": dy.Int64()}) + model_cls = schema.to_pydantic_model(name="CustomModelName") + assert model_cls.__name__ == "CustomModelName" + + def test_validation_success() -> None: schema = create_schema( "test", From a5a82806c7d451f3b082b959e0ddc1cd30a6895e Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Mon, 13 Apr 2026 00:16:34 +0200 Subject: [PATCH 12/12] Fix --- tests/columns/test_pydantic.py | 2 +- tests/schema/test_pydantic_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/columns/test_pydantic.py b/tests/columns/test_pydantic.py index 7c65a03..9c4121a 100644 --- a/tests/columns/test_pydantic.py +++ b/tests/columns/test_pydantic.py @@ -137,7 +137,7 @@ def test_datetime_timezone_warning() -> None: def test_time_resolution_warning() -> None: col = dy.Time(resolution="1h") with pytest.warns( - match="Time column.*has a resolution constraint that cannot be translated" + match="Time resolution is not translated to a pydantic constraint" ): col.pydantic_field() diff --git a/tests/schema/test_pydantic_model.py b/tests/schema/test_pydantic_model.py index f86bda1..e9372a1 100644 --- a/tests/schema/test_pydantic_model.py +++ b/tests/schema/test_pydantic_model.py @@ -65,7 +65,7 @@ def test_schema_with_rules_warns() -> None: {"x": dy.Int64()}, rules={"my_rule": Rule(pl.col("x") > 0)}, ) - with pytest.warns(match="pydantic models do not include schema-level rules"): + with pytest.warns(match="Schema-level rules are not translated to pydantic"): schema.to_pydantic_model()