Skip to content

Commit 9778bf3

Browse files
apartsinclaude
andcommitted
feat: fix 4 broken doc samples, add middleware param, CallbackObservability, custom policy example
Code fixes (verified by audit): - Add `middleware` parameter to `create()` in Python and TypeScript - Create `CallbackObservability` CDK class (was referenced in FAQ but missing) - Add `router` getter on TS ModelMesh, `setMiddleware()` on TS Router Doc fixes: - FAQ Q8: Fix `explain()` call (add required `model` arg, fix snake_case key) - FAQ Q10: Add custom rotation policy example with `select()` override - FAQ Q10: Add YAML registration pattern for custom policies All 1,879 tests pass (1,166 Python + 713 TypeScript). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ac7e9cc commit 9778bf3

File tree

8 files changed

+191
-8
lines changed

8 files changed

+191
-8
lines changed

docs/guides/FAQ.md

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,9 @@ expect(client.calls.length).toBe(1);
310310
Debug routing decisions without making API calls:
311311

312312
```python
313-
explanation = client.explain()
314-
print(explanation["selectedModel"]) # Which model would be selected
315-
print(explanation["reason"]) # Why
313+
explanation = client.explain(model="chat-completion")
314+
print(explanation["selected_model"]) # Which model would be selected
315+
print(explanation["reason"]) # Why
316316
```
317317

