Skip to content

Commit 0e39b76

Browse files
committed
feat(query): add _iter_sorts traversal and _prepare_sorts to Query(Sync) class hierarchy
1 parent 8f0d7f5 commit 0e39b76

1 file changed

Lines changed: 30 additions & 1 deletion

File tree

  • cognite/client/data_classes/data_modeling

cognite/client/data_classes/data_modeling/query.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC
44
from collections import UserDict
5-
from collections.abc import Mapping, MutableMapping, Sequence
5+
from collections.abc import Iterator, Mapping, MutableMapping, Sequence
66
from dataclasses import dataclass, field
77
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
88

@@ -85,6 +85,9 @@ def _load_list(
8585
class SelectBase(CogniteResource, ABC):
8686
sources: list[SourceSelector] = field(default_factory=list)
8787

88+
def _iter_sorts(self) -> Iterator[InstanceSort]:
89+
yield from ()
90+
8891
def dump(self, camel_case: bool = True) -> dict[str, Any]:
8992
output: dict[str, Any] = {}
9093
if self.sources:
@@ -133,6 +136,9 @@ class Select(SelectBase):
133136
sort: list[InstanceSort] = field(default_factory=list)
134137
limit: int | None = None
135138

139+
def _iter_sorts(self) -> Iterator[InstanceSort]:
140+
yield from self.sort
141+
136142
def dump(self, camel_case: bool = True) -> dict[str, Any]:
137143
output = super().dump(camel_case)
138144
if self.sort:
@@ -179,6 +185,16 @@ def instance_type_by_result_expression(self) -> dict[str, type[NodeListWithCurso
179185
for k, v in self.with_.items()
180186
}
181187

188+
def _iter_sorts(self) -> Iterator[InstanceSort]:
189+
for expr in self.with_.values():
190+
yield from expr._iter_sorts()
191+
for sel in self.select.values():
192+
yield from sel._iter_sorts()
193+
194+
def _prepare_sorts(self) -> None:
195+
for sort in self._iter_sorts():
196+
sort._apply_postgres_defaults_or_maybe_warn()
197+
182198
def dump(self, camel_case: bool = True) -> dict[str, Any]:
183199
output: dict[str, Any] = {
184200
"with": {k: v.dump(camel_case) for k, v in self.with_.items()},
@@ -281,6 +297,9 @@ class ResultSetExpressionBase(CogniteResource, ABC):
281297
def _load_sort(resource: dict[str, Any], name: str) -> list[InstanceSort]:
282298
return [InstanceSort.load(sort) for sort in resource.get(name, [])]
283299

300+
def _iter_sorts(self) -> Iterator[InstanceSort]:
301+
yield from ()
302+
284303
@staticmethod
285304
def _init_through(through: list[str] | tuple[str, str, str] | PropertyId | None) -> PropertyId | None:
286305
def error() -> Never:
@@ -336,6 +355,9 @@ def __eq__(self, other: object) -> bool:
336355
return NotImplemented
337356
return type(self) is type(other) and self.dump() == other.dump()
338357

358+
def _iter_sorts(self) -> Iterator[InstanceSort]:
359+
yield from self.sort
360+
339361

340362
@dataclass(eq=False) # Prevents @dataclass from generating its own __eq__, so the parent's is used
341363
class NodeResultSetExpression(NodeOrEdgeResultSetExpression):
@@ -429,6 +451,10 @@ class EdgeResultSetExpression(NodeOrEdgeResultSetExpression):
429451
limit_each: int | None = None
430452
post_sort: list[InstanceSort] = field(default_factory=list)
431453

454+
def _iter_sorts(self) -> Iterator[InstanceSort]:
455+
yield from self.sort
456+
yield from self.post_sort
457+
432458
@classmethod
433459
def _load(cls, resource: dict[str, Any]) -> Self:
434460
query_edge = resource["edges"]
@@ -494,6 +520,9 @@ def __eq__(self, other: object) -> bool:
494520
return NotImplemented
495521
return type(self) is type(other) and self.dump() == other.dump()
496522

523+
def _iter_sorts(self) -> Iterator[InstanceSort]:
524+
yield from self.backfill_sort
525+
497526
@classmethod
498527
def _load(cls, resource: dict[str, Any]) -> ResultSetExpressionSync:
499528
if "nodes" in resource:

0 commit comments

Comments
 (0)