Skip to content

Commit 63a94ff

Browse files
committed
warn on non-index-aligned InstanceSort at PostgreSQL-backed API call sites
1 parent 5701853 commit 63a94ff

3 files changed

Lines changed: 164 additions & 5 deletions

File tree

cognite/client/_api/data_modeling/instances.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,7 @@ async def subscribe(
894894
>>> subscription_context.cancel()
895895
896896
"""
897+
query.warn_non_index_aligned_sorts()
897898
subscription_context = SubscriptionContext()
898899

899900
async def _poll_loop() -> None:
@@ -968,10 +969,11 @@ def _create_other_params(
968969
f"Received in `sources` argument for views: {with_properties}."
969970
)
970971
if sort:
971-
if isinstance(sort, (InstanceSort, dict)):
972-
other_params["sort"] = [cls._dump_instance_sort(sort)]
973-
else:
974-
other_params["sort"] = [cls._dump_instance_sort(s) for s in sort]
972+
sorts_seq = [sort] if isinstance(sort, (InstanceSort, dict)) else list(sort)
973+
for s in sorts_seq:
974+
if isinstance(s, InstanceSort):
975+
s.warn_if_not_index_aligned()
976+
other_params["sort"] = [cls._dump_instance_sort(s) for s in sorts_seq]
975977
if instance_type:
976978
other_params["instanceType"] = instance_type
977979
if debug:
@@ -1647,6 +1649,7 @@ async def query(
16471649
>>> res = client.data_modeling.instances.query(query, debug=debug_params)
16481650
>>> print(res.debug)
16491651
"""
1652+
query.warn_non_index_aligned_sorts()
16501653
return await self._query_or_sync(query, "query", include_typing=include_typing, debug=debug)
16511654

16521655
async def sync(
@@ -1749,6 +1752,7 @@ async def sync(
17491752
>>> res = client.data_modeling.instances.sync(query, debug=debug_params)
17501753
>>> print(res.debug)
17511754
"""
1755+
query.warn_non_index_aligned_sorts()
17521756
return await self._query_or_sync(query, "sync", include_typing=include_typing, debug=debug)
17531757

17541758
async def _query_or_sync(

cognite/client/_sync_api/data_modeling/instances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
===============================================================================
3-
c0af4f83cd8ffa0aaed6fa641bde9a98
3+
28cd64f8d9cdd590b9601c562f1b69a1
44
This file is auto-generated from the Async API modules, - do not edit manually!
55
===============================================================================
66
"""

tests/tests_unit/test_data_classes/test_data_models/test_instances.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from datetime import date, datetime
45
from typing import Any, cast
56

@@ -12,6 +13,7 @@
1213
EdgeApply,
1314
EdgeList,
1415
Float64,
16+
InstanceSort,
1517
Node,
1618
NodeApply,
1719
NodeId,
@@ -638,3 +640,156 @@ def test_to_pandas(self) -> None:
638640
df = info.to_pandas()
639641

640642
pd.testing.assert_frame_equal(df, expected)
643+
644+
645+
class TestInstanceSort:
646+
@pytest.mark.parametrize(
647+
"direction, expected_nulls_first",
648+
[
649+
("ascending", False),
650+
("descending", True),
651+
],
652+
)
653+
def test_nulls_first_auto_derived_from_direction(self, direction: str, expected_nulls_first: bool) -> None:
654+
sort = InstanceSort(["node", "externalId"], direction=direction) # type: ignore[arg-type]
655+
assert sort.nulls_first is expected_nulls_first
656+
657+
def test_no_warning_at_construction_for_any_combination(self) -> None:
658+
with warnings.catch_warnings(record=True) as w:
659+
warnings.simplefilter("always")
660+
InstanceSort(["node", "externalId"], direction="ascending")
661+
InstanceSort(["node", "externalId"], direction="descending")
662+
InstanceSort(["node", "externalId"], direction="ascending", nulls_first=False)
663+
InstanceSort(["node", "externalId"], direction="ascending", nulls_first=True)
664+
InstanceSort(["node", "externalId"], direction="descending", nulls_first=True)
665+
InstanceSort(["node", "externalId"], direction="descending", nulls_first=False)
666+
assert len(w) == 0
667+
668+
@pytest.mark.parametrize(
669+
"direction, nulls_first",
670+
[
671+
("ascending", True),
672+
("descending", False),
673+
],
674+
)
675+
def test_non_index_aligned_nulls_first_is_stored_as_given(self, direction: str, nulls_first: bool) -> None:
676+
sort = InstanceSort(["node", "externalId"], direction=direction, nulls_first=nulls_first) # type: ignore[arg-type]
677+
assert sort.nulls_first is nulls_first
678+
679+
def test_load_does_not_override_nulls_first(self) -> None:
680+
raw = {"property": ["node", "externalId"], "direction": "ascending", "nullsFirst": True}
681+
sort = InstanceSort._load(raw)
682+
assert sort.direction == "ascending"
683+
assert sort.nulls_first is True
684+
685+
def test_dump_round_trip(self) -> None:
686+
sort = InstanceSort(["space", "externalId"], direction="descending")
687+
dumped = sort.dump(camel_case=True)
688+
assert dumped["direction"] == "descending"
689+
assert dumped["nullsFirst"] is True
690+
691+
@pytest.mark.parametrize(
692+
"direction, nulls_first, expected",
693+
[
694+
("ascending", False, True),
695+
("descending", True, True),
696+
("ascending", True, False),
697+
("descending", False, False),
698+
],
699+
)
700+
def test_is_index_aligned(self, direction: str, nulls_first: bool, expected: bool) -> None:
701+
sort = InstanceSort(["node", "externalId"], direction=direction, nulls_first=nulls_first) # type: ignore[arg-type]
702+
assert sort.is_index_aligned is expected
703+
704+
def test_warn_if_not_index_aligned_fires(self) -> None:
705+
sort = InstanceSort(["node", "externalId"], direction="ascending", nulls_first=True)
706+
with warnings.catch_warnings(record=True) as w:
707+
warnings.simplefilter("always")
708+
sort.warn_if_not_index_aligned()
709+
assert len(w) == 1
710+
assert issubclass(w[0].category, UserWarning)
711+
assert "not index-aligned" in str(w[0].message)
712+
713+
def test_warn_if_not_index_aligned_silent_when_aligned(self) -> None:
714+
sort = InstanceSort(["node", "externalId"], direction="ascending")
715+
with warnings.catch_warnings(record=True) as w:
716+
warnings.simplefilter("always")
717+
sort.warn_if_not_index_aligned()
718+
assert len(w) == 0
719+
720+
@pytest.mark.parametrize("bad_direction", ["asc", "desc", "random"])
721+
def test_invalid_direction_raises(self, bad_direction: str) -> None:
722+
with pytest.raises(ValueError, match="direction must be"):
723+
InstanceSort(["node", "externalId"], direction=bad_direction) # type: ignore[arg-type]
724+
725+
def test_invalid_direction_error_shows_original_value(self) -> None:
726+
with pytest.raises(ValueError, match="'Asc'"):
727+
InstanceSort(["node", "externalId"], direction="Asc") # type: ignore[arg-type]
728+
729+
730+
class TestInstanceSortAPIWarning:
731+
@pytest.mark.parametrize(
732+
"direction, nulls_first",
733+
[
734+
("ascending", True),
735+
("descending", False),
736+
],
737+
)
738+
def test_warn_fires_for_non_index_aligned_sorts(self, direction: str, nulls_first: bool) -> None:
739+
sort = InstanceSort(["node", "externalId"], direction=direction, nulls_first=nulls_first) # type: ignore[arg-type]
740+
with warnings.catch_warnings(record=True) as w:
741+
warnings.simplefilter("always")
742+
sort.warn_if_not_index_aligned()
743+
assert len(w) == 1
744+
assert issubclass(w[0].category, UserWarning)
745+
assert "not index-aligned" in str(w[0].message)
746+
assert "PostgreSQL" in str(w[0].message)
747+
748+
@pytest.mark.parametrize(
749+
"direction, nulls_first",
750+
[
751+
("ascending", False),
752+
("descending", True),
753+
("ascending", None),
754+
("descending", None),
755+
],
756+
)
757+
def test_no_warn_for_aligned_or_auto_derived_sorts(self, direction: str, nulls_first: bool | None) -> None:
758+
sort = InstanceSort(["node", "externalId"], direction=direction, nulls_first=nulls_first) # type: ignore[arg-type]
759+
with warnings.catch_warnings(record=True) as w:
760+
warnings.simplefilter("always")
761+
sort.warn_if_not_index_aligned()
762+
assert len(w) == 0
763+
764+
def test_query__iter_sorts_collects_all_sorts(self) -> None:
765+
from cognite.client.data_classes.data_modeling import ViewId
766+
from cognite.client.data_classes.data_modeling.query import (
767+
NodeResultSetExpression,
768+
Query,
769+
Select,
770+
SourceSelector,
771+
)
772+
773+
view = ViewId("s", "v", "1")
774+
sort_a = InstanceSort(view.as_property_ref("a"))
775+
sort_b = InstanceSort(view.as_property_ref("b"))
776+
query = Query(
777+
with_={"nodes": NodeResultSetExpression(sort=[sort_a])},
778+
select={"nodes": Select([SourceSelector(view, ["a"])], sort=[sort_b])},
779+
)
780+
result = list(query._iter_sorts())
781+
assert sort_a in result
782+
assert sort_b in result
783+
assert len(result) == 2
784+
785+
def test_query_sync__iter_sorts_collects_backfill_sort(self) -> None:
786+
from cognite.client.data_classes.data_modeling.query import (
787+
NodeResultSetExpressionSync,
788+
QuerySync,
789+
)
790+
791+
sort = InstanceSort(["node", "externalId"])
792+
query = QuerySync(with_={"nodes": NodeResultSetExpressionSync(backfill_sort=[sort])}, select={})
793+
result = list(query._iter_sorts())
794+
assert sort in result
795+
assert len(result) == 1

0 commit comments

Comments
 (0)