Skip to content

Commit 2db85a4

Browse files
authored
fix(service): overload narrowing + serializer DEFAULT_TYPE_ENCODERS registry (#433)
## Summary - Restore async/sync service overload specificity for `paginate()` and `get_one()` so `schema_type` narrows the same way it does on the driver APIs. - Extract `DEFAULT_TYPE_ENCODERS` from `_normalize_supported_value` in `sqlspec.utils.serializers._json`. Adds IPv4/IPv6 + asyncpg `pgproto.UUID` coverage that the legacy isinstance chain never had; preserves sqlspec's strict `TypeError` on unmapped types and `Decimal -> float` precedent. - `SQLSpecPlugin` now merges `DEFAULT_TYPE_ENCODERS` into `AppConfig.type_encoders` (user-precedence) and registers Litestar-specific `type_decoders` for `numpy.ndarray` and `uuid_utils.UUID`. Legacy NumPy-only block removed; coverage is broader. - `OffsetPagination` now flows through the dataclass tail probe (no special case).
1 parent d85ace0 commit 2db85a4

19 files changed

Lines changed: 1074 additions & 154 deletions

File tree

docs/changelog.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@ SQLSpec Changelog
1010
Recent Updates
1111
==============
1212

13+
schema_dump wire_format Opt-Out (Unreleased)
14+
---------------------------------------------
15+
16+
**Added:**
17+
18+
* ``sqlspec.utils.serializers.schema_dump`` (and its helpers
19+
``serialize_collection`` / ``get_collection_serializer``) now accept a
20+
``wire_format: bool = True`` keyword. The default preserves existing output:
21+
``msgspec.Struct`` instances continue to emit wire-aligned names (honouring
22+
``rename=`` via ``field.encode_name``); Pydantic, dataclass, and attrs
23+
branches continue to emit Python attribute names. Pass ``wire_format=False``
24+
to opt the msgspec branch into Python attribute names (``field.name``) for
25+
cross-library consistency. The kwarg is a no-op for non-msgspec inputs.
26+
27+
* The internal serializer cache key now includes ``wire_format`` so that
28+
``True`` and ``False`` calls for the same Struct type cannot collide.
29+
1330
Schema Wire Correctness (Unreleased)
1431
-------------------------------------
1532

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ maintainers = [{ name = "Litestar Developers", email = "hello@litestar.dev" }]
2424
name = "sqlspec"
2525
readme = "README.md"
2626
requires-python = ">=3.10, <4.0"
27-
version = "0.45.0"
27+
version = "0.46.0"
2828

2929
[project.urls]
3030
Discord = "https://discord.gg/litestar"
@@ -264,7 +264,7 @@ opt_level = "3" # Maximum optimization (0-3)
264264
allow_dirty = true
265265
commit = false
266266
commit_args = "--no-verify"
267-
current_version = "0.45.0"
267+
current_version = "0.46.0"
268268
ignore_missing_files = false
269269
ignore_missing_version = false
270270
message = "chore(release): bump to v{new_version}"

sqlspec/adapters/oracledb/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def build_profile() -> "DriverParameterProfile":
800800
allow_mixed_parameter_styles=False,
801801
preserve_original_params_for_many=False,
802802
json_serializer_strategy="driver",
803-
custom_type_coercions={**build_uuid_coercions()},
803+
custom_type_coercions={**build_uuid_coercions(native=True)},
804804
default_dialect="oracle",
805805
)
806806

sqlspec/extensions/litestar/plugin.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from sqlspec.typing import NUMPY_INSTALLED, ConnectionT, PoolT, SchemaT
4040
from sqlspec.utils.correlation import CorrelationContext
4141
from sqlspec.utils.logging import get_logger, log_with_context
42-
from sqlspec.utils.serializers import numpy_array_dec_hook, numpy_array_enc_hook, numpy_array_predicate
42+
from sqlspec.utils.serializers import DEFAULT_TYPE_ENCODERS, numpy_array_dec_hook, numpy_array_predicate
4343

