Skip to content

Commit 8ef5106

Browse files
dougborgclaude
andcommitted
feat(mcp): rewire @cache_read decorator to Cached* class keys (#472 Phase C)
The @cache_read decorator now keys off the typed-cache Cached* row classes (CachedVariant, CachedProduct, ...) instead of the legacy EntityType StrEnum. Each registered class maps to a small wrapper that fans out to BOTH the legacy cache_sync.ensure_<entity>_synced(services) AND the typed typed_cache.ensure_<entity>_synced(client, typed_cache) helpers via asyncio.gather, so tool bodies see fresh data on either path during the Phase C → Phase D transition. The 11 catalog @cache_read(EntityType.X) call sites in tools/foundation/{customers,inventory,items,reference,reporting}.py shift to @cache_read(CachedX). Tool bodies that still call services.cache.<method>(EntityType.X, ...) are untouched — Phase D will migrate them onto services.typed_cache.catalog.* and remove the legacy cache. Tests that mock decorators._sync_fns (test_inventory, test_items, test_customers, test_reference, integration/test_error_scenarios) flip their dict keys from EntityType to Cached* classes. test_reporting's {}-returning patch is unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 39c2744 commit 8ef5106

15 files changed

Lines changed: 289 additions & 120 deletions

File tree

katana_mcp_server/docs/architecture.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,15 @@ sync when tool surface changes.
8282

8383
### Cache-aware decorators
8484

85-
`tools/decorators.py` provides `@cache_read("entity")` and
86-
`@cache_write("entity_a", "entity_b")` decorators. `cache_read` triggers an incremental
87-
sync of the named entity before invoking the tool; `cache_write` invalidates the listed
88-
entities after a mutating call so the next list/get returns fresh data. Tool
89-
implementations stay focused on business logic; sync orchestration lives in the
90-
decorator.
85+
`tools/decorators.py` provides `@cache_read(CachedVariant, ...)` and
86+
`@cache_write("entity_a", "entity_b")` decorators. `cache_read` keys off the typed-cache
87+
`Cached*` row class (e.g. `CachedVariant`, `CachedProduct`) and triggers an incremental
88+
sync of the named entity before invoking the tool. During the #472 unification rollout
89+
(Phase C) the decorator fans each sync out to BOTH the legacy `CatalogCache` helper and
90+
the typed-cache helper so tool bodies see fresh data on either path; Phase D drops the
91+
legacy half. `cache_write` invalidates the listed entities after a mutating call so the
92+
next list/get returns fresh data. Tool implementations stay focused on business logic;
93+
sync orchestration lives in the decorator.
9194

9295
## Resources
9396

@@ -210,9 +213,9 @@ bugs at the client/generator layer").
210213
integration; use elicitation for any state-changing operation.
211214
1. **Follow ADR-0019** for naming (`<entity>_<field>s` for batch list filters, singular
212215
for `get_*`) and the docstring opening sentence.
213-
1. **If the tool reads from cache,** add `@cache_read("entity")`. If it writes, add
214-
`@cache_write("entity_a", "entity_b")` listing every entity whose cache should be
215-
invalidated.
216+
1. **If the tool reads from cache,** add `@cache_read(CachedEntity)` keyed by the typed
217+
`Cached*` row class. If it writes, add `@cache_write("entity_a", "entity_b")` listing
218+
every entity whose cache should be invalidated.
216219
1. **For new transactional list tools backed by typed cache:** add an `EntitySpec`
217220
literal in `typed_cache/sync.py` and a thin `ensure_<entity>_synced` wrapper. The
218221
`Cached<Entity>` row class is auto-generated from the spec by the next regen.

katana_mcp_server/src/katana_mcp/tools/decorators.py

Lines changed: 102 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
Usage::
77
8-
@cache_read("variant")
8+
@cache_read(CachedVariant)
99
async def _search_items_impl(request, context):
1010
services = get_services(context)
1111
return await services.cache.smart_search("variant", request.query)
@@ -15,65 +15,124 @@ async def _search_items_impl(request, context):
1515
async def _create_item_impl(request, context):
1616
services = get_services(context)
1717
return await services.client.products.create(...)
18+
19+
The ``cache_read`` decorator now keys off the typed-cache ``Cached*``
20+
classes (``CachedVariant``, ``CachedProduct``, …) instead of the legacy
21+
``EntityType`` enum. During the #472 unification rollout the decorator
22+
runs **both** the legacy ``cache_sync.ensure_*_synced`` helper and the
23+
typed ``typed_cache.ensure_*_synced`` helper for each registered class
24+
so tool bodies see fresh data on either path. Phase D drops the legacy
25+
half along with the call sites that read from ``services.cache``.
1826
"""
1927

2028
from __future__ import annotations
2129

22-
from collections.abc import Callable
30+
import asyncio
31+
from collections.abc import Awaitable, Callable
2332
from functools import wraps
24-
from typing import Any, cast
33+
from typing import TYPE_CHECKING, Any, cast
2534

26-
from katana_mcp.cache import EntityType
2735
from katana_mcp.services import get_services
2836

29-
# Lazy-initialized cache of sync functions (avoids circular imports)
30-
_sync_fns: dict[str, Any] | None = None
31-
32-
33-
def _get_sync_fns() -> dict[str, Any]:
34-
"""Get the entity-type → sync function mapping (initialized once)."""
37+
if TYPE_CHECKING:
38+
from sqlmodel import SQLModel
39+
40+
from katana_mcp.services.dependencies import Services
41+
42+
43+
# Lazy-initialized cache of sync functions (avoids circular imports).
44+
# Keys are typed-cache ``Cached*`` classes; values are async wrappers
45+
# that fan out to the legacy ``cache_sync.ensure_*_synced(services)``
46+
# AND the typed ``typed_cache.ensure_*_synced(client, typed_cache)``
47+
# helpers (#472 Phase C). Phase D drops the legacy half.
48+
_sync_fns: dict[type[SQLModel], Callable[[Services], Awaitable[None]]] | None = None
49+
50+
# (Cached* class, ensure-helper stem) — both ``cache_sync`` and ``typed_cache``
51+
# expose ``ensure_<stem>_synced``, differing only in argument shape (legacy
52+
# takes ``services``; typed takes ``client, cache``). The dual wrapper in
53+
# ``_get_sync_fns`` resolves the stem against both modules.
54+
_DUAL_SYNC_REGISTRY: tuple[tuple[str, str], ...] = (
55+
("CachedVariant", "variants"),
56+
("CachedProduct", "products"),
57+
("CachedMaterial", "materials"),
58+
("CachedService", "services"),
59+
("CachedSupplier", "suppliers"),
60+
("CachedCustomer", "customers"),
61+
("CachedLocation", "locations"),
62+
("CachedTaxRate", "tax_rates"),
63+
("CachedOperator", "operators"),
64+
("CachedFactory", "factory"),
65+
("CachedAdditionalCost", "additional_costs"),
66+
)
67+
68+
69+
def _get_sync_fns() -> dict[type[SQLModel], Callable[[Services], Awaitable[None]]]:
70+
"""Get the ``Cached*`` class → dual-sync wrapper mapping (initialized once).
71+
72+
Each wrapper runs the legacy ``cache_sync.ensure_<stem>_synced(services)``
73+
and the typed ``typed_cache.ensure_<stem>_synced(client, typed_cache)``
74+
concurrently via ``asyncio.gather`` so both caches stay populated during
75+
the Phase C → Phase D transition without serializing two API fetches.
76+
The registry is tiny on purpose — Phase D removes the legacy half.
77+
"""
3578
global _sync_fns # noqa: PLW0603
3679
if _sync_fns is None:
37-
from katana_mcp.cache_sync import (
38-
ensure_additional_costs_synced,
39-
ensure_customers_synced,
40-
ensure_factory_synced,
41-
ensure_locations_synced,
42-
ensure_materials_synced,
43-
ensure_operators_synced,
44-
ensure_products_synced,
45-
ensure_services_synced,
46-
ensure_suppliers_synced,
47-
ensure_tax_rates_synced,
48-
ensure_variants_synced,
49-
)
80+
from katana_mcp import cache_sync, typed_cache
81+
from katana_public_api_client.models_pydantic import _generated as cached_models
82+
83+
def _dual(
84+
legacy: Callable[[Services], Awaitable[None]],
85+
typed: Callable[..., Awaitable[None]],
86+
) -> Callable[[Services], Awaitable[None]]:
87+
async def _wrapped(services: Services) -> None:
88+
await asyncio.gather(
89+
legacy(services),
90+
typed(services.client, services.typed_cache),
91+
)
92+
93+
return _wrapped
5094

5195
_sync_fns = {
52-
EntityType.VARIANT: ensure_variants_synced,
53-
EntityType.PRODUCT: ensure_products_synced,
54-
EntityType.MATERIAL: ensure_materials_synced,
55-
EntityType.SERVICE: ensure_services_synced,
56-
EntityType.SUPPLIER: ensure_suppliers_synced,
57-
EntityType.CUSTOMER: ensure_customers_synced,
58-
EntityType.LOCATION: ensure_locations_synced,
59-
EntityType.TAX_RATE: ensure_tax_rates_synced,
60-
EntityType.OPERATOR: ensure_operators_synced,
61-
EntityType.FACTORY: ensure_factory_synced,
62-
EntityType.ADDITIONAL_COST: ensure_additional_costs_synced,
96+
getattr(cached_models, cls_name): _dual(
97+
getattr(cache_sync, f"ensure_{stem}_synced"),
98+
getattr(typed_cache, f"ensure_{stem}_synced"),
99+
)
100+
for cls_name, stem in _DUAL_SYNC_REGISTRY
63101
}
64102
return _sync_fns
65103

66104

67-
def cache_read(*entity_types: str) -> Callable:
68-
"""Sync cache for entity types before executing the tool.
105+
def cache_read(*entity_classes: type[SQLModel]) -> Callable:
106+
"""Sync cache for the given typed ``Cached*`` classes before running the tool.
69107
70-
Calls ``ensure_{type}_synced(services)`` for each entity type before
71-
running the decorated function. The function receives a context with
72-
a guaranteed-fresh cache.
108+
For each class the decorator looks up the registered sync wrapper in
109+
``_get_sync_fns()`` and awaits it. Each wrapper currently fans out to
110+
BOTH the legacy ``CatalogCache`` sync helper AND the typed-cache
111+
``ensure_*_synced`` helper so tool bodies see fresh data on either
112+
path during the #472 unification rollout. Phase D drops the legacy
113+
half along with the ``services.cache`` call sites.
114+
115+
Unknown classes raise ``ValueError`` at decoration time so a typo
116+
fails at import, not silently as a stale-cache read at first call.
73117
74118
Args:
75-
*entity_types: Entity type names to sync (e.g., "variant", "product").
119+
*entity_classes: Typed ``Cached*`` classes to sync (e.g.,
120+
``CachedVariant``, ``CachedProduct``).
76121
"""
122+
# Fail fast at decoration time so a typo blows up at import, not as
123+
# a silent stale-cache read on the first request. Tests that swap
124+
# ``decorators._sync_fns`` in autouse fixtures still get the live
125+
# mocks at call time — the wrapper re-resolves through
126+
# ``_get_sync_fns()``.
127+
registered = _get_sync_fns()
128+
unknown = [cls for cls in entity_classes if cls not in registered]
129+
if unknown:
130+
names = ", ".join(cls.__name__ for cls in unknown)
131+
known = ", ".join(sorted(c.__name__ for c in registered))
132+
raise ValueError(
133+
f"@cache_read: unregistered Cached* class(es): {names}. "
134+
f"Registered classes: {known}."
135+
)
77136

78137
def decorator[F: Callable[..., Any]](fn: F) -> F:
79138
@wraps(fn)
@@ -82,11 +141,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
82141
services = get_services(context)
83142

84143
sync_fns = _get_sync_fns()
85-
for et in entity_types:
86-
# EntityType is a StrEnum — normalize to ensure dict lookup works
87-
key = EntityType(et) if not isinstance(et, EntityType) else et
88-
sync_fn = sync_fns.get(key)
89-
if sync_fn:
144+
for cls in entity_classes:
145+
sync_fn = sync_fns.get(cls)
146+
if sync_fn is not None:
90147
await sync_fn(services)
91148

92149
return await fn(*args, **kwargs)

katana_mcp_server/src/katana_mcp/tools/foundation/customers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from katana_mcp.tools.tool_result_utils import make_simple_result
2121
from katana_mcp.unpack import Unpack, unpack_pydantic_params
2222
from katana_mcp.web_urls import katana_web_url
23+
from katana_public_api_client.models_pydantic._generated import CachedCustomer
2324

2425
# ============================================================================
2526
# Tool 1: search_customers
@@ -74,7 +75,7 @@ def _customer_from_dict(d: dict) -> CustomerInfo:
7475
)
7576

