|
| 1 | +"""Shared filter alias helpers for framework extensions.""" |
| 2 | + |
| 3 | +from collections.abc import Mapping |
| 4 | +from typing import NamedTuple |
| 5 | + |
| 6 | +from sqlspec.utils.text import camelize |
| 7 | + |
| 8 | +__all__ = ("SortField", "SortFieldResolution", "resolve_sort_field_aliases") |
| 9 | + |
| 10 | +SortField = str | set[str] | list[str] |
| 11 | + |
| 12 | + |
| 13 | +class SortFieldResolution(NamedTuple): |
| 14 | + """Resolved sort-field alias metadata. |
| 15 | +
|
| 16 | + Args: |
| 17 | + default_field: Internal SQL-facing field used when no query value is supplied. |
| 18 | + default_query_value: API-facing default value exposed on the query parameter. |
| 19 | + allowed_fields: Internal SQL-facing allowlist. |
| 20 | + inbound_aliases: API-facing query values mapped to internal field names. |
| 21 | + field_display_names: Internal field names mapped to their preferred API-facing display names. |
| 22 | + allowed_display_names: Display names ordered to match the configured sort fields. |
| 23 | + """ |
| 24 | + |
| 25 | + default_field: str |
| 26 | + default_query_value: str |
| 27 | + allowed_fields: frozenset[str] |
| 28 | + inbound_aliases: dict[str, str] |
| 29 | + field_display_names: dict[str, str] |
| 30 | + allowed_display_names: tuple[str, ...] |
| 31 | + |
| 32 | + def normalize(self, value: str | None) -> str | None: |
| 33 | + """Normalize a query value to an internal field name. |
| 34 | +
|
| 35 | + Args: |
| 36 | + value: API-facing query value, or ``None`` to request the default field. |
| 37 | +
|
| 38 | + Returns: |
| 39 | + The internal field name if the value is configured, otherwise ``None``. |
| 40 | + """ |
| 41 | + if value is None: |
| 42 | + return self.default_field |
| 43 | + return self.inbound_aliases.get(value) |
| 44 | + |
| 45 | + |
| 46 | +def resolve_sort_field_aliases( |
| 47 | + sort_field: SortField, sort_field_aliases: Mapping[str, str] | None = None, sort_field_camelize: bool = True |
| 48 | +) -> SortFieldResolution: |
| 49 | + """Resolve sort-field aliases to a closed allowlist map. |
| 50 | +
|
| 51 | + Args: |
| 52 | + sort_field: Configured SQL-facing sort field or fields. |
| 53 | + sort_field_aliases: Optional API-facing alias to SQL-facing field mapping. |
| 54 | + sort_field_camelize: Whether to generate camel-case aliases for configured fields. Defaults to ``True``. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + Precomputed alias metadata for framework filter providers. |
| 58 | +
|
| 59 | + Raises: |
| 60 | + ValueError: If an alias targets an unknown field or collides with a different field. |
| 61 | + """ |
| 62 | + fields = _coerce_fields(sort_field) |
| 63 | + allowed_fields = frozenset(fields) |
| 64 | + inbound_aliases: dict[str, str] = {} |
| 65 | + field_display_names = {field: field for field in fields} |
| 66 | + |
| 67 | + for field in fields: |
| 68 | + _add_alias(inbound_aliases, alias=field, field=field) |
| 69 | + |
| 70 | + if sort_field_camelize: |
| 71 | + for field in fields: |
| 72 | + alias = camelize(field) |
| 73 | + _add_alias(inbound_aliases, alias=alias, field=field) |
| 74 | + field_display_names[field] = alias |
| 75 | + |
| 76 | + if sort_field_aliases: |
| 77 | + for alias, field in sort_field_aliases.items(): |
| 78 | + if field not in allowed_fields: |
| 79 | + msg = f"sort field alias '{alias}' targets unknown sort field '{field}'" |
| 80 | + raise ValueError(msg) |
| 81 | + _add_alias(inbound_aliases, alias=alias, field=field) |
| 82 | + field_display_names[field] = alias |
| 83 | + |
| 84 | + allowed_display_names = tuple(field_display_names[field] for field in fields) |
| 85 | + return SortFieldResolution( |
| 86 | + default_field=fields[0], |
| 87 | + default_query_value=field_display_names[fields[0]], |
| 88 | + allowed_fields=allowed_fields, |
| 89 | + inbound_aliases=inbound_aliases, |
| 90 | + field_display_names=field_display_names, |
| 91 | + allowed_display_names=allowed_display_names, |
| 92 | + ) |
| 93 | + |
| 94 | + |
| 95 | +def _coerce_fields(sort_field: SortField) -> tuple[str, ...]: |
| 96 | + if isinstance(sort_field, str): |
| 97 | + return (sort_field,) |
| 98 | + fields = tuple(sorted(sort_field)) if isinstance(sort_field, set) else tuple(sort_field) |
| 99 | + if not fields: |
| 100 | + msg = "sort_field must include at least one field" |
| 101 | + raise ValueError(msg) |
| 102 | + return fields |
| 103 | + |
| 104 | + |
| 105 | +def _add_alias(inbound_aliases: dict[str, str], *, alias: str, field: str) -> None: |
| 106 | + existing_field = inbound_aliases.get(alias) |
| 107 | + if existing_field is None or existing_field == field: |
| 108 | + inbound_aliases[alias] = field |
| 109 | + return |
| 110 | + |
| 111 | + msg = f"ambiguous sort field alias '{alias}' maps to both '{existing_field}' and '{field}'" |
| 112 | + raise ValueError(msg) |
0 commit comments