Skip to content

Commit 4bb2ddf

Browse files
committed
fix(types): harden ChatMessage.from_dict and fix LLMModel.stream protocol
1 parent 50e1d53 commit 4bb2ddf

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

nemoguardrails/types.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ def from_dict(cls, d: Dict[str, Any]) -> "ChatMessage":
151151
args_dict = json.loads(raw_args)
152152
except json.JSONDecodeError:
153153
raise ValueError(f"Tool call arguments are not valid JSON: {raw_args!r}")
154+
if not isinstance(args_dict, dict):
155+
raise ValueError(
156+
f"Tool call arguments must be a JSON object, got {type(args_dict).__name__}: {raw_args!r}"
157+
)
154158
else:
155159
args_dict = raw_args
156160

@@ -165,13 +169,17 @@ def from_dict(cls, d: Dict[str, Any]) -> "ChatMessage":
165169
)
166170
)
167171

172+
_standard_keys = {"role", "content", "tool_calls", "tool_call_id", "name", "provider_metadata"}
173+
extra = {k: v for k, v in d.items() if k not in _standard_keys}
174+
provider_metadata = {**d.get("provider_metadata", {}), **extra}
175+
168176
return cls(
169177
role=role,
170178
content=d.get("content"),
171179
tool_calls=tool_calls,
172180
tool_call_id=d.get("tool_call_id"),
173181
name=d.get("name"),
174-
provider_metadata=d.get("provider_metadata", {}),
182+
provider_metadata=provider_metadata,
175183
)
176184

177185

@@ -221,7 +229,7 @@ async def generate(
221229
**kwargs,
222230
) -> "LLMResponse": ...
223231

224-
async def stream(
232+
def stream(
225233
self,
226234
prompt: Union[str, List["ChatMessage"]],
227235
*,

tests/test_types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,20 @@ def test_from_dict_with_malformed_json_arguments(self):
220220
with pytest.raises(ValueError, match="not valid JSON"):
221221
ChatMessage.from_dict(d)
222222

223+
def test_from_dict_with_non_object_json_arguments(self):
224+
d = {
225+
"role": "assistant",
226+
"tool_calls": [
227+
{
228+
"id": "tc_1",
229+
"type": "function",
230+
"function": {"name": "search", "arguments": "[]"},
231+
}
232+
],
233+
}
234+
with pytest.raises(ValueError, match="must be a JSON object"):
235+
ChatMessage.from_dict(d)
236+
223237
def test_from_dict_with_legacy_flat_tool_calls(self):
224238
d = {
225239
"role": "assistant",
@@ -235,6 +249,11 @@ def test_from_dict_captures_provider_metadata(self):
235249
msg = ChatMessage.from_dict(d)
236250
assert msg.provider_metadata == {"custom_field": "value", "model": "gpt-4"}
237251

252+
def test_from_dict_unknown_keys_captured_into_provider_metadata(self):
253+
d = {"role": "user", "content": "hi", "unexpected_key": "v"}
254+
msg = ChatMessage.from_dict(d)
255+
assert msg.provider_metadata["unexpected_key"] == "v"
256+
238257
def test_from_dict_missing_provider_metadata_defaults_to_empty(self):
239258
msg = ChatMessage.from_dict({"role": "user", "content": "hi"})
240259
assert msg.provider_metadata == {}

0 commit comments

Comments
 (0)