7677

77-
@cache_read(EntityType.CUSTOMER)
78+
@cache_read(CachedCustomer)
7879
async def _search_customers_impl(
7980
request: SearchCustomersRequest, context: Context
8081
) -> SearchCustomersResponse:
@@ -274,7 +275,7 @@ async def _fetch_customer_addresses(
274275
return result
275276

276277

277-
@cache_read(EntityType.CUSTOMER)
278+
@cache_read(CachedCustomer)
278279
async def _get_customer_impl(
279280
request: GetCustomerRequest, context: Context
280281
) -> GetCustomerResponse:

katana_mcp_server/src/katana_mcp/tools/foundation/inventory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from katana_public_api_client.api.stock_adjustment import get_all_stock_adjustments
3838
from katana_public_api_client.client_types import UNSET, Unset
3939
from katana_public_api_client.domain.converters import to_unset, unwrap_unset
40+
from katana_public_api_client.models_pydantic._generated import CachedVariant
4041
from katana_public_api_client.utils import unwrap_data
4142

4243
logger = get_logger(__name__)
@@ -415,7 +416,7 @@ class LowStockResponse(BaseModel):
415416
total_count: int
416417

417418

418-
@cache_read(EntityType.VARIANT)
419+
@cache_read(CachedVariant)
419420
async def _list_low_stock_items_impl(
420421
request: LowStockRequest, context: Context
421422
) -> LowStockResponse:

katana_mcp_server/src/katana_mcp/tools/foundation/items.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@
9191
UpdateVariantRequestConfigAttributesItem as APIUpdateVariantConfigItem,
9292
Variant,
9393
)
94+
from katana_public_api_client.models_pydantic._generated import (
95+
CachedMaterial,
96+
CachedProduct,
97+
CachedSupplier,
98+
CachedVariant,
99+
)
94100

