Skip to content

Commit 98fa59e

Browse files
fix: Propagate column overrides through grandchild schemas (#330)
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 7c73bb1 commit 98fa59e

2 files changed

Lines changed: 39 additions & 11 deletions

File tree

dataframely/_base_schema.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import textwrap
88
from abc import ABCMeta
9+
from collections.abc import Mapping
910
from copy import copy
1011
from dataclasses import dataclass, field
1112
from typing import TYPE_CHECKING, Any
@@ -116,12 +117,7 @@ def __new__(
116117
*args: Any,
117118
**kwargs: Any,
118119
) -> SchemaMeta:
119-
result = Metadata()
120-
for base in bases:
121-
result.update(mcs._get_metadata_recursively(base))
122-
namespace_metadata = mcs._get_metadata(namespace)
123-
mcs._remove_overridden_columns(result, namespace, bases)
124-
result.update(namespace_metadata)
120+
result = mcs._collect_metadata(bases, namespace)
125121
namespace[_COLUMN_ATTR] = result.columns
126122
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)
127123

@@ -212,7 +208,7 @@ def __getattribute__(cls, name: str) -> Any:
212208
@staticmethod
213209
def _remove_overridden_columns(
214210
result: Metadata,
215-
namespace: dict[str, Any],
211+
namespace: Mapping[str, Any],
216212
bases: tuple[type[object], ...],
217213
) -> None:
218214
"""Remove inherited columns that the child namespace explicitly overrides.
@@ -238,15 +234,23 @@ def _remove_overridden_columns(
238234
result.columns.pop(parent_key, None)
239235

240236
@staticmethod
241-
def _get_metadata_recursively(kls: type[object]) -> Metadata:
237+
def _collect_metadata(
238+
bases: tuple[type[object], ...],
239+
namespace: Mapping[str, Any],
240+
) -> Metadata:
242241
result = Metadata()
243-
for base in kls.__bases__:
242+
for base in bases:
244243
result.update(SchemaMeta._get_metadata_recursively(base))
245-
result.update(SchemaMeta._get_metadata(kls.__dict__)) # type: ignore
244+
SchemaMeta._remove_overridden_columns(result, namespace, bases)
245+
result.update(SchemaMeta._get_metadata(namespace))
246246
return result
247247

248248
@staticmethod
249-
def _get_metadata(source: dict[str, Any]) -> Metadata:
249+
def _get_metadata_recursively(kls: type[object]) -> Metadata:
250+
return SchemaMeta._collect_metadata(kls.__bases__, kls.__dict__)
251+
252+
@staticmethod
253+
def _get_metadata(source: Mapping[str, Any]) -> Metadata:
250254
result = Metadata()
251255
for attr, value in {
252256
k: v for k, v in source.items() if not k.startswith("__")

tests/schema/test_inheritance.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,27 @@ def test_columns() -> None:
2020
assert ParentSchema.column_names() == ["a"]
2121
assert ChildSchema.column_names() == ["a", "b"]
2222
assert GrandchildSchema.column_names() == ["a", "b", "c"]
23+
24+
25+
class OverrideBase(dy.Schema):
26+
amt = dy.Float64(nullable=True)
27+
28+
29+
class OverrideChild(OverrideBase):
30+
amt = dy.Float64(nullable=False)
31+
32+
33+
class OverrideGrandchild(OverrideChild):
34+
pass
35+
36+
37+
class OverrideGreatGrandchild(OverrideGrandchild):
38+
other = dy.Integer()
39+
40+
41+
def test_column_override_propagates_to_grandchild() -> None:
42+
assert OverrideBase.columns()["amt"].nullable is True
43+
assert OverrideChild.columns()["amt"].nullable is False
44+
assert OverrideGrandchild.columns()["amt"].nullable is False
45+
assert OverrideGreatGrandchild.columns()["amt"].nullable is False
46+
assert OverrideGreatGrandchild.column_names() == ["amt", "other"]

0 commit comments

Comments
 (0)