Skip to content

Commit ebd3322

Browse files
nabinchhajohnnygrecoclaude
authored
refactor: Decouple ModelFacade from LiteLLM via ModelClient adapter (#373)
* plans for model facade overhaul * update plan * add review * address feedback + add more details after several self reviews * update plan doc * address nits * Add cannonical objects * self-review feedback + address * add LiteLLMRouter protocol to strongly type bridge router param Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * simplify some things * add a protol for http response like object * move HttpResponse * update PR-1 architecture notes for lifecycle and router protocol Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Address PR #359 feedback: exception wrapping, shared parsing, test improvements - Wrap all LiteLLM router calls in try/except to normalize raw exceptions into canonical ProviderError at the bridge boundary (blocking review item) - Extract reusable response-parsing helpers into clients/parsing.py for shared use across future native adapters - Add async image parsing path using httpx.AsyncClient to avoid blocking the event loop in agenerate_image - Add retry_after field to ProviderError for future retry engine support - Fix _to_int_or_none to parse numeric strings from providers - Create test conftest.py with shared mock_router/bridge_client fixtures - Parametrize duplicate image generation and error mapping tests - Add tests for exception wrapping across all bridge methods * Use contextlib to dry out some code * Address Greptile feedback: HTTP-date retry-after parsing, docstring clarity - Parse RFC 7231 HTTP-date strings in Retry-After header (used by Azure and Anthropic during rate-limiting) in addition to numeric delay-seconds - Clarify collect_non_none_optional_fields docstring explaining why f.default is None is the correct check for optional field forwarding - Add tests for HTTP-date and garbage Retry-After values * Address Greptile feedback: FastAPI detail parsing, comment fixes - Fix misleading comment about prompt field defaults in _IMAGE_EXCLUDE - Handle list-format detail arrays in _extract_structured_message for FastAPI/Pydantic validation errors - Document scope boundary for vision content in collect_raw_image_candidates * add PR-2 architecture notes for model facade overhaul * save progress on pr2 * small refactor * address feedback * Address greptile comment in pr1 * refactor ProviderError from dataclass to regular Exception - Replace @DataClass + __post_init__ with explicit __init__ that calls super().__init__ properly, avoiding brittle field-ordering dependency - Store cause via __cause__ only, removing the redundant .cause attr - Update match pattern in handle_llm_exceptions for non-dataclass type - Rename shadowed local `fields` to `optional_fields` in TransportKwargs * Address greptile feedback * PR feedback * track usage tracking in finally block for images * pr feedback * wrap facade close in try/catch * clean up stray params * fix stray inclusion of metadata * small regression fix * address more feedback --------- Co-authored-by: Johnny Greco <jogreco@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7de879a commit ebd3322

22 files changed

Lines changed: 1595 additions & 1511 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,6 @@ packages/data-designer/README.md
107107
.cursor/rules/cerebro.mdc
108108
.cursor/mcp.json
109109
.claude/rules/cerebro.md
110+
111+
# Claude worktrees
112+
.claude/worktrees/

packages/data-designer-engine/src/data_designer/engine/mcp/facade.py

Lines changed: 59 additions & 229 deletions
Large diffs are not rendered by default.

packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
map_http_error_to_provider_error,
1111
map_http_status_to_provider_error_kind,
1212
)
13+
from data_designer.engine.models.clients.factory import create_model_client
1314
from data_designer.engine.models.clients.types import (
1415
AssistantMessage,
1516
ChatCompletionRequest,
@@ -25,12 +26,12 @@
2526
)
2627

2728
__all__ = [
28-
"HttpResponse",
2929
"AssistantMessage",
3030
"ChatCompletionRequest",
3131
"ChatCompletionResponse",
3232
"EmbeddingRequest",
3333
"EmbeddingResponse",
34+
"HttpResponse",
3435
"ImageGenerationRequest",
3536
"ImageGenerationResponse",
3637
"ImagePayload",
@@ -39,6 +40,7 @@
3940
"ProviderErrorKind",
4041
"ToolCall",
4142
"Usage",
43+
"create_model_client",
4244
"map_http_error_to_provider_error",
4345
"map_http_status_to_provider_error_kind",
4446
]

packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
from data_designer.engine.models.clients.errors import (
1313
ProviderError,
1414
ProviderErrorKind,
15+
extract_message_from_exception_string,
1516
map_http_status_to_provider_error_kind,
1617
)
1718
from data_designer.engine.models.clients.parsing import (
1819
aextract_images_from_chat_response,
1920
aextract_images_from_image_response,
2021
aparse_chat_completion_response,
21-
collect_non_none_optional_fields,
2222
extract_embedding_vector,
2323
extract_images_from_chat_response,
2424
extract_images_from_image_response,
@@ -32,6 +32,7 @@
3232
EmbeddingResponse,
3333
ImageGenerationRequest,
3434
ImageGenerationResponse,
35+
TransportKwargs,
3536
)
3637

