Skip to content

Commit 1ca7510

Browse files
authored
feat(types): add framework-agnostic LLM type system (#1745)
Introduce nemoguardrails/types.py with provider-independent data types and protocols that replace direct LangChain type dependencies in core code. - Add `nemoguardrails/types.py` with framework-agnostic data types (`ChatMessage`, `Role`, `ToolCall`, `ToolCallFunction`, `LLMResponse`, `LLMResponseChunk`, `UsageInfo`, `FinishReason`) and protocols (`LLMModel`, `LLMFramework`) - `ChatMessage.from_dict()` handles both OpenAI nested and legacy flat tool call formats, JSON string argument parsing, and role aliases (`bot`, `human`, `developer`) - `ChatMessage.to_dict()` / `from_dict()` roundtrip preserves all fields including `provider_metadata` - `LLMModel` protocol defines the `generate()` / `stream()` contract that all LLM adapters must implement - `LLMFramework` protocol defines the `create_model()` factory contract for pluggable backends
1 parent b78f48b commit 1ca7510

2 files changed

Lines changed: 758 additions & 0 deletions

File tree

nemoguardrails/types.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
from dataclasses import dataclass, field
18+
from enum import Enum
19+
from typing import Any, AsyncIterator, Dict, List, Literal, Optional, Protocol, Union, runtime_checkable
20+
21+
22+
class Role(str, Enum):
23+
USER = "user"
24+
ASSISTANT = "assistant"
25+
SYSTEM = "system"
26+
TOOL = "tool"
27+
28+
29+
@dataclass
30+
class ToolCallFunction:
31+
name: str
32+
arguments: Dict[str, Any]
33+
34+
35+
@dataclass
36+
class ToolCall:
37+
id: str
38+
type: str = "function"
39+
function: ToolCallFunction = field(default_factory=lambda: ToolCallFunction(name="", arguments={}))
40+
41+
def to_dict(self) -> Dict[str, Any]:
42+
return {
43+
"id": self.id,
44+
"type": self.type,
45+
"function": {
46+
"name": self.function.name,
47+
"arguments": self.function.arguments,
48+
},
49+
}
50+
51+
52+
@dataclass
53+
class UsageInfo:
54+
input_tokens: int = 0
55+
output_tokens: int = 0
56+
total_tokens: int = 0
57+
reasoning_tokens: Optional[int] = None
58+
cached_tokens: Optional[int] = None
59+
60+
61+
FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "error", "other"]
62+
63+
64+
_STANDARD_MESSAGE_KEYS = {"role", "content", "tool_calls", "tool_call_id", "name", "provider_metadata"}
65+
66+
_ROLE_ALIASES = {
67+
"bot": Role.ASSISTANT,
68+
"assistant": Role.ASSISTANT,
69+
"human": Role.USER,
70+
"user": Role.USER,
71+
"developer": Role.SYSTEM,
72+
"system": Role.SYSTEM,
73+
"tool": Role.TOOL,
74+
}
75+
76+
77+
@dataclass
78+
class ChatMessage:
79+
role: Role
80+
content: Optional[Union[str, List[Dict[str, Any]]]] = None
81+
tool_calls: Optional[List[ToolCall]] = None
82+
tool_call_id: Optional[str] = None
83+
name: Optional[str] = None
84+
provider_metadata: Dict[str, Any] = field(default_factory=dict)
85+
86+
@classmethod
87+
def from_user(cls, content: str, **kwargs) -> "ChatMessage":
88+
return cls(role=Role.USER, content=content, **kwargs)
89+
90+
@classmethod
91+
def from_assistant(cls, content: str, **kwargs) -> "ChatMessage":
92+
return cls(role=Role.ASSISTANT, content=content, **kwargs)
93+
94+
@classmethod
95+
def from_system(cls, content: str, **kwargs) -> "ChatMessage":
96+
return cls(role=Role.SYSTEM, content=content, **kwargs)
97+
98+
@classmethod
99+
def from_tool(cls, content: str, tool_call_id: str, **kwargs) -> "ChatMessage":
100+
return cls(role=Role.TOOL, content=content, tool_call_id=tool_call_id, **kwargs)
101+
102+
def to_dict(self) -> Dict[str, Any]:
103+
payload: Dict[str, Any] = {"role": self.role.value}
104+
105+
if self.content is not None:
106+
payload["content"] = self.content
107+
108+
if self.tool_calls is not None:
109+
payload["tool_calls"] = [tc.to_dict() for tc in self.tool_calls]
110+
111+
if self.tool_call_id is not None:
112+
payload["tool_call_id"] = self.tool_call_id
113+
114+
if self.name is not None:
115+
payload["name"] = self.name
116+
117+
if self.provider_metadata:
118+
payload["provider_metadata"] = self.provider_metadata
119+
120+
return payload
121+
122+
@classmethod
123+
def from_dict(cls, d: Dict[str, Any]) -> "ChatMessage":
124+
"""Create a ChatMessage from a dict.
125+
126+
Accepts both the canonical nested tool call format
127+
(``{"function": {"name": ..., "arguments": ...}}``) and the legacy
128+
flat format (``{"name": ..., "args": ...}``). JSON string arguments
129+
are parsed automatically. Role aliases like "bot", "human", and
130+
"developer" are mapped to canonical Role values. Unknown keys are
131+
captured into ``provider_metadata``.
132+
"""
133+
134+
raw_role = d.get("role")
135+
if raw_role is None:
136+
raise ValueError("Missing required key: 'role'")
137+
role = _ROLE_ALIASES.get(raw_role)
138+
if role is None:
139+
raise ValueError(f"Unknown role: {raw_role}")
140+
141+
tool_calls = None
142+
raw_tool_calls = d.get("tool_calls")
143+
if raw_tool_calls is not None:
144+
tool_calls = []
145+
for tc in raw_tool_calls:
146+
func_data = tc.get("function")
147+
if func_data is not None:
148+
raw_args = func_data.get("arguments", {})
149+
else:
150+
raw_args = tc.get("args", {})
151+
func_data = {"name": tc.get("name", "")}
152+
153+
if isinstance(raw_args, str):
154+
try:
155+
args_dict = json.loads(raw_args)
156+
except json.JSONDecodeError:
157+
raise ValueError(f"Tool call arguments are not valid JSON: {raw_args!r}")
158+
if not isinstance(args_dict, dict):
159+
raise ValueError(
160+
f"Tool call arguments must be a JSON object, got {type(args_dict).__name__}: {raw_args!r}"
161+
)
162+
else:
163+
if not isinstance(raw_args, dict):
164+
raise ValueError(
165+
f"Tool call arguments must be a dict, got {type(raw_args).__name__}: {raw_args!r}"
166+
)
167+
args_dict = raw_args
168+
169+
tool_calls.append(
170+
ToolCall(
171+
id=tc.get("id", ""),
172+
type=tc.get("type", "function"),
173+
function=ToolCallFunction(
174+
name=func_data.get("name", ""),
175+
arguments=args_dict,
176+
),
177+
)
178+
)
179+
180+
extra = {k: v for k, v in d.items() if k not in _STANDARD_MESSAGE_KEYS}
181+
provider_metadata = {**extra, **d.get("provider_metadata", {})}
182+
183+
return cls(
184+
role=role,
185+
content=d.get("content"),
186+
tool_calls=tool_calls,
187+
tool_call_id=d.get("tool_call_id"),
188+
name=d.get("name"),
189+
provider_metadata=provider_metadata,
190+
)
191+
192+
193+
@dataclass
194+
class LLMResponse:
195+
content: str
196+
reasoning: Optional[str] = None
197+
tool_calls: Optional[List[ToolCall]] = None
198+
model: Optional[str] = None
199+
finish_reason: Optional[FinishReason] = None
200+
stop_sequence: Optional[str] = None
201+
request_id: Optional[str] = None
202+
usage: Optional[UsageInfo] = None
203+
provider_metadata: Optional[Dict[str, Any]] = None
204+
205+
206+
@dataclass
207+
class LLMResponseChunk:
208+
delta_content: Optional[str] = None
209+
delta_reasoning: Optional[str] = None
210+
delta_tool_calls: Optional[List[ToolCall]] = None
211+
model: Optional[str] = None
212+
finish_reason: Optional[FinishReason] = None
213+
request_id: Optional[str] = None
214+
usage: Optional[UsageInfo] = None
215+
provider_metadata: Optional[Dict[str, Any]] = None
216+
217+
218+
@runtime_checkable
219+
class LLMModel(Protocol):
220+
"""Protocol that all LLM backends must implement.
221+
222+
Adapters wrap provider-specific SDKs (LangChain, LiteLLM, OpenAI, etc.)
223+
behind this interface so the core pipeline remains framework-agnostic.
224+
225+
``prompt`` accepts either a plain string or a list of ``ChatMessage``
226+
objects. Adapters convert ``ChatMessage`` to whatever their SDK expects.
227+
``**kwargs`` are forwarded to the underlying SDK (e.g. temperature,
228+
max_tokens).
229+
"""
230+
231+
async def generate(
232+
self,
233+
prompt: Union[str, List["ChatMessage"]],
234+
*,
235+
stop: Optional[List[str]] = None,
236+
**kwargs,
237+
) -> "LLMResponse": ...
238+
239+
def stream(
240+
self,
241+
prompt: Union[str, List["ChatMessage"]],
242+
*,
243+
stop: Optional[List[str]] = None,
244+
**kwargs,
245+
) -> AsyncIterator["LLMResponseChunk"]:
246+
"""Implementations must be async generator functions (use ``yield``)."""
247+
...
248+
249+
@property
250+
def model_name(self) -> str: ...
251+
252+
@property
253+
def provider_name(self) -> Optional[str]: ...
254+
255+
@property
256+
def provider_url(self) -> Optional[str]: ...
257+
258+
259+
@runtime_checkable
260+
class LLMFramework(Protocol):
261+
"""Protocol for pluggable LLM framework backends.
262+
263+
Each framework (LangChain, LiteLLM, etc.) implements this protocol to
264+
provide a factory for creating ``LLMModel`` instances.
265+
266+
``model_kwargs`` carries all provider-specific configuration. Framework
267+
implementations extract what they need (e.g. LangChain pops ``mode``
268+
to choose between chat and text completion models).
269+
"""
270+
271+
def create_model(
272+
self,
273+
model_name: str,
274+
provider_name: str,
275+
model_kwargs: Optional[Dict[str, Any]] = None,
276+
) -> LLMModel: ...

0 commit comments

Comments
 (0)