95101
logger = get_logger(__name__)
96102

@@ -177,7 +183,7 @@ def _search_response_to_tool_result(
177183
return make_tool_result(filtered_response, ui=ui)
178184

179185

180-
@cache_read(EntityType.VARIANT)
186+
@cache_read(CachedVariant)
181187
async def _search_items_impl(
182188
request: SearchItemsRequest, context: Context
183189
) -> SearchItemsResponse:
@@ -1755,10 +1761,10 @@ def _partition_variant_lookups(
17551761

17561762

17571763
@cache_read(
1758-
EntityType.VARIANT,
1759-
EntityType.PRODUCT,
1760-
EntityType.MATERIAL,
1761-
EntityType.SUPPLIER,
1764+
CachedVariant,
1765+
CachedProduct,
1766+
CachedMaterial,
1767+
CachedSupplier,
17621768
)
17631769
async def _get_variant_details_impl(
17641770
request: GetVariantDetailsRequest, context: Context

katana_mcp_server/src/katana_mcp/tools/foundation/reference.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
from katana_mcp.tools.decorators import cache_read
2121
from katana_mcp.tools.tool_result_utils import make_simple_result
2222
from katana_mcp.unpack import Unpack, unpack_pydantic_params
23+
from katana_public_api_client.models_pydantic._generated import (
24+
CachedAdditionalCost,
25+
CachedLocation,
26+
CachedOperator,
27+
CachedSupplier,
28+
CachedTaxRate,
29+
)
2330

2431
# ============================================================================
2532
# Shared helpers
@@ -140,7 +147,7 @@ def _supplier_summary_from_dict(d: dict[str, Any]) -> SupplierInfo:
140147
)
141148

142149

143-
@cache_read(EntityType.SUPPLIER)
150+
@cache_read(CachedSupplier)
144151
async def _list_suppliers_impl(
145152
request: ListSuppliersRequest, context: Context
146153
) -> ListSuppliersResponse:
@@ -230,7 +237,7 @@ def _iso_or_none(value: Any) -> str | None:
230237
return str(value)
231238

232239

233-
@cache_read(EntityType.SUPPLIER)
240+
@cache_read(CachedSupplier)
234241
async def _get_supplier_impl(
235242
request: GetSupplierRequest, context: Context
236243
) -> GetSupplierResponse:
@@ -383,7 +390,7 @@ def _location_from_dict(d: dict[str, Any]) -> LocationInfo:
383390
)
384391

385392

386-
@cache_read(EntityType.LOCATION)
393+
@cache_read(CachedLocation)
387394
async def _list_locations_impl(
388395
request: ListLocationsRequest, context: Context
389396
) -> ListLocationsResponse:
@@ -482,7 +489,7 @@ def _tax_rate_from_dict(d: dict[str, Any]) -> TaxRateInfo:
482489
)
483490