3738
logger = logging.getLogger(__name__)
@@ -75,57 +76,67 @@ def supports_image_generation(self) -> bool:
7576
return True
7677

7778
def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
79+
transport = TransportKwargs.from_request(request)
7880
with _handle_non_provider_errors(self.provider_name):
7981
response = self._router.completion(
8082
model=request.model,
8183
messages=request.messages,
82-
**collect_non_none_optional_fields(request),
84+
extra_headers=transport.headers or None,
85+
**transport.body,
8386
)
8487
return parse_chat_completion_response(response)
8588

8689
async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
90+
transport = TransportKwargs.from_request(request)
8791
with _handle_non_provider_errors(self.provider_name):
8892
response = await self._router.acompletion(
8993
model=request.model,
9094
messages=request.messages,
91-
**collect_non_none_optional_fields(request),
95+
extra_headers=transport.headers or None,
96+
**transport.body,
9297
)
9398
return await aparse_chat_completion_response(response)
9499

95100
def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
101+
transport = TransportKwargs.from_request(request)
96102
with _handle_non_provider_errors(self.provider_name):
97103
response = self._router.embedding(
98104
model=request.model,
99105
input=request.inputs,
100-
**collect_non_none_optional_fields(request),
106+
extra_headers=transport.headers or None,
107+
**transport.body,
101108
)
102109
vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])]
103110
return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response)
104111

105112
async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
113+
transport = TransportKwargs.from_request(request)
106114
with _handle_non_provider_errors(self.provider_name):
107115
response = await self._router.aembedding(
108116
model=request.model,
109117
input=request.inputs,
110-
**collect_non_none_optional_fields(request),
118+
extra_headers=transport.headers or None,
119+
**transport.body,
111120
)
112121
vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])]
113122
return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response)
114123

115124
def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
116-
image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE)
125+
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE)
117126
with _handle_non_provider_errors(self.provider_name):
118127
if request.messages is not None:
119128
response = self._router.completion(
120129
model=request.model,
121130
messages=request.messages,
122-
**image_kwargs,
131+
extra_headers=transport.headers or None,
132+
**transport.body,
123133
)
124134
else:
125135
response = self._router.image_generation(
126136
prompt=request.prompt,
127137
model=request.model,
128-
**image_kwargs,
138+
extra_headers=transport.headers or None,
139+
**transport.body,
129140
)
130141

131142
if request.messages is not None:
@@ -137,19 +148,21 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp
137148
return ImageGenerationResponse(images=images, usage=usage, raw=response)
138149

139150
async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
140-
image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE)
151+
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE)
141152
with _handle_non_provider_errors(self.provider_name):
142153
if request.messages is not None:
143154
response = await self._router.acompletion(
144155
model=request.model,
145156
messages=request.messages,
146-
**image_kwargs,
157+
extra_headers=transport.headers or None,
158+
**transport.body,
147159
)
148160
else:
149161
response = await self._router.aimage_generation(
150162
prompt=request.prompt,
151163
model=request.model,
152-
**image_kwargs,
164+
extra_headers=transport.headers or None,
165+
**transport.body,
153166
)
154167

155168
if request.messages is not None:
@@ -183,7 +196,7 @@ def _handle_non_provider_errors(provider_name: str) -> Iterator[None]:
183196

