Skip to content

Commit ea191bd

Browse files
authored
fix(litestar): unique provider signatures stop filter cross-binding (#435) (#436)
## Summary - Filter providers in `create_filter_dependencies()` shared Python parameter names (`values`, `is_null`, `is_not_null`, `before`, `after`) across distinct `Provide()` instances. Litestar collapses bindings by parameter name across providers, ignoring per-instance `Parameter(query=...)` aliases — so siblings in the same family cross-bound (silent value bleed when types matched, `400 Invalid UUID` when they didn't). Reported as #435. - Each affected closure (`in_fields`, `not_in_fields`, `null_fields`, `not_null_fields`, `created_at`+`updated_at`) is now built via small helpers that synthesize a unique `__signature__` per field (`{field}_values`, `{field}_is_null`, `{field}_is_not_null`, `{field}_before`/`{field}_after`). - FastAPI extension was already correct — `Depends()` creates isolated sub-dependency scopes — confirmed with a new mixed-type regression test. Closes #435.
1 parent 2db85a4 commit ea191bd

6 files changed

Lines changed: 320 additions & 91 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ maintainers = [{ name = "Litestar Developers", email = "hello@litestar.dev" }]
2424
name = "sqlspec"
2525
readme = "README.md"
2626
requires-python = ">=3.10, <4.0"
27-
version = "0.46.0"
27+
version = "0.46.1"
2828

2929
[project.urls]
3030
Discord = "https://discord.gg/litestar"
@@ -264,7 +264,7 @@ opt_level = "3" # Maximum optimization (0-3)
264264
allow_dirty = true
265265
commit = false
266266
commit_args = "--no-verify"
267-
current_version = "0.46.0"
267+
current_version = "0.46.1"
268268
ignore_missing_files = false
269269
ignore_missing_version = false
270270
message = "chore(release): bump to v{new_version}"

sqlspec/extensions/litestar/providers.py

Lines changed: 117 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,103 @@ def _resolve_sort_fields(sort_field: SortField) -> tuple[str, set[str]]:
171171
return fields[0], set(fields)
172172

173173

174-
def _create_statement_filters( # noqa: C901
174+
def _build_in_collection_provider(field: FieldNameType, *, negated: bool) -> Callable[..., Any]:
175+
"""Build a per-field `IN` / `NOT IN` filter provider with a unique parameter name.
176+
177+
Litestar's dependency resolver collapses parameters across distinct ``Provide()``
178+
instances by Python parameter name, ignoring per-instance ``Parameter(query=...)``
179+
aliases. Giving each provider a unique parameter name (via ``__signature__``)
180+
prevents siblings in the same family from cross-binding (issue #435).
181+
"""
182+
type_hint = field.type_hint
183+
field_name = field.name
184+
param_name = f"{field_name}_values"
185+
alias = camelize(f"{field_name}_{'not_in' if negated else 'in'}")
186+
filter_cls: Any = NotInCollectionFilter if negated else InCollectionFilter
187+
annotation = list[type_hint] | None # type: ignore[valid-type]
188+
return_annotation = filter_cls[type_hint] | None
189+
190+
def provide(**kwargs: Any) -> Any:
191+
values = kwargs.get(param_name)
192+
return filter_cls[type_hint](field_name=field_name, values=values) if values else None
193+
194+
provide.__signature__ = inspect.Signature( # type: ignore[attr-defined]
195+
parameters=[
196+
inspect.Parameter(
197+
param_name,
198+
kind=inspect.Parameter.KEYWORD_ONLY,
199+
default=Parameter(query=alias, default=None, required=False),
200+
annotation=annotation,
201+
)
202+
],
203+
return_annotation=return_annotation,
204+
)
205+
provide.__annotations__ = {param_name: annotation, "return": return_annotation}
206+
return provide
207+
208+
209+
def _build_null_provider(field_name: str, *, negated: bool) -> Callable[..., Any]:
210+
"""Build a per-field ``IS NULL`` / ``IS NOT NULL`` provider with a unique parameter name (issue #435)."""
211+
suffix = "is_not_null" if negated else "is_null"
212+
param_name = f"{field_name}_{suffix}"
213+
alias = camelize(f"{field_name}_{suffix}")
214+
filter_cls: type[Any] = NotNullFilter if negated else NullFilter
215+
annotation = bool | None
216+
return_annotation = filter_cls | None
217+
218+
def provide(**kwargs: Any) -> Any:
219+
flag = kwargs.get(param_name)
220+
return filter_cls(field_name=field_name) if flag else None
221+
222+
provide.__signature__ = inspect.Signature( # type: ignore[attr-defined]
223+
parameters=[
224+
inspect.Parameter(
225+
param_name,
226+
kind=inspect.Parameter.KEYWORD_ONLY,
227+
default=Parameter(query=alias, default=None, required=False),
228+
annotation=annotation,
229+
)
230+
],
231+
return_annotation=return_annotation,
232+
)
233+
provide.__annotations__ = {param_name: annotation, "return": return_annotation}
234+
return provide
235+
236+
237+
def _build_before_after_provider(field_name: str, before_alias: str, after_alias: str) -> Callable[..., Any]:
238+
"""Build a ``BeforeAfterFilter`` provider with unique ``before``/``after`` parameter names (issue #435).
239+
240+
``created_at`` and ``updated_at`` providers both used ``before``/``after``, so when
241+
enabled together they cross-bound to whichever query alias Litestar resolved first.
242+
"""
243+
before_param = f"{field_name}_before"
244+
after_param = f"{field_name}_after"
245+
246+
def provide(**kwargs: Any) -> BeforeAfterFilter:
247+
return BeforeAfterFilter(field_name, kwargs.get(before_param), kwargs.get(after_param))
248+
249+
provide.__signature__ = inspect.Signature( # type: ignore[attr-defined]
250+
parameters=[
251+
inspect.Parameter(
252+
before_param,
253+
kind=inspect.Parameter.KEYWORD_ONLY,
254+
default=Parameter(query=before_alias, default=None, required=False),
255+
annotation=DTorNone,
256+
),
257+
inspect.Parameter(
258+
after_param,
259+
kind=inspect.Parameter.KEYWORD_ONLY,
260+
default=Parameter(query=after_alias, default=None, required=False),
261+
annotation=DTorNone,
262+
),
263+
],
264+
return_annotation=BeforeAfterFilter,
265+
)
266+
provide.__annotations__ = {before_param: DTorNone, after_param: DTorNone, "return": BeforeAfterFilter}
267+
return provide
268+
269+
270+
def _create_statement_filters(
175271
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
176272
) -> dict[str, Provide]:
177273
"""Create filter dependencies based on configuration.
@@ -195,24 +291,14 @@ def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
195291
filters[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = Provide(provide_id_filter, sync_to_thread=False) # pyright: ignore[reportUnknownArgumentType]
196292

197293
if config.get("created_at", False):
198-
199-
def provide_created_filter(
200-
before: DTorNone = Parameter(query="createdBefore", default=None, required=False),
201-
after: DTorNone = Parameter(query="createdAfter", default=None, required=False),
202-
) -> BeforeAfterFilter:
203-
return BeforeAfterFilter("created_at", before, after)
204-
205-
filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = Provide(provide_created_filter, sync_to_thread=False)
294+
filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = Provide(
295+
_build_before_after_provider("created_at", "createdBefore", "createdAfter"), sync_to_thread=False
296+
)
206297

207298
if config.get("updated_at", False):
208-
209-
def provide_updated_filter(
210-
before: DTorNone = Parameter(query="updatedBefore", default=None, required=False),
211-
after: DTorNone = Parameter(query="updatedAfter", default=None, required=False),
212-
) -> BeforeAfterFilter:
213-
return BeforeAfterFilter("updated_at", before, after)
214-
215-
filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = Provide(provide_updated_filter, sync_to_thread=False)
299+
filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = Provide(
300+
_build_before_after_provider("updated_at", "updatedBefore", "updatedAfter"), sync_to_thread=False
301+
)
216302

217303
if config.get("pagination_type") == "limit_offset":
218304

@@ -276,85 +362,33 @@ def provide_order_by(
276362
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
277363

278364
for field_def in not_in_fields:
279-
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
280-
281-
def create_not_in_filter_provider( # pyright: ignore
282-
field_name: FieldNameType,
283-
) -> Callable[..., NotInCollectionFilter[field_def.type_hint] | None]: # type: ignore
284-
def provide_not_in_filter( # pyright: ignore
285-
values: list[field_name.type_hint] | None = Parameter( # type: ignore
286-
query=camelize(f"{field_name.name}_not_in"), default=None, required=False
287-
),
288-
) -> NotInCollectionFilter[field_name.type_hint] | None: # type: ignore
289-
return (
290-
NotInCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore
291-
if values
292-
else None
293-
)
294-
295-
return provide_not_in_filter # pyright: ignore
296-
297-
provider = create_not_in_filter_provider(field_def) # pyright: ignore
298-
filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
365+
resolved = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
366+
filters[f"{resolved.name}_not_in_filter"] = Provide(
367+
_build_in_collection_provider(resolved, negated=True), sync_to_thread=False
368+
)
299369

300370
if in_fields := config.get("in_fields"):
301371
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
302372

303373
for field_def in in_fields:
304-
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
305-
306-
def create_in_filter_provider( # pyright: ignore
307-
field_name: FieldNameType,
308-
) -> Callable[..., InCollectionFilter[field_def.type_hint] | None]: # type: ignore # pyright: ignore
309-
def provide_in_filter( # pyright: ignore
310-
values: list[field_name.type_hint] | None = Parameter( # type: ignore # pyright: ignore
311-
query=camelize(f"{field_name.name}_in"), default=None, required=False
312-
),
313-
) -> InCollectionFilter[field_name.type_hint] | None: # type: ignore # pyright: ignore
314-
return (
315-
InCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore # pyright: ignore
316-
if values
317-
else None
318-
)
319-
320-
return provide_in_filter # pyright: ignore
321-
322-
provider = create_in_filter_provider(field_def) # type: ignore
323-
filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
374+
resolved = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
375+
filters[f"{resolved.name}_in_filter"] = Provide(
376+
_build_in_collection_provider(resolved, negated=False), sync_to_thread=False
377+
)
324378

325379
if null_fields := config.get("null_fields"):
326380
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
327-
328381
for field_name in null_fields:
329-
330-
def create_null_filter_provider(fname: str) -> Callable[..., NullFilter | None]:
331-
def provide_null_filter(
332-
is_null: bool | None = Parameter(query=camelize(f"{fname}_is_null"), default=None, required=False),
333-
) -> NullFilter | None:
334-
return NullFilter(field_name=fname) if is_null else None
335-
336-
return provide_null_filter
337-
338-
null_provider = create_null_filter_provider(field_name)
339-
filters[f"{field_name}_null_filter"] = Provide(null_provider, sync_to_thread=False)
382+
filters[f"{field_name}_null_filter"] = Provide(
383+
_build_null_provider(field_name, negated=False), sync_to_thread=False
384+
)
340385

341386
if not_null_fields := config.get("not_null_fields"):
342387
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
343-
344388
for field_name in not_null_fields:
345-
346-
def create_not_null_filter_provider(fname: str) -> Callable[..., NotNullFilter | None]:
347-
def provide_not_null_filter(
348-
is_not_null: bool | None = Parameter(
349-
query=camelize(f"{fname}_is_not_null"), default=None, required=False
350-
),
351-
) -> NotNullFilter | None:
352-
return NotNullFilter(field_name=fname) if is_not_null else None
353-
354-
return provide_not_null_filter
355-
356-
not_null_provider = create_not_null_filter_provider(field_name)
357-
filters[f"{field_name}_not_null_filter"] = Provide(not_null_provider, sync_to_thread=False)
389+
filters[f"{field_name}_not_null_filter"] = Provide(
390+
_build_null_provider(field_name, negated=True), sync_to_thread=False
391+
)
358392

359393
if filters:
360394
filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(

tests/integration/extensions/fastapi/test_filters_integration.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,42 @@ async def list_users(
486486
assert set(data["values"]) == {"deleted", "archived"}
487487

488488

489+
def test_fastapi_multi_in_fields_mixed_types_do_not_cross_bind() -> None:
490+
"""FastAPI sub-`Depends()` are scoped, so siblings with overlapping inner param names must not collide (#435)."""
491+
sqlspec = SQLSpec()
492+
config = AiosqliteConfig(
493+
connection_config={"database": ":memory:"}, extension_config={"fastapi": {"commit_mode": "manual"}}
494+
)
495+
sqlspec.add_config(config)
496+
497+
app = FastAPI()
498+
db_ext = SQLSpecPlugin(sqlspec, app=app)
499+
500+
from sqlspec.extensions.fastapi.providers import FieldNameType, InCollectionFilter
501+
502+
@app.get("/x")
503+
async def handler(
504+
filters: Annotated[
505+
list[FilterTypes],
506+
Depends(
507+
db_ext.provide_filters({"in_fields": [FieldNameType("role", str), FieldNameType("owner_id", UUID)]})
508+
),
509+
],
510+
) -> dict[str, Any]:
511+
return {
512+
"got": [
513+
(f.field_name, [str(v) for v in (f.values or ())]) for f in filters if isinstance(f, InCollectionFilter)
514+
]
515+
}
516+
517+
valid_uuid = "11111111-2222-3333-4444-555555555555"
518+
with TestClient(app) as client:
519+
response = client.get("/x", params={"roleIn": "HR", "ownerIdIn": valid_uuid})
520+
assert response.status_code == 200, response.text
521+
got = dict(response.json()["got"])
522+
assert got == {"role": ["HR"], "owner_id": [valid_uuid]}
523+
524+
489525
def test_fastapi_in_fields_with_query_execution() -> None:
490526
"""Test in_fields filter applied to actual SQL query execution (issue #405)."""
491527
sqlspec = SQLSpec()

0 commit comments

Comments
 (0)