Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions dataframely/_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def my_filter(self) -> pl.DataFrame:

#: Whether the member should be ignored in the filter method.
ignored_in_filters: bool = False
#: Whether the member's non-primary key columns should be inlined for sampling.
#: This means that value overrides are supplied on the top-level rather than in
#: a subkey with the member's name. Only valid if the member's primary key matches
#: the collection's common primary key. Two members that share common column names
#: may not both be inlined for sampling.
inline_for_sampling: bool = False


# --------------------------------------- UTILS -------------------------------------- #
Expand Down Expand Up @@ -136,6 +142,30 @@ def __new__(
f"{len(intersection)} such filters: {sorted(intersection)}."
)

# 3) Check that inlining for sampling is configured correctly.
if len(non_ignored_member_schemas) > 0:
common_primary_keys = _common_primary_keys(non_ignored_member_schemas)
inlined_columns: set[str] = set()
for member, info in result.members.items():
if info.inline_for_sampling:
if set(info.schema.primary_keys()) != common_primary_keys:
raise ImplementationError(
f"Member '{member}' is inlined for sampling but its primary "
"key is a superset of the common primary key. Such a member "
"must not be inlined to be able to provide multiple values "
"for a single combination of the common primary key."
)
non_primary_key_columns = (
set(info.schema.column_names()) - common_primary_keys
)
if len(inlined_columns & non_primary_key_columns):
raise ImplementationError(
f"At least one column name of member '{member}' clashes "
"with a column name of another member that is inlined for "
"sampling."
)
inlined_columns.update(non_primary_key_columns)

return super().__new__(mcs, name, bases, namespace, *args, **kwargs)

@staticmethod
Expand Down Expand Up @@ -201,6 +231,7 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
schema=get_args(kls)[0],
is_optional=False,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
)
else:
# Some other unknown annotation
Expand Down
11 changes: 9 additions & 2 deletions dataframely/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def sample(
...
}

_Any_ member/value can be left out and will be sampled automatically.
*Any* member/value can be left out and will be sampled automatically.
Note that overrides for columns of members that are annotated with
``inline_for_sampling=True`` can be supplied on the top-level instead
of in a nested dictionary.
generator: The (seeded) generator to use for sampling data. If ``None``, a
generator with random seed is automatically created.

Expand Down Expand Up @@ -198,7 +201,11 @@ def sample(
else _extract_keys_if_exist(sample, primary_keys)
),
**_extract_keys_if_exist(
sample[member] if member in sample else {},
(
sample
if member_infos[member].inline_for_sampling
else (sample[member] if member in sample else {})
),
schema.column_names(),
),
}
Expand Down
64 changes: 64 additions & 0 deletions tests/collection/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import pytest

import dataframely as dy
from dataframely.exc import ImplementationError
from dataframely.random import Generator
from dataframely.testing.factory import create_collection_raw


class MyFirstSchema(dy.Schema):
Expand Down Expand Up @@ -37,6 +39,21 @@ def _preprocess_sample(
return sample


class MyInlinedCollection(dy.Collection):
first: Annotated[
dy.LazyFrame[MyFirstSchema],
dy.CollectionMember(inline_for_sampling=True),
]
second: dy.LazyFrame[MySecondSchema]

@classmethod
def _preprocess_sample(
cls, sample: dict[str, Any], index: int, generator: Generator
) -> dict[str, Any]:
sample["a"] = index
return sample


class SmallCollection(dy.Collection):
first: dy.LazyFrame[MyFirstSchema]

Expand Down Expand Up @@ -100,6 +117,22 @@ def test_sample_with_overrides() -> None:
assert collection.second.collect()["c"].to_list() == [3, 4, 6]


def test_sample_inline_with_overrides() -> None:
collection = MyInlinedCollection.sample(
overrides=[
{"b": 4, "second": [{"c": 3}, {"c": 4}]},
{"b": 8, "second": [{"c": 6}]},
]
)
assert collection.first.collect()["a"].to_list() == [0, 1]
assert collection.first.collect()["b"].to_list() == [4, 8]

assert collection.second is not None
assert collection.second.collect()["a"].to_list() == [0, 0, 1]
Comment thread
borchero marked this conversation as resolved.
assert collection.second.collect()["b"].to_list() != [4, 4, 8]
assert collection.second.collect()["c"].to_list() == [3, 4, 6]


@pytest.mark.parametrize("n", [0, 1000])
def test_sample_without_dependent_members(n: int) -> None:
collection = SmallCollection.sample(n)
Expand All @@ -125,3 +158,34 @@ def test_sample_no_common_primary_key() -> None:
def test_sample_no_overwrite() -> None:
with pytest.raises(ValueError, match=r"`_preprocess_sample` must be overwritten"):
IncompleteCollection.sample()


def test_invalid_inline_for_sampling() -> None:
with pytest.raises(ImplementationError, match=r"its primary key is a superset"):
create_collection_raw(
"test",
{
"first": dy.LazyFrame[MyFirstSchema],
"second": Annotated[
dy.LazyFrame[MySecondSchema],
dy.CollectionMember(inline_for_sampling=True),
],
},
)


def test_duplicate_column_inlined_for_sampling() -> None:
with pytest.raises(ImplementationError, match=r"clashes with a column name"):
create_collection_raw(
"test",
{
"first": Annotated[
dy.LazyFrame[MyFirstSchema],
dy.CollectionMember(inline_for_sampling=True),
],
"second": Annotated[
dy.LazyFrame[MyFirstSchema],
dy.CollectionMember(inline_for_sampling=True),
],
},
)
Loading