184197
raise ProviderError(
185198
kind=kind,
186-
message=str(exc),
199+
message=extract_message_from_exception_string(str(exc)),
187200
status_code=status_code if isinstance(status_code, int) else None,
188201
provider_name=provider_name,
189202
cause=exc,

packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import calendar
77
import email.utils
8+
import json
89
import time
9-
from dataclasses import dataclass
1010
from enum import Enum
1111

1212
from data_designer.engine.models.clients.types import HttpResponse
@@ -28,20 +28,26 @@ class ProviderErrorKind(str, Enum):
2828
UNSUPPORTED_CAPABILITY = "unsupported_capability"
2929

3030

31-
@dataclass
3231
class ProviderError(Exception):
33-
kind: ProviderErrorKind
34-
message: str
35-
status_code: int | None = None
36-
provider_name: str | None = None
37-
model_name: str | None = None
38-
retry_after: float | None = None
39-
cause: Exception | None = None
40-
41-
def __post_init__(self) -> None:
42-
Exception.__init__(self, self.message)
43-
if self.cause is not None:
44-
self.__cause__ = self.cause
32+
def __init__(
33+
self,
34+
kind: ProviderErrorKind,
35+
message: str,
36+
status_code: int | None = None,
37+
provider_name: str | None = None,
38+
model_name: str | None = None,
39+
retry_after: float | None = None,
40+
cause: Exception | None = None,
41+
) -> None:
42+
super().__init__(message)
43+
self.kind = kind
44+
self.message = message
45+
self.status_code = status_code
46+
self.provider_name = provider_name
47+
self.model_name = model_name
48+
self.retry_after = retry_after
49+
if cause is not None:
50+
self.__cause__ = cause
4551

4652
def __str__(self) -> str:
4753
return self.message
@@ -118,6 +124,31 @@ def map_http_error_to_provider_error(
118124
)
119125

120126

127+
def extract_message_from_exception_string(raw: str) -> str:
128+
"""Extract a human-readable message from a stringified LiteLLM exception.
129+
130+
LiteLLM often formats errors as ``"Error code: 400 - {json}"``. This
131+
mirrors the structured-key lookup in ``_extract_structured_message`` but
132+
operates on a raw string instead of an ``HttpResponse``.
133+
"""
134+
json_start = raw.find("{")
135+
if json_start != -1:
136+
try:
137+
payload = json.loads(raw[json_start:])
138+
except (json.JSONDecodeError, ValueError):
139+
return raw
140+
if isinstance(payload, dict):
141+
for key in ("message", "error", "detail"):
142+
value = payload.get(key)
143+
if isinstance(value, str) and value.strip():
144+
return value.strip()
145+
if isinstance(value, dict):
146+
nested = value.get("message")
147+
if isinstance(nested, str) and nested.strip():
148+
return nested.strip()
149+
return raw
150+
151+
121152
def _extract_response_text(response: HttpResponse) -> str:
122153
# Try structured JSON extraction first — most providers return structured error
123154
# bodies and we want the human-readable message, not raw JSON.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import data_designer.lazy_heavy_imports as lazy
7+
from data_designer.config.models import ModelConfig
8+
from data_designer.engine.model_provider import ModelProviderRegistry
9+
from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient
10+
from data_designer.engine.models.clients.base import ModelClient
11+
from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
12+
from data_designer.engine.secret_resolver import SecretResolver
13+
14+
15+
def create_model_client(
16+
model_config: ModelConfig,
17+
secret_resolver: SecretResolver,
18+
model_provider_registry: ModelProviderRegistry,
19+
) -> ModelClient:
20+
"""Create a ModelClient for the given model configuration.
21+
22+
Resolves the provider, API key, and constructs a LiteLLM router wrapped in
23+
a LiteLLMBridgeClient adapter.
24+
25+
Args:
26+
model_config: The model configuration to create a client for.
27+
secret_resolver: Resolver for secrets referenced in provider configs.
28+
model_provider_registry: Registry of model provider configurations.
29+
30+
Returns:
31+
A ModelClient instance ready for use.
32+
"""
33+
provider = model_provider_registry.get_provider(model_config.provider)
34+
api_key = None
35+
if provider.api_key:
36+
api_key = secret_resolver.resolve(provider.api_key)
37+
api_key = api_key or "not-used-but-required"
38+
39+
litellm_params = lazy.litellm.LiteLLM_Params(
40+
model=f"{provider.provider_type}/{model_config.model}",
41+
api_base=provider.endpoint,
42+
api_key=api_key,
43+
max_parallel_requests=model_config.inference_parameters.max_parallel_requests,
44+
)
45+
deployment = {
46+
"model_name": model_config.model,
47+
"litellm_params": litellm_params.model_dump(),
48+
}
49+
router = CustomRouter([deployment], **LiteLLMRouterDefaultKwargs().model_dump())
50+
return LiteLLMBridgeClient(provider_name=provider.name, router=router)

packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
from __future__ import annotations
77

8-
import dataclasses
98
import json
109
import logging
10+
import uuid
1111
from typing import Any
1212

1313
from data_designer.config.utils.image_helpers import (
@@ -206,7 +206,7 @@ def extract_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:
206206

207207
normalized_tool_calls: list[ToolCall] = []
208208
for raw_tool_call in raw_tool_calls:
209-
tool_call_id = get_value_from(raw_tool_call, "id") or ""
209+
tool_call_id = get_value_from(raw_tool_call, "id") or uuid.uuid4().hex
210210
function = get_value_from(raw_tool_call, "function")
211211
name = get_value_from(function, "name") or ""
212212
arguments_value = get_value_from(function, "arguments")
@@ -333,17 +333,3 @@ def get_first_value_or_none(values: Any) -> Any | None:
333333
if isinstance(values, list) and values:
334334
return values[0]
335335
return None
336-
337-
338-
def collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]:
339-
"""Extract non-None optional fields from a request dataclass, skipping *exclude*.
340-
341-
The ``f.default is None`` check intentionally targets fields whose default is
342-
``None`` — i.e. truly optional kwargs the caller may or may not set. Fields with
343-
non-``None`` defaults are not "optional" in this forwarding sense and are excluded.
344-
"""
345-
return {
346-
f.name: v
347-
for f in dataclasses.fields(request)
348-
if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None
349-
}

0 commit comments

Comments
 (0)