318318
See the [Mock Client and Testing](Testing.md) guide.
@@ -459,6 +459,49 @@ provider = CorpLLMProvider(BaseProviderConfig(
459459

460460
Override only what differs: `_get_completion_endpoint()` for the URL path, `_build_headers()` for authentication, `_build_request_payload()` to translate the request format, and `_parse_response()` to translate the response back. For streaming, also override `_parse_sse_chunk()`.
461461

462+
**Custom rotation policy:**
463+
464+
Inherit from `BaseRotationPolicy` and override `select()` to control how models are chosen, `should_deactivate()` to control when a model is taken offline, or `should_recover()` to control when it comes back.
465+
466+
```python
467+
from modelmesh.cdk import BaseRotationPolicy, BaseRotationConfig
468+
from modelmesh.interfaces.rotation import ModelState
469+
from modelmesh.interfaces.provider import CompletionRequest
470+
from typing import Optional
471+
472+
class CostAwarePolicy(BaseRotationPolicy):
473+
"""Pick the cheapest model that hasn't exceeded its error threshold."""
474+
475+
def select(
476+
self,
477+
candidates: list[ModelState],
478+
request: CompletionRequest,
479+
) -> Optional[ModelState]:
480+
if not candidates:
481+
return None
482+
# Sort by cost (lowest first), break ties by error rate
483+
return min(candidates, key=lambda c: (c.total_cost, c.error_rate))
484+
```
485+
486+
Register the policy in YAML by pointing `strategy` at your custom class, or pass it programmatically:
487+
488+
```yaml
489+
pools:
490+
chat:
491+
capability: generation.text-generation.chat-completion
492+
strategy: my_app.policies.CostAwarePolicy
493+
```
494+
495+
```python
496+
# Or register programmatically
497+
from modelmesh.cdk import ThresholdRotationPolicy, ThresholdRotationConfig
498+
499+
policy = CostAwarePolicy(BaseRotationConfig(
500+
failure_threshold=5,
501+
cooldown_seconds=120,
502+
))
503+
```
504+
462505
Six connector types are extensible this way: providers, rotation policies, secret stores, storage backends, observability sinks, and discovery connectors.
463506

464507
See the [CDK](../ConnectorCatalogue.md) reference and [CDK Developer Guide](../cdk/DeveloperGuide.md).

src/python/modelmesh/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def create(
7878
strategy: str = "stick-until-failure",
7979
api_keys: dict[str, str] | None = None,
8080
config: str | dict | MeshConfig | None = None,
81+
middleware: list[Middleware] | None = None,
8182
) -> MeshClient:
8283
"""Create an OpenAI SDK-compatible client with ModelMesh routing.
8384
@@ -109,6 +110,8 @@ def create(
109110
config: Full configuration -- YAML file path, dict, or
110111
``MeshConfig`` object. When provided, auto-detection is
111112
skipped.
113+
middleware: List of :class:`Middleware` instances to attach to the
114+
router. Middleware runs before and after each provider call.
112115
113116
Returns:
114117
``MeshClient``: OpenAI SDK-compatible client with ModelMesh
@@ -158,7 +161,10 @@ def create(
158161
f"{type(config).__name__}"
159162
)
160163
mesh.initialize(mesh_config)
161-
return mesh.get_client()
164+
client = mesh.get_client()
165+
if middleware:
166+
mesh._router._middleware = MiddlewareStack(middleware)
167+
return client
162168

163169
if not capabilities and pool is None:
164170
raise ValueError(
@@ -183,7 +189,10 @@ def create(
183189
strategy=strategy,
184190
)
185191
mesh.initialize(MeshConfig(raw=raw_config))
186-
return mesh.get_client()
192+
client = mesh.get_client()
193+
if middleware:
194+
mesh._router._middleware = MiddlewareStack(middleware)
195+
return client
187196

188197

189198
# Well-known short names mapped to full capability tree paths.

src/python/modelmesh/cdk/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
# -- Specialized classes -----------------------------------------------------
6161

6262
from modelmesh.cdk.specialized import (
63+
CallbackObservability,
64+
CallbackObservabilityConfig,
6365
ConsoleObservability,
6466
ConsoleObservabilityConfig,
6567
FileObservability,
@@ -132,6 +134,8 @@
132134
"FileSecretStore",
133135
"KeyValueStorageConfig",
134136
"KeyValueStorage",
137+
"CallbackObservabilityConfig",
138+
"CallbackObservability",
135139
"ConsoleObservabilityConfig",
136140
"ConsoleObservability",
137141
"FileObservabilityConfig",

src/python/modelmesh/cdk/specialized/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
"""
1818
from __future__ import annotations
1919

20+
from modelmesh.cdk.specialized.callback_observability import (
21+
CallbackObservability,
22+
CallbackObservabilityConfig,
23+
)
2024
from modelmesh.cdk.specialized.console_observability import (
2125
ConsoleObservability,
2226
ConsoleObservabilityConfig,
@@ -74,6 +78,8 @@
7478
"KeyValueStorageConfig",
7579
"KeyValueStorage",
7680
# Observability
81+
"CallbackObservabilityConfig",
82+
"CallbackObservability",
7783
"ConsoleObservabilityConfig",
7884
"ConsoleObservability",
7985
"FileObservabilityConfig",
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Callback-based observability sink.
2+
3+
Routes events, logs, and traces to a user-supplied callback function,
4+
enabling integration with custom dashboards, message queues, or alerting
5+
systems without subclassing.
6+
7+
Usage::
8+
9+
from modelmesh.cdk import CallbackObservability, CallbackObservabilityConfig
10+
11+
def on_event(event):
12+
my_dashboard.send(event.event_type, event.model_id, event.timestamp)
13+
14+
obs = CallbackObservability(CallbackObservabilityConfig(
15+
callback=on_event,
16+
))
17+
"""
18+
from __future__ import annotations
19+
20+
from dataclasses import dataclass, field
21+
from typing import Any, Callable, Optional
22+
23+
from modelmesh.cdk.base_observability import (
24+
BaseObservability,
25+
BaseObservabilityConfig,
26+
)
27+
from modelmesh.interfaces.observability import (
28+
RequestLogEntry,
29+
RoutingEvent,
30+
Severity,
31+
TraceEntry,
32+
)
33+
34+
__all__ = [
35+
"CallbackObservabilityConfig",
36+
"CallbackObservability",
37+
]
38+
39+
40+
@dataclass
41+
class CallbackObservabilityConfig(BaseObservabilityConfig):
42+
"""Configuration for CallbackObservability.
43+
44+
Attributes:
45+
callback: Function called with each event, log entry, or trace.
46+
Receives the original object (RoutingEvent, RequestLogEntry,
47+
or TraceEntry), not a formatted string.
48+
"""
49+
50+
callback: Optional[Callable[[Any], None]] = None
51+
52+
53+
class CallbackObservability(BaseObservability):
54+
"""Observability sink that routes events to a user-supplied callback.
55+
56+
The callback receives the original event/log/trace object (not a
57+
formatted string), enabling integration with custom dashboards,
58+
message queues, or alerting systems.
59+
60+
Respects all BaseObservability filters (event_filter, min_severity,
61+
redact_secrets) before invoking the callback.
62+
"""
63+
64+
def __init__(self, config: CallbackObservabilityConfig) -> None:
65+
super().__init__(config)
66+
self._callback_fn = config.callback
67+
68+
def emit(self, event: RoutingEvent) -> None:
69+
"""Emit a routing event to the callback.
70+
71+
Applies event_filter from config before invoking.
72+
"""
73+
if self._config.event_filter:
74+
if event.event_type.value not in self._config.event_filter:
75+
return
76+
if self._callback_fn:
77+
self._callback_fn(event)
78+
79+
def log(self, entry: RequestLogEntry) -> None:
80+
"""Route a request/response log entry to the callback."""
81+
if self._callback_fn:
82+
self._callback_fn(entry)
83+
84+
def trace(self, entry: TraceEntry) -> None:
85+
"""Route a trace entry to the callback, filtering by min_severity."""
86+
min_level = self._SEVERITY_ORDER.get(
87+
Severity(self._config.min_severity), 1
88+
)
89+
entry_level = self._SEVERITY_ORDER.get(entry.severity, 0)
90+
if entry_level < min_level:
91+
return
92+
if self._callback_fn:
93+
self._callback_fn(entry)
94+
95+
def _write(self, line: str) -> None:
96+
"""No-op: output goes through callback, not formatted _write."""
97+
pass

src/typescript/src/core/mesh.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ export class ModelMesh {
5151
private _observability: ObservabilityConnector | null = null;
5252
private _initialized = false;
5353

54+
/** Access the underlying Router instance. */
55+
get router(): Router {
56+
if (!this._router) {
57+
throw new Error('ModelMesh not initialized. Call initialize() first.');
58+
}
59+
return this._router;
60+
}
61+
5462
// -- Lifecycle -----------------------------------------------------------
5563

5664
/**

src/typescript/src/core/router.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export class Router {
5151
private readonly _emitter: EventEmitter;
5252
private readonly _observability: ObservabilityConnector | null;
5353
private readonly _maxRetries: number;
54-
private readonly _middleware: MiddlewareStack | null;
54+
private _middleware: MiddlewareStack | null;
5555

5656
constructor(
5757
pools: Record<string, CapabilityPool>,
@@ -71,6 +71,11 @@ export class Router {
7171
this._middleware = middleware ?? null;
7272
}
7373

74+
/** Replace the middleware stack. Called by create() when middleware is provided. */
75+
setMiddleware(stack: MiddlewareStack): void {
76+
this._middleware = stack;
77+
}
78+
7479
private _trace(
7580
severity: Severity | string,
7681
component: string,

src/typescript/src/index.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ export interface CreateOptions {
199199
strategy?: string;
200200
apiKeys?: Record<string, string>;
201201
config?: string | Record<string, unknown> | MeshConfig;
202+
/** Middleware instances to attach to the router. */
203+
middleware?: Middleware[];
202204
}
203205

204206
/**
@@ -247,6 +249,7 @@ export function create(...args: unknown[]): MeshClient {
247249
strategy = 'stick-until-failure',
248250
apiKeys,
249251
config,
252+
middleware: middlewareList,
250253
} = options;
251254

252255
const mesh = new ModelMesh();
@@ -262,7 +265,11 @@ export function create(...args: unknown[]): MeshClient {
262265
meshConfig = MeshConfig.fromDict(config);
263266
}
264267
mesh.initialize(meshConfig);
265-
return mesh.getClient();
268+
const client = mesh.getClient();
269+
if (middlewareList && middlewareList.length > 0) {
270+
mesh.router.setMiddleware(new MiddlewareStack(middlewareList));
271+
}
272+
return client;
266273
}
267274

268275
if (capabilities.length === 0 && pool === undefined) {
@@ -325,5 +332,9 @@ export function create(...args: unknown[]): MeshClient {
325332
pools: poolsSection,
326333
observability: { connector: 'modelmesh.null.v1' },
327334
}));
328-
return mesh.getClient();
335+
const client = mesh.getClient();
336+
if (middlewareList && middlewareList.length > 0) {
337+
mesh.router.setMiddleware(new MiddlewareStack(middlewareList));
338+
}
339+
return client;
329340
}

0 commit comments

Comments
 (0)