|
2 | 2 |
|
3 | 3 | from abc import ABC |
4 | 4 | from collections import UserDict |
5 | | -from collections.abc import Mapping, MutableMapping, Sequence |
| 5 | +from collections.abc import Iterator, Mapping, MutableMapping, Sequence |
6 | 6 | from dataclasses import dataclass, field |
7 | 7 | from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar |
8 | 8 |
|
@@ -85,6 +85,9 @@ def _load_list( |
85 | 85 | class SelectBase(CogniteResource, ABC): |
86 | 86 | sources: list[SourceSelector] = field(default_factory=list) |
87 | 87 |
|
| 88 | + def _iter_sorts(self) -> Iterator[InstanceSort]: |
| 89 | + yield from () |
| 90 | + |
88 | 91 | def dump(self, camel_case: bool = True) -> dict[str, Any]: |
89 | 92 | output: dict[str, Any] = {} |
90 | 93 | if self.sources: |
@@ -133,6 +136,9 @@ class Select(SelectBase): |
133 | 136 | sort: list[InstanceSort] = field(default_factory=list) |
134 | 137 | limit: int | None = None |
135 | 138 |
|
| 139 | + def _iter_sorts(self) -> Iterator[InstanceSort]: |
| 140 | + yield from self.sort |
| 141 | + |
136 | 142 | def dump(self, camel_case: bool = True) -> dict[str, Any]: |
137 | 143 | output = super().dump(camel_case) |
138 | 144 | if self.sort: |
@@ -179,6 +185,16 @@ def instance_type_by_result_expression(self) -> dict[str, type[NodeListWithCurso |
179 | 185 | for k, v in self.with_.items() |
180 | 186 | } |
181 | 187 |
|
| 188 | + def _iter_sorts(self) -> Iterator[InstanceSort]: |
| 189 | + for expr in self.with_.values(): |
| 190 | + yield from expr._iter_sorts() |
| 191 | + for sel in self.select.values(): |
| 192 | + yield from sel._iter_sorts() |
| 193 | + |
| 194 | + def _prepare_sorts(self) -> None: |
| 195 | + for sort in self._iter_sorts(): |
| 196 | + sort._apply_postgres_defaults_or_maybe_warn() |
| 197 | + |
182 | 198 | def dump(self, camel_case: bool = True) -> dict[str, Any]: |
183 | 199 | output: dict[str, Any] = { |
184 | 200 | "with": {k: v.dump(camel_case) for k, v in self.with_.items()}, |
@@ -281,6 +297,9 @@ class ResultSetExpressionBase(CogniteResource, ABC): |
281 | 297 | def _load_sort(resource: dict[str, Any], name: str) -> list[InstanceSort]: |
282 | 298 | return [InstanceSort.load(sort) for sort in resource.get(name, [])] |
283 | 299 |
|
| 300 | + def _iter_sorts(self) -> Iterator[InstanceSort]: |
| 301 | + yield from () |
| 302 | + |
284 | 303 | @staticmethod |
285 | 304 | def _init_through(through: list[str] | tuple[str, str, str] | PropertyId | None) -> PropertyId | None: |
286 | 305 | def error() -> Never: |
@@ -336,6 +355,9 @@ def __eq__(self, other: object) -> bool: |
336 | 355 | return NotImplemented |
337 | 356 | return type(self) is type(other) and self.dump() == other.dump() |
338 | 357 |
|
| 358 | + def _iter_sorts(self) -> Iterator[InstanceSort]: |
| 359 | + yield from self.sort |
| 360 | + |
339 | 361 |
|
340 | 362 | @dataclass(eq=False) # Prevents @dataclass from generating its own __eq__, so the parent's is used |
341 | 363 | class NodeResultSetExpression(NodeOrEdgeResultSetExpression): |
@@ -429,6 +451,10 @@ class EdgeResultSetExpression(NodeOrEdgeResultSetExpression): |
429 | 451 | limit_each: int | None = None |
430 | 452 | post_sort: list[InstanceSort] = field(default_factory=list) |
431 | 453 |
|
| 454 | + def _iter_sorts(self) -> Iterator[InstanceSort]: |
| 455 | + yield from self.sort |
| 456 | + yield from self.post_sort |
| 457 | + |
432 | 458 | @classmethod |
433 | 459 | def _load(cls, resource: dict[str, Any]) -> Self: |
434 | 460 | query_edge = resource["edges"] |
@@ -494,6 +520,9 @@ def __eq__(self, other: object) -> bool: |
494 | 520 | return NotImplemented |
495 | 521 | return type(self) is type(other) and self.dump() == other.dump() |
496 | 522 |
|
| 523 | + def _iter_sorts(self) -> Iterator[InstanceSort]: |
| 524 | + yield from self.backfill_sort |
| 525 | + |
497 | 526 | @classmethod |
498 | 527 | def _load(cls, resource: dict[str, Any]) -> ResultSetExpressionSync: |
499 | 528 | if "nodes" in resource: |
|
0 commit comments