diff --git a/dataframely/_base_collection.py b/dataframely/_base_collection.py index a8ecee2..eddb1e9 100644 --- a/dataframely/_base_collection.py +++ b/dataframely/_base_collection.py @@ -48,6 +48,12 @@ def my_filter(self) -> pl.DataFrame: #: Whether the member should be ignored in the filter method. ignored_in_filters: bool = False + #: Whether the member's non-primary key columns should be inlined for sampling. + #: This means that value overrides are supplied on the top-level rather than in + #: a subkey with the member's name. Only valid if the member's primary key matches + #: the collection's common primary key. Two members that share common column names + #: may not both be inlined for sampling. + inline_for_sampling: bool = False # --------------------------------------- UTILS -------------------------------------- # @@ -136,6 +142,30 @@ def __new__( f"{len(intersection)} such filters: {sorted(intersection)}." ) + # 3) Check that inlining for sampling is configured correctly. + if len(non_ignored_member_schemas) > 0: + common_primary_keys = _common_primary_keys(non_ignored_member_schemas) + inlined_columns: set[str] = set() + for member, info in result.members.items(): + if info.inline_for_sampling: + if set(info.schema.primary_keys()) != common_primary_keys: + raise ImplementationError( + f"Member '{member}' is inlined for sampling but its primary " + "key is a superset of the common primary key. Such a member " + "must not be inlined to be able to provide multiple values " + "for a single combination of the common primary key." + ) + non_primary_key_columns = ( + set(info.schema.column_names()) - common_primary_keys + ) + if len(inlined_columns & non_primary_key_columns): + raise ImplementationError( + f"At least one column name of member '{member}' clashes " + "with a column name of another member that is inlined for " + "sampling." + ) + inlined_columns.update(non_primary_key_columns) + return super().__new__(mcs, name, bases, namespace, *args, **kwargs) @staticmethod @@ -201,6 +231,7 @@ def _get_metadata(source: dict[str, Any]) -> Metadata: schema=get_args(kls)[0], is_optional=False, ignored_in_filters=collection_member.ignored_in_filters, + inline_for_sampling=collection_member.inline_for_sampling, ) else: # Some other unknown annotation diff --git a/dataframely/collection.py b/dataframely/collection.py index 293e592..70d6595 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -123,7 +123,10 @@ def sample( ... } - _Any_ member/value can be left out and will be sampled automatically. + *Any* member/value can be left out and will be sampled automatically. + Note that overrides for columns of members that are annotated with + ``inline_for_sampling=True`` can be supplied on the top-level instead + of in a nested dictionary. generator: The (seeded) generator to use for sampling data. If ``None``, a generator with random seed is automatically created. @@ -198,7 +201,11 @@ def sample( else _extract_keys_if_exist(sample, primary_keys) ), **_extract_keys_if_exist( - sample[member] if member in sample else {}, + ( + sample + if member_infos[member].inline_for_sampling + else (sample[member] if member in sample else {}) + ), schema.column_names(), ), } diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index 4034387..42d388f 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -6,7 +6,9 @@ import pytest import dataframely as dy +from dataframely.exc import ImplementationError from dataframely.random import Generator +from dataframely.testing.factory import create_collection_raw class MyFirstSchema(dy.Schema): @@ -37,6 +39,21 @@ def _preprocess_sample( return sample +class MyInlinedCollection(dy.Collection): + first: Annotated[ + dy.LazyFrame[MyFirstSchema], + dy.CollectionMember(inline_for_sampling=True), + ] + second: dy.LazyFrame[MySecondSchema] + + @classmethod + def _preprocess_sample( + cls, sample: dict[str, Any], index: int, generator: Generator + ) -> dict[str, Any]: + sample["a"] = index + return sample + + class SmallCollection(dy.Collection): first: dy.LazyFrame[MyFirstSchema] @@ -100,6 +117,22 @@ def test_sample_with_overrides() -> None: assert collection.second.collect()["c"].to_list() == [3, 4, 6] +def test_sample_inline_with_overrides() -> None: + collection = MyInlinedCollection.sample( + overrides=[ + {"b": 4, "second": [{"c": 3}, {"c": 4}]}, + {"b": 8, "second": [{"c": 6}]}, + ] + ) + assert collection.first.collect()["a"].to_list() == [0, 1] + assert collection.first.collect()["b"].to_list() == [4, 8] + + assert collection.second is not None + assert collection.second.collect()["a"].to_list() == [0, 0, 1] + assert collection.second.collect()["b"].to_list() != [4, 4, 8] + assert collection.second.collect()["c"].to_list() == [3, 4, 6] + + @pytest.mark.parametrize("n", [0, 1000]) def test_sample_without_dependent_members(n: int) -> None: collection = SmallCollection.sample(n) @@ -125,3 +158,34 @@ def test_sample_no_common_primary_key() -> None: def test_sample_no_overwrite() -> None: with pytest.raises(ValueError, match=r"`_preprocess_sample` must be overwritten"): IncompleteCollection.sample() + + +def test_invalid_inline_for_sampling() -> None: + with pytest.raises(ImplementationError, match=r"its primary key is a superset"): + create_collection_raw( + "test", + { + "first": dy.LazyFrame[MyFirstSchema], + "second": Annotated[ + dy.LazyFrame[MySecondSchema], + dy.CollectionMember(inline_for_sampling=True), + ], + }, + ) + + +def test_duplicate_column_inlined_for_sampling() -> None: + with pytest.raises(ImplementationError, match=r"clashes with a column name"): + create_collection_raw( + "test", + { + "first": Annotated[ + dy.LazyFrame[MyFirstSchema], + dy.CollectionMember(inline_for_sampling=True), + ], + "second": Annotated[ + dy.LazyFrame[MyFirstSchema], + dy.CollectionMember(inline_for_sampling=True), + ], + }, + )