Skip to content

Commit 54cbc75

Browse files
authored
feat: Allow to inline collection member columns for sampling (#5)
1 parent 2cf6a39 commit 54cbc75

3 files changed

Lines changed: 104 additions & 2 deletions

File tree

dataframely/_base_collection.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def my_filter(self) -> pl.DataFrame:
4848

4949
#: Whether the member should be ignored in the filter method.
5050
ignored_in_filters: bool = False
51+
#: Whether the member's non-primary key columns should be inlined for sampling.
52+
#: This means that value overrides are supplied on the top-level rather than in
53+
#: a subkey with the member's name. Only valid if the member's primary key matches
54+
#: the collection's common primary key. Two members that share common column names
55+
#: may not both be inlined for sampling.
56+
inline_for_sampling: bool = False
5157

5258

5359
# --------------------------------------- UTILS -------------------------------------- #
@@ -136,6 +142,30 @@ def __new__(
136142
f"{len(intersection)} such filters: {sorted(intersection)}."
137143
)
138144

145+
# 3) Check that inlining for sampling is configured correctly.
146+
if len(non_ignored_member_schemas) > 0:
147+
common_primary_keys = _common_primary_keys(non_ignored_member_schemas)
148+
inlined_columns: set[str] = set()
149+
for member, info in result.members.items():
150+
if info.inline_for_sampling:
151+
if set(info.schema.primary_keys()) != common_primary_keys:
152+
raise ImplementationError(
153+
f"Member '{member}' is inlined for sampling but its primary "
154+
"key is a superset of the common primary key. Such a member "
155+
"must not be inlined to be able to provide multiple values "
156+
"for a single combination of the common primary key."
157+
)
158+
non_primary_key_columns = (
159+
set(info.schema.column_names()) - common_primary_keys
160+
)
161+
if len(inlined_columns & non_primary_key_columns):
162+
raise ImplementationError(
163+
f"At least one column name of member '{member}' clashes "
164+
"with a column name of another member that is inlined for "
165+
"sampling."
166+
)
167+
inlined_columns.update(non_primary_key_columns)
168+
139169
return super().__new__(mcs, name, bases, namespace, *args, **kwargs)
140170

141171
@staticmethod
@@ -201,6 +231,7 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
201231
schema=get_args(kls)[0],
202232
is_optional=False,
203233
ignored_in_filters=collection_member.ignored_in_filters,
234+
inline_for_sampling=collection_member.inline_for_sampling,
204235
)
205236
else:
206237
# Some other unknown annotation

dataframely/collection.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def sample(
123123
...
124124
}
125125
126-
_Any_ member/value can be left out and will be sampled automatically.
126+
*Any* member/value can be left out and will be sampled automatically.
127+
Note that overrides for columns of members that are annotated with
128+
``inline_for_sampling=True`` can be supplied on the top-level instead
129+
of in a nested dictionary.
127130
generator: The (seeded) generator to use for sampling data. If ``None``, a
128131
generator with random seed is automatically created.
129132
@@ -198,7 +201,11 @@ def sample(
198201
else _extract_keys_if_exist(sample, primary_keys)
199202
),
200203
**_extract_keys_if_exist(
201-
sample[member] if member in sample else {},
204+
(
205+
sample
206+
if member_infos[member].inline_for_sampling
207+
else (sample[member] if member in sample else {})
208+
),
202209
schema.column_names(),
203210
),
204211
}

tests/collection/test_sample.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import pytest
77

88
import dataframely as dy
9+
from dataframely.exc import ImplementationError
910
from dataframely.random import Generator
11+
from dataframely.testing.factory import create_collection_raw
1012

1113

1214
class MyFirstSchema(dy.Schema):
@@ -37,6 +39,21 @@ def _preprocess_sample(
3739
return sample
3840

3941

42+
class MyInlinedCollection(dy.Collection):
43+
first: Annotated[
44+
dy.LazyFrame[MyFirstSchema],
45+
dy.CollectionMember(inline_for_sampling=True),
46+
]
47+
second: dy.LazyFrame[MySecondSchema]
48+
49+
@classmethod
50+
def _preprocess_sample(
51+
cls, sample: dict[str, Any], index: int, generator: Generator
52+
) -> dict[str, Any]:
53+
sample["a"] = index
54+
return sample
55+
56+
4057
class SmallCollection(dy.Collection):
4158
first: dy.LazyFrame[MyFirstSchema]
4259

@@ -100,6 +117,22 @@ def test_sample_with_overrides() -> None:
100117
assert collection.second.collect()["c"].to_list() == [3, 4, 6]
101118

102119

120+
def test_sample_inline_with_overrides() -> None:
121+
collection = MyInlinedCollection.sample(
122+
overrides=[
123+
{"b": 4, "second": [{"c": 3}, {"c": 4}]},
124+
{"b": 8, "second": [{"c": 6}]},
125+
]
126+
)
127+
assert collection.first.collect()["a"].to_list() == [0, 1]
128+
assert collection.first.collect()["b"].to_list() == [4, 8]
129+
130+
assert collection.second is not None
131+
assert collection.second.collect()["a"].to_list() == [0, 0, 1]
132+
assert collection.second.collect()["b"].to_list() != [4, 4, 8]
133+
assert collection.second.collect()["c"].to_list() == [3, 4, 6]
134+
135+
103136
@pytest.mark.parametrize("n", [0, 1000])
104137
def test_sample_without_dependent_members(n: int) -> None:
105138
collection = SmallCollection.sample(n)
@@ -125,3 +158,34 @@ def test_sample_no_common_primary_key() -> None:
125158
def test_sample_no_overwrite() -> None:
126159
with pytest.raises(ValueError, match=r"`_preprocess_sample` must be overwritten"):
127160
IncompleteCollection.sample()
161+
162+
163+
def test_invalid_inline_for_sampling() -> None:
164+
with pytest.raises(ImplementationError, match=r"its primary key is a superset"):
165+
create_collection_raw(
166+
"test",
167+
{
168+
"first": dy.LazyFrame[MyFirstSchema],
169+
"second": Annotated[
170+
dy.LazyFrame[MySecondSchema],
171+
dy.CollectionMember(inline_for_sampling=True),
172+
],
173+
},
174+
)
175+
176+
177+
def test_duplicate_column_inlined_for_sampling() -> None:
178+
with pytest.raises(ImplementationError, match=r"clashes with a column name"):
179+
create_collection_raw(
180+
"test",
181+
{
182+
"first": Annotated[
183+
dy.LazyFrame[MyFirstSchema],
184+
dy.CollectionMember(inline_for_sampling=True),
185+
],
186+
"second": Annotated[
187+
dy.LazyFrame[MyFirstSchema],
188+
dy.CollectionMember(inline_for_sampling=True),
189+
],
190+
},
191+
)

0 commit comments

Comments
 (0)