4444
if TYPE_CHECKING:
4545
from collections.abc import AsyncGenerator, Callable
@@ -179,6 +179,23 @@ def _build_correlation_headers(*, primary: str, configured: list[str], auto_trac
179179
return tuple(_dedupe_headers(header_order))
180180

181181

182+
def _build_litestar_type_decoders() -> "list[tuple[Callable[[Any], bool], Callable[[type, Any], Any]]]":
183+
"""Build the Litestar-specific ``type_decoders`` list.
184+
185+
Decoders are predicate-tuples consumed by Litestar's request-body parsing,
186+
not part of sqlspec's serializer registry — so they live here rather than
187+
in :data:`sqlspec.utils.serializers.DEFAULT_TYPE_ENCODERS`.
188+
"""
189+
decoders: list[tuple[Callable[[Any], bool], Callable[[type, Any], Any]]] = []
190+
if NUMPY_INSTALLED:
191+
decoders.append((numpy_array_predicate, numpy_array_dec_hook))
192+
with suppress(ImportError):
193+
import uuid_utils # pyright: ignore[reportMissingImports]
194+
195+
decoders.append((lambda t: t is uuid_utils.UUID, lambda t, v: t(str(v))))
196+
return decoders
197+
198+
182199
class CorrelationMiddleware:
183200
__slots__ = ("_app", "_headers")
184201

@@ -467,28 +484,20 @@ def store_sqlspec_in_state() -> None:
467484
existing_plugins = list(app_config.plugins or [])
468485
if not any(isinstance(p, _OffsetPaginationSchemaPlugin) for p in existing_plugins):
469486
existing_plugins.append(_OffsetPaginationSchemaPlugin())
470-
app_config.plugins = existing_plugins
487+
app_config.plugins = existing_plugins
471488

472489
if app_config.exception_handlers is None:
473490
app_config.exception_handlers = {}
474491
app_config.exception_handlers.setdefault(NotFoundError, not_found_error_handler)
475492

476-
if NUMPY_INSTALLED:
477-
import numpy as np
478-
479-
if app_config.type_encoders is None:
480-
app_config.type_encoders = {np.ndarray: numpy_array_enc_hook}
481-
else:
482-
encoders_dict = dict(app_config.type_encoders)
483-
encoders_dict[np.ndarray] = numpy_array_enc_hook
484-
app_config.type_encoders = encoders_dict
485-
486-
if app_config.type_decoders is None:
487-
app_config.type_decoders = [(numpy_array_predicate, numpy_array_dec_hook)]
488-
else:
489-
decoders_list = list(app_config.type_decoders)
490-
decoders_list.append((numpy_array_predicate, numpy_array_dec_hook))
491-
app_config.type_decoders = decoders_list
493+
# Inject sqlspec's DEFAULT_TYPE_ENCODERS into Litestar's response serializer
494+
# (user-supplied encoders win on conflict). Litestar's per-handler
495+
# resolve_type_encoders() merges these with route/controller/router-level
496+
# overrides automatically — no bidirectional thread needed.
497+
app_config.type_encoders = {**DEFAULT_TYPE_ENCODERS, **(app_config.type_encoders or {})}
498+
sqlspec_decoders = _build_litestar_type_decoders()
499+
if sqlspec_decoders:
500+
app_config.type_decoders = [*(app_config.type_decoders or []), *sqlspec_decoders]
492501

493502
if self._correlation_headers:
494503
middleware = DefineMiddleware(CorrelationMiddleware, headers=self._correlation_headers)

sqlspec/service.py

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Service base classes for SQLSpec application services."""
22

33
from contextlib import asynccontextmanager, contextmanager
4-
from typing import TYPE_CHECKING, Any, Generic, cast
4+
from typing import TYPE_CHECKING, Any, Generic, cast, overload
55

66
from typing_extensions import TypeVar
77

@@ -51,6 +51,28 @@ def driver(self) -> AsyncDriverT:
5151
"""Alias for :attr:`session` matching the recipe-doc terminology."""
5252
return self._session
5353

54+
@overload
55+
async def paginate(
56+
self,
57+
statement: "Statement | QueryBuilder",
58+
/,
59+
*parameters: "StatementParameters | StatementFilter",
60+
schema_type: "type[SchemaT]",
61+
count_with_window: bool = False,
62+
**kwargs: Any,
63+
) -> OffsetPagination[SchemaT]: ...
64+
65+
@overload
66+
async def paginate(
67+
self,
68+
statement: "Statement | QueryBuilder",
69+
/,
70+
*parameters: "StatementParameters | StatementFilter",
71+
schema_type: None = None,
72+
count_with_window: bool = False,
73+
**kwargs: Any,
74+
) -> OffsetPagination[dict[str, Any]]: ...
75+
5476
async def paginate(
5577
self,
5678
statement: "Statement | QueryBuilder",
@@ -59,7 +81,7 @@ async def paginate(
5981
schema_type: "type[SchemaT] | None" = None,
6082
count_with_window: bool = False,
6183
**kwargs: Any,
62-
) -> OffsetPagination[SchemaT]:
84+
) -> "OffsetPagination[SchemaT] | OffsetPagination[dict[str, Any]]":
6385
"""Execute a paginated query and return an OffsetPagination container.
6486
6587
Args:
@@ -78,13 +100,54 @@ async def paginate(
78100
statement, *parameters, schema_type=schema_type, count_with_window=count_with_window, **kwargs
79101
)
80102

103+
if schema_type is None:
104+
return OffsetPagination(
105+
items=cast("list[dict[str, Any]]", items),
106+
limit=limit_offset.limit if limit_offset is not None else len(items),
107+
offset=limit_offset.offset if limit_offset is not None else 0,
108+
total=total,
109+
)
110+
81111
return OffsetPagination(
82112
items=cast("list[SchemaT]", items),
83113
limit=limit_offset.limit if limit_offset is not None else len(items),
84114
offset=limit_offset.offset if limit_offset is not None else 0,
85115
total=total,
86116
)
87117

118+
@overload
119+
async def get_one(
120+
self,
121+
statement: "Statement | QueryBuilder",
122+
/,
123+
*parameters: "StatementParameters | StatementFilter",
124+
schema_type: "type[SchemaT]",
125+
error_message: str | None = None,
126+
**kwargs: Any,
127+
) -> SchemaT: ...
128+
129+
@overload
130+
async def get_one(
131+
self,
132+
statement: "Statement | QueryBuilder",
133+
/,
134+
*parameters: "StatementParameters | StatementFilter",
135+
schema_type: None = None,
136+
error_message: str | None = None,
137+
**kwargs: Any,
138+
) -> dict[str, Any]: ...
139+
140+
@overload
141+
async def get_one(
142+
self,
143+
statement: "Statement | QueryBuilder",
144+
/,
145+
*parameters: "StatementParameters | StatementFilter",
146+
schema_type: "type[SchemaT] | None" = None,
147+
error_message: str | None = None,
148+
**kwargs: Any,
149+
) -> "SchemaT | dict[str, Any]": ...
150+
88151
async def get_one(
89152
self,
90153
statement: "Statement | QueryBuilder",
@@ -191,6 +254,28 @@ def driver(self) -> SyncDriverT:
191254
"""Alias for :attr:`session` matching the recipe-doc terminology."""
192255
return self._session
193256

257+
@overload
258+
def paginate(
259+
self,
260+
statement: "Statement | QueryBuilder",
261+
/,
262+
*parameters: "StatementParameters | StatementFilter",
263+
schema_type: "type[SchemaT]",
264+
count_with_window: bool = False,
265+
**kwargs: Any,
266+
) -> OffsetPagination[SchemaT]: ...
267+
268+
@overload
269+
def paginate(
270+
self,
271+
statement: "Statement | QueryBuilder",
272+
/,
273+
*parameters: "StatementParameters | StatementFilter",
274+
schema_type: None = None,
275+
count_with_window: bool = False,
276+
**kwargs: Any,
277+
) -> OffsetPagination[dict[str, Any]]: ...
278+
194279
def paginate(
195280
self,
196281
statement: "Statement | QueryBuilder",
@@ -199,7 +284,7 @@ def paginate(
199284
schema_type: "type[SchemaT] | None" = None,
200285
count_with_window: bool = False,
201286
**kwargs: Any,
202-
) -> OffsetPagination[SchemaT]:
287+
) -> "OffsetPagination[SchemaT] | OffsetPagination[dict[str, Any]]":
203288
"""Execute a paginated query and return an OffsetPagination container.
204289
205290
Args:
@@ -218,13 +303,54 @@ def paginate(
218303
statement, *parameters, schema_type=schema_type, count_with_window=count_with_window, **kwargs
219304
)
220305

306+
if schema_type is None:
307+
return OffsetPagination(
308+
items=cast("list[dict[str, Any]]", items),
309+
limit=limit_offset.limit if limit_offset is not None else len(items),
310+
offset=limit_offset.offset if limit_offset is not None else 0,
311+
total=total,
312+
)
313+
221314
return OffsetPagination(
222315
items=cast("list[SchemaT]", items),
223316
limit=limit_offset.limit if limit_offset is not None else len(items),
224317
offset=limit_offset.offset if limit_offset is not None else 0,
225318
total=total,
226319
)
227320

321+
@overload
322+
def get_one(
323+
self,
324+
statement: "Statement | QueryBuilder",
325+
/,
326+
*parameters: "StatementParameters | StatementFilter",
327+
schema_type: "type[SchemaT]",
328+
error_message: str | None = None,
329+
**kwargs: Any,
330+
) -> SchemaT: ...
331+
332+
@overload
333+
def get_one(
334+
self,
335+
statement: "Statement | QueryBuilder",
336+
/,
337+
*parameters: "StatementParameters | StatementFilter",
338+
schema_type: None = None,
339+
error_message: str | None = None,
340+
**kwargs: Any,
341+
) -> dict[str, Any]: ...
342+
343+
@overload
344+
def get_one(
345+
self,
346+
statement: "Statement | QueryBuilder",
347+
/,
348+
*parameters: "StatementParameters | StatementFilter",
349+
schema_type: "type[SchemaT] | None" = None,
350+
error_message: str | None = None,
351+
**kwargs: Any,
352+
) -> "SchemaT | dict[str, Any]": ...
353+
228354
def get_one(
229355
self,
230356
statement: "Statement | QueryBuilder",

0 commit comments

Comments
 (0)