484491

485-
@cache_read(EntityType.TAX_RATE)
492+
@cache_read(CachedTaxRate)
486493
async def _list_tax_rates_impl(
487494
request: ListTaxRatesRequest, context: Context
488495
) -> ListTaxRatesResponse:
@@ -566,7 +573,7 @@ def _operator_from_dict(d: dict[str, Any]) -> OperatorInfo:
566573
return OperatorInfo(id=d.get("id") or 0, name=d.get("name") or "")
567574

568575

569-
@cache_read(EntityType.OPERATOR)
576+
@cache_read(CachedOperator)
570577
async def _list_operators_impl(
571578
request: ListOperatorsRequest, context: Context
572579
) -> ListOperatorsResponse:
@@ -641,7 +648,7 @@ def _additional_cost_from_dict(d: dict[str, Any]) -> AdditionalCostInfo:
641648
return AdditionalCostInfo(id=d.get("id") or 0, name=d.get("name") or "")
642649

643650

644-
@cache_read(EntityType.ADDITIONAL_COST)
651+
@cache_read(CachedAdditionalCost)
645652
async def _list_additional_costs_impl(
646653
request: ListAdditionalCostsRequest, context: Context
647654
) -> ListAdditionalCostsResponse:

0 commit comments

Comments
 (0)