Skip to content

Commit c845c69

Browse files
committed
chore(client): emit Mapped[T] field types on Cached* classes (drop col()/cast() ergonomics tax)
Closes #625-spike-revisit. Generator now wraps every cache-table field's type in ``Mapped[T]`` so type checkers see ``CachedX.field`` as ``InstrumentedAttribute[T]`` (with ``.in_/.is_/.desc/.ilike``). The 36 ``sqlmodel.col()`` wrappers and 4 ``cast(QueryableAttribute, ...)`` calls landed in #624 are no longer needed and have been removed. The runtime path is unchanged: a small ``_mapped_shim.py`` module defines ``Mapped`` as ``sqlalchemy.orm.Mapped`` under ``TYPE_CHECKING`` and as a runtime-identity class (``__class_getitem__`` returns ``T``) otherwise. So ``Annotated[Mapped[int], Field(...)]`` evaluates to ``Annotated[int, Field(...)]`` at class-definition time — exactly what SQLModel + pydantic 2.13 want. The earlier #625 spike concluded "won't do" because the only working form (``if TYPE_CHECKING / else``) doubled every field declaration. That spike missed the runtime-identity-shim option, which the research follow-up surfaced via SQLModel discussion #1016 and ``roman.pt/posts/pydantic-in-sqlalchemy-fields/``. Generator changes (``scripts/generate_pydantic_models.py``): - New pass ``wrap_cache_fields_in_mapped`` runs after every other field-rewrite pass (those passes match ``Annotated[T, Field(...)]`` literally — running this wrap before them would leave them unable to find their targets). It does two regex sweeps: - ``Annotated[<type>, ...]`` → ``Annotated[Mapped[<type>], ...]`` (column fields) - ``field: <type> = Relationship(...)`` → ``field: Mapped[<type>] = Relationship(...)`` (relationship fields, matching SQLAlchemy 2.0's canonical shape) - Operates on both ``Cached*`` siblings *and* the shared entity base classes (BaseEntity, UpdatableEntity, DeletableEntity, etc.) since ``Cached*`` inherits ``id`` / ``created_at`` / ``updated_at`` / ``deleted_at`` / ``archived_at`` from those bases. Wrapping only the Cached classes left those inherited fields as bare ``T`` for type checkers — and those are the fields most query call sites filter on. API classes that share the same bases inherit the wrapping too; runtime is unaffected (Mapped → identity), and no API consumer relies on class-level access (instance-level access stays as ``T`` per pydantic's normal field handling). - Mapped shim import added to cache-table modules and ``base.py``. Foundation-file cleanup (``katana_mcp_server/.../foundation/*.py``): - Dropped 36 ``col(CachedX.field)`` wrappers across 6 files (inventory, manufacturing_orders, sales_orders, purchase_orders, stock_transfers, reporting) — types now resolve correctly without the wrapper. - Dropped 4 ``cast("QueryableAttribute[Any]", CachedX.<rel>)`` wrappers around ``selectinload`` arguments — relationship fields are now ``Mapped[list[X]]`` so ``selectinload(CachedX.<rel>)`` type-checks natively. - Removed corresponding unused imports (``col``, ``QueryableAttribute``, ``cast``). Typecheck config (``pyproject.toml``): - ``ty check`` now also excludes ``katana_public_api_client/models_pydantic/_generated/`` (matches ``pyrightconfig.json`` exclusions). Ty rejects the trailing ``= None`` default on ``Annotated[Mapped[T | None], Field(...)] = None``; pyright accepts it. Generated files are auto-emitted, so excluding them from in-repo type-checking is consistent with how pyright already handles them. Validation: - ``uv run poe lint`` clean (ruff + ty + yamllint) - ``uv run poe check`` clean (full validation) - ``uv run pyright katana_mcp_server/src/`` 0 errors / 1 documented warning - All 3,019 tests pass This is forward-compatible: when SQLModel ships native ``Mapped[T]`` support, the shim becomes redundant and can be deleted in a one-line PR (the generator output stays the same). Refs: #625
1 parent b82abdb commit c845c69

