From d7449cbb62b976ed00181d4218b653641cfab562 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Sat, 19 Apr 2025 22:05:32 +0200 Subject: [PATCH 1/4] feat: Allow to inline collection member columns for sampling --- dataframely/_base_collection.py | 31 ++++++++++++++++ dataframely/collection.py | 11 ++++-- tests/collection/test_sample.py | 64 +++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) diff --git a/dataframely/_base_collection.py b/dataframely/_base_collection.py index 5d2ae12..5c76bb3 100644 --- a/dataframely/_base_collection.py +++ b/dataframely/_base_collection.py @@ -48,6 +48,12 @@ def my_filter(self) -> pl.DataFrame: #: Whether the member should be ignored in the filter method. ignored_in_filters: bool = False + #: Whether the member's non-primary key columns should be inlined for sampling. + #: This means that value overrides are supplied on the top-level rather than in + #: a subkey with the member's name. Only valid if the member's primary key matches + #: the collection's common primary key. Two members that share common column names + #: may not both be inlined for sampling. + inline_for_sampling: bool = False # --------------------------------------- UTILS -------------------------------------- # @@ -136,6 +142,30 @@ def __new__( f"{len(intersection)} such filters: {sorted(intersection)}." ) + # 3) Check that inlining for sampling is configured correctly. + if len(non_ignored_member_schemas) > 0: + common_primary_keys = _common_primary_keys(non_ignored_member_schemas) + inlined_columns = set() + for member, info in result.members.items(): + if info.inline_for_sampling: + if set(info.schema.primary_keys()) != common_primary_keys: + raise ImplementationError( + f"Member '{member}' is inlined for sampling but its primary " + "key is a superset of the common primary key. Such a member " + "must not be inlined to be able to provide multiple values " + "for a single combination of the common primary key." + ) + non_primary_key_columns = ( + set(info.schema.column_names()) - common_primary_keys + ) + if len(inlined_columns & non_primary_key_columns): + raise ImplementationError( + f"At least one column name of member '{member}' clashes " + "with a column name of another member that is inlined for " + "sampling." + ) + inlined_columns.update(non_primary_key_columns) + return super().__new__(mcs, name, bases, namespace, *args, **kwargs) @staticmethod @@ -201,6 +231,7 @@ def _get_metadata(source: dict[str, Any]) -> Metadata: schema=get_args(kls)[0], is_optional=False, ignored_in_filters=collection_member.ignored_in_filters, + inline_for_sampling=collection_member.inline_for_sampling, ) else: # Some other unknown annotation diff --git a/dataframely/collection.py b/dataframely/collection.py index f80e889..e86b88b 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -123,7 +123,10 @@ def sample( ... } - _Any_ member/value can be left out and will be sampled automatically. + *Any* member/value can be left out and will be sampled automatically. + Note that overrides for columns of members that are annotated with + ``inline_for_sampling=True`` can be supplied on the top-level instead + of in a nested dictionary. generator: The (seeded) generator to use for sampling data. If ``None``, a generator with random seed is automatically created. @@ -198,7 +201,11 @@ def sample( else _extract_keys_if_exist(sample, primary_keys) ), **_extract_keys_if_exist( - sample[member] if member in sample else {}, + ( + sample + if member_infos[member].inline_for_sampling + else (sample[member] if member in sample else {}) + ), schema.column_names(), ), } diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index e439b83..1acce4a 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -6,7 +6,10 @@ import pytest import dataframely as dy +from dataframely.exc import ImplementationError from dataframely.random import Generator +from dataframely.testing import create_collection +from dataframely.testing.factory import create_collection_raw class MyFirstSchema(dy.Schema): @@ -37,6 +40,21 @@ def _preprocess_sample( return sample +class MyInlinedCollection(dy.Collection): + first: Annotated[ + dy.LazyFrame[MyFirstSchema], + dy.CollectionMember(inline_for_sampling=True), + ] + second: dy.LazyFrame[MySecondSchema] + + @classmethod + def _preprocess_sample( + cls, sample: dict[str, Any], index: int, generator: Generator + ) -> dict[str, Any]: + sample["a"] = index + return sample + + class SmallCollection(dy.Collection): first: dy.LazyFrame[MyFirstSchema] @@ -100,6 +118,21 @@ def test_sample_with_overrides(): assert collection.second.collect()["c"].to_list() == [3, 4, 6] +def test_sample_inline_with_overrides(): + collection = MyInlinedCollection.sample( + overrides=[ + {"b": 4, "second": [{"c": 3}, {"c": 4}]}, + {"b": 8, "second": [{"c": 6}]}, + ] + ) + assert collection.first.collect()["a"].to_list() == [0, 1] + assert collection.first.collect()["b"].to_list() == [4, 8] + + assert collection.second is not None + assert collection.second.collect()["a"].to_list() == [0, 0, 1] + assert collection.second.collect()["c"].to_list() == [3, 4, 6] + + @pytest.mark.parametrize("n", [0, 1000]) def test_sample_without_dependent_members(n: int): collection = SmallCollection.sample(n) @@ -125,3 +158,34 @@ def test_sample_no_common_primary_key(): def test_sample_no_overwrite(): with pytest.raises(ValueError, match=r"`_preprocess_sample` must be overwritten"): IncompleteCollection.sample() + + +def test_invalid_inline_for_sampling(): + with pytest.raises(ImplementationError, match=r"its primary key is a superset"): + create_collection_raw( + "test", + { + "first": dy.LazyFrame[MyFirstSchema], + "second": Annotated[ + dy.LazyFrame[MySecondSchema], + dy.CollectionMember(inline_for_sampling=True), + ], + }, + ) + + +def test_duplicate_column_inlined_for_sampling(): + with pytest.raises(ImplementationError, match=r"clashes with a column name"): + create_collection_raw( + "test", + { + "first": Annotated[ + dy.LazyFrame[MyFirstSchema], + dy.CollectionMember(inline_for_sampling=True), + ], + "second": Annotated[ + dy.LazyFrame[MyFirstSchema], + dy.CollectionMember(inline_for_sampling=True), + ], + }, + ) From 66f45e74543eef3bfc126c59ff9dd7cc451cee99 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Sat, 19 Apr 2025 22:17:09 +0200 Subject: [PATCH 2/4] Fix pre-commit --- dataframely/_base_collection.py | 2 +- tests/collection/test_sample.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/dataframely/_base_collection.py b/dataframely/_base_collection.py index 5c76bb3..36082fc 100644 --- a/dataframely/_base_collection.py +++ b/dataframely/_base_collection.py @@ -145,7 +145,7 @@ def __new__( # 3) Check that inlining for sampling is configured correctly. if len(non_ignored_member_schemas) > 0: common_primary_keys = _common_primary_keys(non_ignored_member_schemas) - inlined_columns = set() + inlined_columns: set[str] = set() for member, info in result.members.items(): if info.inline_for_sampling: if set(info.schema.primary_keys()) != common_primary_keys: diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index 1acce4a..de8e277 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -8,7 +8,6 @@ import dataframely as dy from dataframely.exc import ImplementationError from dataframely.random import Generator -from dataframely.testing import create_collection from dataframely.testing.factory import create_collection_raw From c4f135b0b12bf100a7e85c68c1a0c4e85550390b Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 22 Apr 2025 12:05:18 +0200 Subject: [PATCH 3/4] Add test --- tests/collection/test_sample.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index de8e277..4477a69 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -129,6 +129,7 @@ def test_sample_inline_with_overrides(): assert collection.second is not None assert collection.second.collect()["a"].to_list() == [0, 0, 1] + assert collection.second.collect()["b"].to_list() != [4, 4, 8] assert collection.second.collect()["c"].to_list() == [3, 4, 6] From ba12c24d895c06b8a9880f2ce1cae1d235d8e1c1 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 22 Apr 2025 12:10:29 +0200 Subject: [PATCH 4/4] Add missing type annotations --- tests/collection/test_sample.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index a0c15e3..42d388f 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -117,7 +117,7 @@ def test_sample_with_overrides() -> None: assert collection.second.collect()["c"].to_list() == [3, 4, 6] -def test_sample_inline_with_overrides(): +def test_sample_inline_with_overrides() -> None: collection = MyInlinedCollection.sample( overrides=[ {"b": 4, "second": [{"c": 3}, {"c": 4}]}, @@ -160,7 +160,7 @@ def test_sample_no_overwrite() -> None: IncompleteCollection.sample() -def test_invalid_inline_for_sampling(): +def test_invalid_inline_for_sampling() -> None: with pytest.raises(ImplementationError, match=r"its primary key is a superset"): create_collection_raw( "test", @@ -174,7 +174,7 @@ def test_invalid_inline_for_sampling(): ) -def test_duplicate_column_inlined_for_sampling(): +def test_duplicate_column_inlined_for_sampling() -> None: with pytest.raises(ImplementationError, match=r"clashes with a column name"): create_collection_raw( "test",