Skip to content

Commit a3c552b

Browse files
authored
fix: Improved LangChain serialization (NVIDIA#165)
#### Overview Fix LangChain serialization, by adding a LangChain specific codec - [x] I confirm this contribution is my own work, or I have the right to submit it under this project's license. - [x] I searched existing issues and open pull requests, and this does not duplicate existing work. #### Details * This fixes the ability to use LLM intercepts * First pass at documenting the immutability of `LLMRequest` and the immutability of `AnnotatedLLMRequest` (I need to go back and updagte this for the other language bindings) #### Where should the reviewer start? `python/nemo_relay/integrations/langchain/_serialization.py` #### Related Issues: (use one of the action keywords Closes / Fixes / Resolves / Relates to) - Closes # ## Summary by CodeRabbit * **Documentation** * Clarified that LLM request objects are immutable; examples now show returning new request instances instead of mutating originals. * **Integrations** * LangChain integration reworked to use a dedicated codec for reliable request/response translation, role normalization, tool-call handling, and preservation of extra fields. * **Tests** * Expanded LangChain and middleware tests, including codec round-trip and interceptor behavior; added an end-to-end agent integration test. [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/NeMo-Relay/pull/165?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Will Killian (https://github.com/willkill07) URL: NVIDIA#165
1 parent 2d9b4f4 commit a3c552b

9 files changed

Lines changed: 276 additions & 99 deletions

File tree

docs/build-plugins/code-examples.mdx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ This page collects concrete examples for the surrounding guide area.
1313

1414
Use an LLM request intercept when a plugin needs to inject tenant or routing metadata into every provider request.
1515

16+
LLM request intercepts receive three arguments: `name`, `request`, and `annotated`. The `request` object is immutable, however it is possible to return a new instance of the request with edits, the exception to this is when the intercept is written in Rust.
17+
1618
<Tabs>
1719
<Tab title="Python" language="python">
1820
```python
@@ -30,8 +32,9 @@ class HeaderPlugin:
3032

3133
def register(self, plugin_config, context):
3234
def add_header(name, request, annotated):
33-
request.headers[plugin_config["header_name"]] = plugin_config["value"]
34-
return request, annotated
35+
headers = request.headers.copy()
36+
headers[plugin_config["header_name"]] = plugin_config["value"]
37+
return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated
3538

3639
context.register_llm_request_intercept("inject-header", 100, False, add_header)
3740

docs/build-plugins/register-behavior.mdx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ class HeaderPlugin:
122122

123123
def register(self, plugin_config, context):
124124
def add_header(name, request, annotated):
125-
request.headers[plugin_config["header_name"]] = plugin_config["value"]
126-
return request, annotated
125+
headers = request.headers.copy()
126+
headers[plugin_config["header_name"]] = plugin_config["value"]
127+
return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated
127128

128129
context.register_llm_request_intercept("inject-header", 100, False, add_header)
129130

docs/integrate-into-frameworks/provider-codecs.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def add_system_message(_name, request, annotated):
8989
if annotated is None:
9090
return request, annotated
9191

92+
# Attributes of the annotated request can be re-assigned, but cannot be modified in-place.
93+
# For example `annotated.messages.append(...)` would not work, but re-assigning
94+
# `annotated.messages = annotated.messages + [...]` does work.
9295
annotated.messages = [
9396
{"role": "system", "content": "Answer with concise technical detail."},
9497
*annotated.messages,

python/nemo_relay/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ def redact_args(tool_name, args):
3838
return {**args, "api_key": "***"}
3939
4040
def add_header(name, request, annotated):
41-
request.headers["Authorization"] = "Bearer test-token"
42-
return request, annotated
41+
headers = request.headers.copy()
42+
headers["Authorization"] = "Bearer test-token"
43+
return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated
4344
4445
async def tool_impl(args):
4546
return {"echo": args["query"]}

python/nemo_relay/integrations/langchain/_serialization.py

Lines changed: 178 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,33 @@
55

66
from __future__ import annotations
77

8+
import json
89
from typing import TYPE_CHECKING, Any
910

1011
from langchain.agents.middleware import ModelResponse
1112
from langchain_core.messages import (
13+
AIMessage,
1214
BaseMessage,
15+
HumanMessage,
1316
SystemMessage,
1417
ToolMessage,
1518
messages_from_dict,
1619
messages_to_dict,
1720
)
1821
from langgraph.types import Command, Send
1922

20-
from nemo_relay.codecs import AnthropicMessagesCodec, LlmCodec, OpenAIChatCodec, OpenAIResponsesCodec
23+
from nemo_relay import AnnotatedLLMRequest, LLMRequest
24+
from nemo_relay.codecs import LlmCodec
2125

2226
if TYPE_CHECKING:
2327
from langchain.agents.middleware import ModelRequest
2428

25-
26-
# In order to infer codec support from LangChain chat model types, we need to import them here.
27-
# However these may not be installed in the user's environment.
28-
_HAS_ANTHROPIC = False
29-
_HAS_OPENAI = False
30-
_HAS_NVIDIA = False
31-
try:
32-
from langchain_anthropic import ChatAnthropic
33-
34-
_HAS_ANTHROPIC = True
35-
except ImportError:
36-
pass
37-
38-
try:
39-
from langchain_openai import ChatOpenAI
40-
41-
_HAS_OPENAI = True
42-
except ImportError:
43-
pass
44-
45-
try:
46-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
47-
48-
_HAS_NVIDIA = True
49-
except ImportError:
50-
pass
51-
5229
LANGCHAIN_MODEL_RESPONSE_KEY = "__nemo_relay_integrations_langchain_model_response"
30+
_LANGCHAIN_MODELED_REQUEST_KEYS = {"messages", "model", "tool_choice", "tools"}
31+
_LC_TO_RELAY_MESSAGE_ROLE = {
32+
"human": "user",
33+
"ai": "assistant",
34+
}
5335

5436

5537
def get_model_name(model: Any) -> str | None:
@@ -61,24 +43,156 @@ def get_model_name(model: Any) -> str | None:
6143
return None
6244

6345

64-
def infer_codec_from_model(model: Any) -> LlmCodec | None:
65-
"""Infer a NeMo Relay codec name from a LangChain chat model."""
66-
if _HAS_ANTHROPIC:
67-
if isinstance(model, ChatAnthropic):
68-
return AnthropicMessagesCodec()
69-
70-
if _HAS_NVIDIA:
71-
if isinstance(model, ChatNVIDIA):
72-
return OpenAIChatCodec()
73-
74-
if _HAS_OPENAI:
75-
if isinstance(model, ChatOpenAI):
76-
if getattr(model, "use_responses_api", None) is True:
77-
return OpenAIResponsesCodec()
78-
79-
return OpenAIChatCodec()
80-
81-
return None
46+
class LangChainCodec(LlmCodec):
47+
"""Translate LangChain ``ModelRequest`` payloads for request intercepts."""
48+
49+
@classmethod
50+
def _langchain_tool_calls_to_annotated(cls, tool_calls: list[Any]) -> list[dict[str, Any]]:
51+
annotated_tool_calls = []
52+
for tool_call in tool_calls:
53+
args = tool_call["args"]
54+
arguments = args if isinstance(args, str) else json.dumps(args)
55+
annotated_tool_calls.append(
56+
{
57+
"id": tool_call.get("id") or "",
58+
"type": "function",
59+
"function": {
60+
"name": tool_call["name"],
61+
"arguments": arguments,
62+
},
63+
}
64+
)
65+
66+
return annotated_tool_calls
67+
68+
@classmethod
69+
def _annotated_tool_calls_to_langchain(cls, tool_calls: Any) -> list[dict[str, Any]] | None:
70+
if not isinstance(tool_calls, list) or not tool_calls:
71+
return None
72+
73+
langchain_tool_calls = []
74+
for tool_call in tool_calls:
75+
if not isinstance(tool_call, dict):
76+
continue
77+
function = tool_call.get("function")
78+
if isinstance(function, dict):
79+
name = str(function.get("name") or "")
80+
arguments = function.get("arguments", {})
81+
else:
82+
name = str(tool_call.get("name") or "")
83+
arguments = tool_call.get("args", {})
84+
85+
if isinstance(arguments, str):
86+
try:
87+
args = json.loads(arguments)
88+
except json.JSONDecodeError:
89+
args = {"arguments": arguments}
90+
elif isinstance(arguments, dict):
91+
args = arguments
92+
else:
93+
args = {}
94+
95+
langchain_tool_calls.append(
96+
{
97+
"name": name,
98+
"args": args,
99+
"id": str(tool_call.get("id") or ""),
100+
"type": "tool_call",
101+
}
102+
)
103+
104+
return langchain_tool_calls or None
105+
106+
@classmethod
107+
def _langchain_message_to_annotated(cls, message: BaseMessage) -> list[dict[str, Any]]:
108+
content = message.content
109+
if content is None:
110+
content = []
111+
elif isinstance(content, str):
112+
content = [content]
113+
114+
name = message.name
115+
role = _LC_TO_RELAY_MESSAGE_ROLE.get(message.type, message.type)
116+
117+
messages = []
118+
for msg in content:
119+
relay_message: dict[str, Any] = {"role": role}
120+
if isinstance(msg, str):
121+
relay_message["content"] = msg
122+
elif isinstance(msg, dict):
123+
relay_message.update(msg)
124+
if "content" not in relay_message:
125+
relay_message["content"] = relay_message.pop("text", "")
126+
else:
127+
raise ValueError(f"Unsupported LangChain message content type: {type(content)}")
128+
129+
if name is not None:
130+
relay_message["name"] = name
131+
132+
# Using getattr as we are inferring subclasses of BaseMessage based upon the role
133+
if role == "assistant":
134+
tool_calls = getattr(message, "tool_calls", [])
135+
relay_message["tool_calls"] = cls._langchain_tool_calls_to_annotated(tool_calls)
136+
elif role == "tool":
137+
relay_message["tool_call_id"] = getattr(message, "tool_call_id", "")
138+
139+
messages.append(relay_message)
140+
141+
return messages
142+
143+
@classmethod
144+
def _annotated_message_to_langchain(cls, message: dict[str, Any]) -> BaseMessage:
145+
role = message.get("role")
146+
content = message.get("content", "")
147+
name = message.get("name")
148+
149+
if role == "system":
150+
return SystemMessage(content=content, name=name)
151+
if role == "user":
152+
return HumanMessage(content=content, name=name)
153+
if role == "assistant":
154+
tool_calls = cls._annotated_tool_calls_to_langchain(message.get("tool_calls"))
155+
return AIMessage(content=content, name=name, tool_calls=tool_calls or [])
156+
if role == "tool":
157+
return ToolMessage(content=content, name=name, tool_call_id=str(message.get("tool_call_id") or ""))
158+
raise ValueError(f"Unsupported annotated LangChain message role: {role!r}")
159+
160+
def decode(self, request: LLMRequest) -> AnnotatedLLMRequest:
161+
"""Decode a LangChain-shaped request payload into an annotated request."""
162+
payload = request.content
163+
raw_messages = payload.get("messages", [])
164+
messages: list[dict[str, Any]] = []
165+
if isinstance(raw_messages, list):
166+
for message in messages_from_dict(raw_messages):
167+
messages.extend(self._langchain_message_to_annotated(message))
168+
169+
model = payload.get("model")
170+
tools = payload.get("tools")
171+
tool_choice = payload.get("tool_choice")
172+
extra = {key: value for key, value in payload.items() if key not in _LANGCHAIN_MODELED_REQUEST_KEYS}
173+
174+
return AnnotatedLLMRequest(
175+
messages,
176+
model=model if isinstance(model, str) else None,
177+
tools=tools if isinstance(tools, list) else None,
178+
tool_choice=tool_choice if isinstance(tool_choice, str | dict) else None,
179+
extra=extra or None,
180+
)
181+
182+
def encode(self, annotated: AnnotatedLLMRequest, original: LLMRequest) -> LLMRequest:
183+
"""Encode annotated request edits back into a LangChain-shaped payload."""
184+
payload = dict(original.content)
185+
payload.update(annotated.extra)
186+
payload["messages"] = messages_to_dict(
187+
[self._annotated_message_to_langchain(message) for message in annotated.messages]
188+
)
189+
if annotated.model is not None:
190+
payload["model"] = annotated.model
191+
if annotated.tools is not None:
192+
payload["tools"] = annotated.tools
193+
if annotated.tool_choice is not None:
194+
payload["tool_choice"] = annotated.tool_choice
195+
return LLMRequest(dict(original.headers), payload)
82196

83197

84198
def split_system_message(messages: list[BaseMessage]) -> tuple[SystemMessage | None, list[BaseMessage]]:
@@ -109,12 +223,12 @@ def model_request_to_payload(model_name: str | None, request: ModelRequest[Any])
109223

110224
def payload_to_model_request(
111225
original: ModelRequest[Any],
112-
payload: dict[str, Any],
226+
llm_request: LLMRequest,
113227
) -> ModelRequest[Any]:
114228
"""Apply supported NeMo Relay request-intercept edits back to ``ModelRequest``."""
115229
overrides: dict[str, Any] = {}
116230

117-
raw_messages = payload.get("messages")
231+
raw_messages = llm_request.content.get("messages")
118232
if isinstance(raw_messages, list) and len(raw_messages) > 0:
119233
try:
120234
system_message, messages = split_system_message(messages_from_dict(raw_messages))
@@ -123,12 +237,24 @@ def payload_to_model_request(
123237
except Exception:
124238
pass
125239

126-
model_settings = payload.get("model_settings")
240+
model_settings = llm_request.content.get("model_settings")
127241
if isinstance(model_settings, dict):
128-
overrides["model_settings"] = model_settings
242+
# Using dict() to ensure we have a copy
243+
model_settings_copy = dict(model_settings)
244+
extra_headers = model_settings_copy.get("extra_headers")
245+
if not isinstance(extra_headers, dict):
246+
extra_headers = {}
247+
overrides["model_settings"] = model_settings_copy
248+
else:
249+
overrides["model_settings"] = {}
250+
extra_headers = {}
251+
252+
if len(llm_request.headers) > 0:
253+
extra_headers.update(llm_request.headers)
254+
overrides["model_settings"]["extra_headers"] = extra_headers
129255

130-
if "tool_choice" in payload:
131-
overrides["tool_choice"] = payload["tool_choice"]
256+
if "tool_choice" in llm_request.content:
257+
overrides["tool_choice"] = llm_request.content["tool_choice"]
132258

133259
return original.override(**overrides) if overrides else original
134260

python/nemo_relay/integrations/langchain/middleware.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
import nemo_relay
1414
from nemo_relay.integrations.langchain._serialization import (
15+
LangChainCodec,
1516
get_model_name,
16-
infer_codec_from_model,
1717
model_request_to_payload,
1818
model_response_from_json,
1919
model_response_to_json,
@@ -72,7 +72,7 @@ def _prepare_model_call(self, request: ModelRequest[Any]) -> tuple:
7272
object_codec = nemo_relay.typed.BestEffortAnyCodec()
7373
model_name = get_model_name(request.model)
7474
llm_request = nemo_relay.LLMRequest({}, model_request_to_payload(model_name, request))
75-
model_codec = infer_codec_from_model(request.model)
75+
model_codec = LangChainCodec()
7676
return (object_codec, llm_request, model_name, model_codec)
7777

7878
def wrap_model_call(
@@ -83,8 +83,8 @@ def wrap_model_call(
8383
"""Wrap a sync LangChain agent model call in NeMo Relay LLM execution."""
8484
(object_codec, llm_request, model_name, model_codec) = self._prepare_model_call(request)
8585

86-
async def _call(req: Any) -> Any:
87-
response = handler(payload_to_model_request(request, req.content))
86+
async def _call(req: nemo_relay.LLMRequest) -> Any:
87+
response = handler(payload_to_model_request(request, req))
8888
return model_response_to_json(response, object_codec)
8989

9090
result = run_sync(
@@ -106,8 +106,8 @@ async def awrap_model_call(
106106
"""Wrap an async LangChain agent model call in NeMo Relay LLM execution."""
107107
(object_codec, llm_request, model_name, model_codec) = self._prepare_model_call(request)
108108

109-
async def _call(req: Any) -> Any:
110-
response = await handler(payload_to_model_request(request, req.content))
109+
async def _call(req: nemo_relay.LLMRequest) -> Any:
110+
response = await handler(payload_to_model_request(request, req))
111111
return model_response_to_json(response, object_codec)
112112

113113
result = await self._llm_execute(

0 commit comments

Comments
 (0)