16 files changed

Lines changed: 609 additions & 264 deletions

File tree

katana_mcp_server/src/katana_mcp/tools/foundation/inventory.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import json
1111
import time
1212
from datetime import datetime
13-
from typing import Annotated, Any, Literal, cast
13+
from typing import Annotated, Any, Literal
1414

1515
from fastmcp import Context, FastMCP
1616
from fastmcp.tools import ToolResult
@@ -1192,7 +1192,7 @@ def _apply_stock_adjustment_filters(
11921192
this function lets the paginated path avoid re-parsing on the COUNT
11931193
query.
11941194
"""
1195-
from sqlmodel import col, exists, select
1195+
from sqlmodel import exists, select
11961196

11971197
from katana_public_api_client.models_pydantic._generated import (
11981198
CachedStockAdjustment,
@@ -1202,14 +1202,14 @@ def _apply_stock_adjustment_filters(
12021202
if request.location_id is not None:
12031203
stmt = stmt.where(CachedStockAdjustment.location_id == request.location_id)
12041204
if request.ids is not None:
1205-
stmt = stmt.where(col(CachedStockAdjustment.id).in_(request.ids))
1205+
stmt = stmt.where(CachedStockAdjustment.id.in_(request.ids))
12061206
if request.stock_adjustment_number is not None:
12071207
stmt = stmt.where(
12081208
CachedStockAdjustment.stock_adjustment_number
12091209
== request.stock_adjustment_number
12101210
)
12111211
if not request.include_deleted:
1212-
stmt = stmt.where(col(CachedStockAdjustment.deleted_at).is_(None))
1212+
stmt = stmt.where(CachedStockAdjustment.deleted_at.is_(None))
12131213

12141214
# ``variant_id`` is a row-level field — EXISTS subquery scans the
12151215
# indexed FK directly so a match on any row of any adjustment is
@@ -1229,7 +1229,7 @@ def _apply_stock_adjustment_filters(
12291229
if request.reason is not None:
12301230
needle = request.reason.strip()
12311231
if needle:
1232-
stmt = stmt.where(col(CachedStockAdjustment.reason).ilike(f"%{needle}%"))
1232+
stmt = stmt.where(CachedStockAdjustment.reason.ilike(f"%{needle}%"))
12331233

12341234
return apply_date_window_filters(
12351235
stmt,
@@ -1256,8 +1256,8 @@ async def _list_stock_adjustments_impl(
12561256
subquery against the row table so a match on any row is found
12571257
regardless of how many adjustments precede it. See ADR-0018.
12581258
"""
1259-
from sqlalchemy.orm import QueryableAttribute, selectinload
1260-
from sqlmodel import col, func, select
1259+
from sqlalchemy.orm import selectinload
1260+
from sqlmodel import func, select
12611261

12621262
from katana_mcp.typed_cache import ensure_stock_adjustments_synced
12631263
from katana_public_api_client.models_pydantic._generated import (
@@ -1276,16 +1276,11 @@ async def _list_stock_adjustments_impl(
12761276
# materialization time and we skip the correlated COUNT subquery.
12771277
if request.include_rows:
12781278
stmt = select(CachedStockAdjustment).options(
1279-
selectinload(
1280-
cast(
1281-
"QueryableAttribute[Any]",
1282-
CachedStockAdjustment.stock_adjustment_rows,
1283-
)
1284-
)
1279+
selectinload(CachedStockAdjustment.stock_adjustment_rows)
12851280
)
12861281
else:
12871282
row_count_subq = (
1288-
select(func.count(col(CachedStockAdjustmentRow.id)))
1283+
select(func.count(CachedStockAdjustmentRow.id))
12891284
.where(
12901285
CachedStockAdjustmentRow.stock_adjustment_id == CachedStockAdjustment.id
12911286
)
@@ -1296,8 +1291,8 @@ async def _list_stock_adjustments_impl(
12961291
stmt = select(CachedStockAdjustment, row_count_subq)
12971292
stmt = _apply_stock_adjustment_filters(stmt, request, parsed_dates)
12981293
stmt = stmt.order_by(
1299-
col(CachedStockAdjustment.created_at).desc(),
1300-
col(CachedStockAdjustment.id).desc(),
1294+
CachedStockAdjustment.created_at.desc(),
1295+
CachedStockAdjustment.id.desc(),
13011296
)
13021297
if request.page is not None:
13031298
stmt = stmt.offset((request.page - 1) * request.limit).limit(request.limit)

katana_mcp_server/src/katana_mcp/tools/foundation/manufacturing_orders.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,27 +1777,24 @@ def _apply_manufacturing_order_filters(
17771777
reflect exactly the same filter set as the data rows. ``parsed_dates``
17781778
must come from :func:`parse_request_dates`.
17791779
"""
1780-
from sqlmodel import col
17811780

17821781
from katana_public_api_client.models_pydantic._generated import (
17831782
CachedManufacturingOrder,
17841783
)
17851784

17861785
if request.ids is not None:
1787-
stmt = stmt.where(col(CachedManufacturingOrder.id).in_(request.ids))
1786+
stmt = stmt.where(CachedManufacturingOrder.id.in_(request.ids))
17881787
if request.order_no is not None:
17891788
stmt = stmt.where(CachedManufacturingOrder.order_no == request.order_no)
17901789
if request.status is not None:
17911790
stmt = stmt.where(CachedManufacturingOrder.status == request.status)
17921791
if request.location_id is not None:
17931792
stmt = stmt.where(CachedManufacturingOrder.location_id == request.location_id)
17941793
if request.variant_ids is not None:
1795-
stmt = stmt.where(
1796-
col(CachedManufacturingOrder.variant_id).in_(request.variant_ids)
1797-
)
1794+
stmt = stmt.where(CachedManufacturingOrder.variant_id.in_(request.variant_ids))
17981795
if request.sales_order_ids is not None:
17991796
stmt = stmt.where(
1800-
col(CachedManufacturingOrder.sales_order_id).in_(request.sales_order_ids)
1797+
CachedManufacturingOrder.sales_order_id.in_(request.sales_order_ids)
18011798
)
18021799
if request.ingredient_availability is not None:
18031800
stmt = stmt.where(
@@ -1814,7 +1811,7 @@ def _apply_manufacturing_order_filters(
18141811
== request.is_linked_to_sales_order
18151812
)
18161813
if not request.include_deleted:
1817-
stmt = stmt.where(col(CachedManufacturingOrder.deleted_at).is_(None))
1814+
stmt = stmt.where(CachedManufacturingOrder.deleted_at.is_(None))
18181815

18191816
return apply_date_window_filters(
18201817
stmt,
@@ -1842,7 +1839,7 @@ async def _list_manufacturing_orders_impl(
18421839
Filters (including ``production_deadline_*``) translate to indexed
18431840
SQL. See ADR-0018.
18441841
"""
1845-
from sqlmodel import col, func, select
1842+
from sqlmodel import func, select
18461843

18471844
from katana_mcp.typed_cache import ensure_manufacturing_orders_synced
18481845
from katana_public_api_client.models_pydantic._generated import (
@@ -1858,8 +1855,8 @@ async def _list_manufacturing_orders_impl(
18581855
stmt = select(CachedManufacturingOrder)
18591856
stmt = _apply_manufacturing_order_filters(stmt, request, parsed_dates)
18601857
stmt = stmt.order_by(
1861-
col(CachedManufacturingOrder.created_at).desc(),
1862-
col(CachedManufacturingOrder.id).desc(),
1858+
CachedManufacturingOrder.created_at.desc(),
1859+
CachedManufacturingOrder.id.desc(),
18631860
)
18641861
if request.page is not None:
18651862
stmt = stmt.offset((request.page - 1) * request.limit).limit(request.limit)
@@ -2182,7 +2179,7 @@ async def _list_blocking_ingredients_impl(
21822179
hasn't run for, which would otherwise surface as a false-positive
21832180
blocking entry.
21842181
"""
2185-
from sqlmodel import col, select
2182+
from sqlmodel import select
21862183

21872184
from katana_mcp.typed_cache import (
21882185
MANUFACTURING_ORDER_RECIPE_ROW_SPEC,
@@ -2264,24 +2261,22 @@ async def _list_blocking_ingredients_impl(
22642261
select(CachedManufacturingOrderRecipeRow, CachedManufacturingOrder)
22652262
.join(
22662263
CachedManufacturingOrder,
2267-
col(CachedManufacturingOrder.id)
2264+
CachedManufacturingOrder.id
22682265
== CachedManufacturingOrderRecipeRow.manufacturing_order_id,
22692266
)
2270-
.where(col(CachedManufacturingOrder.deleted_at).is_(None))
2271-
.where(col(CachedManufacturingOrderRecipeRow.deleted_at).is_(None))
2272-
.where(col(CachedManufacturingOrder.status).in_(statuses))
2267+
.where(CachedManufacturingOrder.deleted_at.is_(None))
2268+
.where(CachedManufacturingOrderRecipeRow.deleted_at.is_(None))
2269+
.where(CachedManufacturingOrder.status.in_(statuses))
22732270
.where(
2274-
col(CachedManufacturingOrderRecipeRow.ingredient_availability).in_(
2271+
CachedManufacturingOrderRecipeRow.ingredient_availability.in_(
22752272
list(_BLOCKING_AVAILABILITY)
22762273
)
22772274
)
22782275
)
22792276
if request.mo_ids is not None:
2280-
stmt = stmt.where(col(CachedManufacturingOrder.id).in_(request.mo_ids))
2277+
stmt = stmt.where(CachedManufacturingOrder.id.in_(request.mo_ids))
22812278
if request.mo_order_nos is not None:
2282-
stmt = stmt.where(
2283-
col(CachedManufacturingOrder.order_no).in_(request.mo_order_nos)
2284-
)
2279+
stmt = stmt.where(CachedManufacturingOrder.order_no.in_(request.mo_order_nos))
22852280
if request.location_id is not None:
22862281
stmt = stmt.where(CachedManufacturingOrder.location_id == request.location_id)
22872282
stmt = apply_date_window_filters(

katana_mcp_server/src/katana_mcp/tools/foundation/purchase_orders.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import asyncio
1414
from datetime import UTC, datetime
1515
from enum import StrEnum
16-
from typing import Annotated, Any, Literal, cast
16+
from typing import Annotated, Any, Literal
1717

1818
from fastmcp import Context, FastMCP
1919
from fastmcp.tools import ToolResult
@@ -2013,7 +2013,6 @@ def _apply_purchase_order_filters(
20132013
Shared by the data SELECT and the COUNT SELECT so pagination totals
20142014
reflect exactly the same filter set as the data rows.
20152015
"""
2016-
from sqlmodel import col
20172016

20182017
from katana_public_api_client.models_pydantic._generated import (
20192018
CachedPurchaseOrder,
@@ -2023,7 +2022,7 @@ def _apply_purchase_order_filters(
20232022
)
20242023

20252024
if request.ids is not None:
2026-
stmt = stmt.where(col(CachedPurchaseOrder.id).in_(request.ids))
2025+
stmt = stmt.where(CachedPurchaseOrder.id.in_(request.ids))
20272026
if request.order_no is not None:
20282027
stmt = stmt.where(CachedPurchaseOrder.order_no == request.order_no)
20292028
if request.entity_type is not None:
@@ -2056,7 +2055,7 @@ def _apply_purchase_order_filters(
20562055
if request.supplier_id is not None:
20572056
stmt = stmt.where(CachedPurchaseOrder.supplier_id == request.supplier_id)
20582057
if not request.include_deleted:
2059-
stmt = stmt.where(col(CachedPurchaseOrder.deleted_at).is_(None))
2058+
stmt = stmt.where(CachedPurchaseOrder.deleted_at.is_(None))
20602059

20612060
return apply_date_window_filters(
20622061
stmt,
@@ -2085,8 +2084,8 @@ async def _list_purchase_orders_impl(
20852084
outsourced-only ``tracking_location_id``) translate to indexed SQL.
20862085
See ADR-0018.
20872086
"""
2088-
from sqlalchemy.orm import QueryableAttribute, selectinload
2089-
from sqlmodel import col, func, select
2087+
from sqlalchemy.orm import selectinload
2088+
from sqlmodel import func, select
20902089

20912090
from katana_mcp.typed_cache import ensure_purchase_orders_synced
20922091
from katana_public_api_client.models_pydantic._generated import (
@@ -2105,16 +2104,11 @@ async def _list_purchase_orders_impl(
21052104
# time and we skip the correlated COUNT subquery.
21062105
if request.include_rows:
21072106
stmt = select(CachedPurchaseOrder).options(
2108-
selectinload(
2109-
cast(
2110-
"QueryableAttribute[Any]",
2111-
CachedPurchaseOrder.purchase_order_rows,
2112-
)
2113-
)
2107+
selectinload(CachedPurchaseOrder.purchase_order_rows)
21142108
)
21152109
else:
21162110
row_count_subq = (
2117-
select(func.count(col(CachedPurchaseOrderRow.id)))
2111+
select(func.count(CachedPurchaseOrderRow.id))
21182112
.where(CachedPurchaseOrderRow.purchase_order_id == CachedPurchaseOrder.id)
21192113
.correlate(CachedPurchaseOrder)
21202114
.scalar_subquery()
@@ -2123,8 +2117,8 @@ async def _list_purchase_orders_impl(
21232117
stmt = select(CachedPurchaseOrder, row_count_subq)
21242118
stmt = _apply_purchase_order_filters(stmt, request, parsed_dates)
21252119
stmt = stmt.order_by(
2126-
col(CachedPurchaseOrder.created_at).desc(),
2127-
col(CachedPurchaseOrder.id).desc(),
2120+
CachedPurchaseOrder.created_at.desc(),
2121+
CachedPurchaseOrder.id.desc(),
21282122
)
21292123
if request.page is not None:
21302124
stmt = stmt.offset((request.page - 1) * request.limit).limit(request.limit)

katana_mcp_server/src/katana_mcp/tools/foundation/reporting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ async def _fetch_completed_mo_recipe_rows_in_window(
793793
window datetimes must also be naive UTC — callers are responsible for
794794
stripping tzinfo before passing.
795795
"""
796-
from sqlmodel import col, select
796+
from sqlmodel import select
797797

798798
from katana_mcp.typed_cache import ensure_manufacturing_orders_synced
799799
from katana_public_api_client.models_pydantic._generated import (
@@ -815,13 +815,13 @@ async def _fetch_completed_mo_recipe_rows_in_window(
815815
select(CachedManufacturingOrderRecipeRow)
816816
.join(
817817
CachedManufacturingOrder,
818-
col(CachedManufacturingOrder.id)
818+
CachedManufacturingOrder.id
819819
== CachedManufacturingOrderRecipeRow.manufacturing_order_id,
820820
)
821821
.where(CachedManufacturingOrder.status == ManufacturingOrderStatus.done)
822-
.where(col(CachedManufacturingOrder.done_date).is_not(None))
823-
.where(col(CachedManufacturingOrder.deleted_at).is_(None))
824-
.where(col(CachedManufacturingOrderRecipeRow.deleted_at).is_(None))
822+
.where(CachedManufacturingOrder.done_date.is_not(None))
823+
.where(CachedManufacturingOrder.deleted_at.is_(None))
824+
.where(CachedManufacturingOrderRecipeRow.deleted_at.is_(None))
825825
)
826826
row_stmt = apply_date_window_filters(
827827
row_stmt,

katana_mcp_server/src/katana_mcp/tools/foundation/sales_orders.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import asyncio
1717
from datetime import datetime
1818
from enum import StrEnum
19-
from typing import Annotated, Any, Literal, cast
19+
from typing import Annotated, Any, Literal
2020

2121
from fastmcp import Context, FastMCP
2222
from fastmcp.tools import ToolResult
@@ -730,7 +730,6 @@ def _apply_sales_order_filters(
730730
this function lets the paginated path avoid re-parsing on the COUNT
731731
query.
732732
"""
733-
from sqlmodel import col
734733

735734
from katana_public_api_client.models_pydantic._generated import (
736735
CachedSalesOrder,
@@ -745,7 +744,7 @@ def _apply_sales_order_filters(
745744
if request.order_no is not None:
746745
stmt = stmt.where(CachedSalesOrder.order_no == request.order_no)
747746
if request.ids is not None:
748-
stmt = stmt.where(col(CachedSalesOrder.id).in_(request.ids))
747+
stmt = stmt.where(CachedSalesOrder.id.in_(request.ids))
749748
if request.customer_id is not None:
750749
stmt = stmt.where(CachedSalesOrder.customer_id == request.customer_id)
751750
if request.location_id is not None:
@@ -767,7 +766,7 @@ def _apply_sales_order_filters(
767766
if request.currency is not None:
768767
stmt = stmt.where(CachedSalesOrder.currency == request.currency)
769768
if not request.include_deleted:
770-
stmt = stmt.where(col(CachedSalesOrder.deleted_at).is_(None))
769+
stmt = stmt.where(CachedSalesOrder.deleted_at.is_(None))
771770

772771
return apply_date_window_filters(
773772
stmt,
@@ -795,8 +794,8 @@ async def _list_sales_orders_impl(
795794
translates request filters into indexed SQL and returns results
796795
directly. See ADR-0018.
797796
"""
798-
from sqlalchemy.orm import QueryableAttribute, selectinload
799-
from sqlmodel import col, func, select
797+
from sqlalchemy.orm import selectinload
798+
from sqlmodel import func, select
800799

801800
from katana_mcp.typed_cache import ensure_sales_orders_synced
802801
from katana_public_api_client.models_pydantic._generated import (
@@ -815,23 +814,19 @@ async def _list_sales_orders_impl(
815814
# time and we skip the correlated COUNT subquery entirely.
816815
if request.include_rows:
817816
stmt = select(CachedSalesOrder).options(
818-
selectinload(
819-
cast("QueryableAttribute[Any]", CachedSalesOrder.sales_order_rows)
820-
)
817+
selectinload(CachedSalesOrder.sales_order_rows)
821818
)
822819
else:
823820
row_count_subq = (
824-
select(func.count(col(CachedSalesOrderRow.id)))
821+
select(func.count(CachedSalesOrderRow.id))
825822
.where(CachedSalesOrderRow.sales_order_id == CachedSalesOrder.id)
826823
.correlate(CachedSalesOrder)
827824
.scalar_subquery()
828825
.label("row_count")
829826
)
830827
stmt = select(CachedSalesOrder, row_count_subq)
831828
stmt = _apply_sales_order_filters(stmt, request, parsed_dates)
832-
stmt = stmt.order_by(
833-
col(CachedSalesOrder.created_at).desc(), col(CachedSalesOrder.id).desc()
834-
)
829+
stmt = stmt.order_by(CachedSalesOrder.created_at.desc(), CachedSalesOrder.id.desc())
835830
if request.page is not None:
836831
stmt = stmt.offset((request.page - 1) * request.limit).limit(request.limit)
837832
else:

0 commit comments

Comments
 (0)