Skip to content

Commit cb77d85

Browse files
committed
refactor(llm): extract shared helpers in LangChain adapter
Extract _extract_tool_calls, _extract_usage, _extract_model_info, _build_provider_metadata from the response/chunk converters. Remove dead else branch in tool call extraction (LangChain tool_calls is always List[dict]). Type the response parameters. s
1 parent 3982772 commit cb77d85

1 file changed

Lines changed: 78 additions & 87 deletions

File tree

nemoguardrails/integrations/langchain/llm_adapter.py

Lines changed: 78 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,24 @@ def _build_usage_info(raw: Any) -> Optional[UsageInfo]:
263263
)
264264

265265

266-
def _extract_reasoning(response) -> Optional[str]:
266+
_EXTRACTED_METADATA_KEYS = frozenset(
267+
{
268+
"model_name",
269+
"model",
270+
"finish_reason",
271+
"stop_reason",
272+
"stop_sequence",
273+
"id",
274+
"request_id",
275+
"token_usage",
276+
"usage",
277+
}
278+
)
279+
280+
_REASONING_KEYS = frozenset({"reasoning_content"})
281+
282+
283+
def _extract_reasoning(response: Any) -> Optional[str]:
267284
content_blocks = getattr(response, "content_blocks", None)
268285
if content_blocks:
269286
for block in content_blocks:
@@ -281,94 +298,85 @@ def _extract_reasoning(response) -> Optional[str]:
281298
return None
282299

283300

284-
def _langchain_response_to_llm_response(response) -> LLMResponse:
285-
content = getattr(response, "content", None)
286-
if content is None:
287-
content = str(response)
301+
def _extract_tool_calls(response: Any) -> Optional[List[ToolCall]]:
302+
raw = getattr(response, "tool_calls", None)
303+
if not raw:
304+
return None
305+
return [
306+
ToolCall(
307+
id=tc.get("id") or str(uuid.uuid4()),
308+
type="function",
309+
function=ToolCallFunction(
310+
name=tc.get("name", ""),
311+
arguments=tc.get("args", {}),
312+
),
313+
)
314+
for tc in raw
315+
]
288316

289-
reasoning = _extract_reasoning(response)
290-
291-
raw_tool_calls = getattr(response, "tool_calls", None)
292-
tool_calls = None
293-
if raw_tool_calls:
294-
tool_calls = []
295-
for tc in raw_tool_calls:
296-
if isinstance(tc, dict):
297-
tool_calls.append(
298-
ToolCall(
299-
id=tc.get("id") or str(uuid.uuid4()),
300-
type="function",
301-
function=ToolCallFunction(
302-
name=tc.get("name", ""),
303-
arguments=tc.get("args", {}),
304-
),
305-
)
306-
)
307-
else:
308-
tool_calls.append(
309-
ToolCall(
310-
id=getattr(tc, "id", None) or str(uuid.uuid4()),
311-
type="function",
312-
function=ToolCallFunction(
313-
name=getattr(tc, "name", ""),
314-
arguments=getattr(tc, "args", {}),
315-
),
316-
)
317-
)
318317

319-
response_metadata = getattr(response, "response_metadata", None) or {}
320-
additional_kwargs = getattr(response, "additional_kwargs", None) or {}
318+
def _extract_usage(response: Any) -> Optional[UsageInfo]:
319+
usage = _build_usage_info(getattr(response, "usage_metadata", None))
320+
if usage is not None:
321+
return usage
321322

322-
usage_metadata = getattr(response, "usage_metadata", None)
323-
usage = _build_usage_info(usage_metadata)
324-
if usage is None and response_metadata:
325-
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage")
323+
for source in (
324+
getattr(response, "response_metadata", None) or {},
325+
getattr(response, "generation_info", None) or {},
326+
):
327+
token_usage = source.get("token_usage") or source.get("usage")
326328
if token_usage:
327329
usage = _build_usage_info(token_usage)
330+
if usage is not None:
331+
return usage
332+
333+
return None
328334

329-
model = response_metadata.get("model_name") or response_metadata.get("model")
330335

336+
def _extract_model_info(response_metadata: Dict[str, Any]) -> tuple:
337+
model = response_metadata.get("model_name") or response_metadata.get("model")
331338
raw_finish = response_metadata.get("finish_reason") or response_metadata.get("stop_reason")
332339
finish_reason = _map_finish_reason(raw_finish)
333-
334340
stop_sequence = response_metadata.get("stop_sequence")
335-
336341
request_id = response_metadata.get("id") or response_metadata.get("request_id")
342+
return model, finish_reason, stop_sequence, request_id
343+
344+
345+
def _build_provider_metadata(
346+
response_metadata: Dict[str, Any],
347+
additional_kwargs: Optional[Dict[str, Any]] = None,
348+
) -> Optional[Dict[str, Any]]:
349+
result: Dict[str, Any] = {k: v for k, v in response_metadata.items() if k not in _EXTRACTED_METADATA_KEYS}
350+
if additional_kwargs:
351+
for k, v in additional_kwargs.items():
352+
if k not in _REASONING_KEYS and k not in result:
353+
result[k] = v
354+
return result or None
337355

338-
extracted_keys = {
339-
"model_name",
340-
"model",
341-
"finish_reason",
342-
"stop_reason",
343-
"stop_sequence",
344-
"id",
345-
"request_id",
346-
"token_usage",
347-
"usage",
348-
}
349-
reasoning_keys = {"reasoning_content"}
350-
provider_metadata: Dict[str, Any] = {}
351-
for k, v in response_metadata.items():
352-
if k not in extracted_keys:
353-
provider_metadata[k] = v
354-
for k, v in additional_kwargs.items():
355-
if k not in reasoning_keys and k not in provider_metadata:
356-
provider_metadata[k] = v
356+
357+
def _langchain_response_to_llm_response(response: Any) -> LLMResponse:
358+
content = getattr(response, "content", None)
359+
if content is None:
360+
content = str(response)
361+
362+
response_metadata = getattr(response, "response_metadata", None) or {}
363+
additional_kwargs = getattr(response, "additional_kwargs", None) or {}
364+
model, finish_reason, stop_sequence, request_id = _extract_model_info(response_metadata)
357365

358366
return LLMResponse(
359367
content=content,
360-
reasoning=reasoning,
361-
tool_calls=tool_calls,
368+
reasoning=_extract_reasoning(response),
369+
tool_calls=_extract_tool_calls(response),
362370
model=model,
363371
finish_reason=finish_reason,
364372
stop_sequence=stop_sequence,
365373
request_id=request_id,
366-
usage=usage,
367-
provider_metadata=provider_metadata if provider_metadata else None,
374+
usage=_extract_usage(response),
375+
provider_metadata=_build_provider_metadata(response_metadata, additional_kwargs),
368376
)
369377

370378

371-
def _langchain_chunk_to_llm_response_chunk(chunk) -> LLMResponseChunk:
379+
def _langchain_chunk_to_llm_response_chunk(chunk: Any) -> LLMResponseChunk:
372380
content = getattr(chunk, "content", None)
373381
if content is None:
374382
content = getattr(chunk, "text", None)
@@ -377,27 +385,10 @@ def _langchain_chunk_to_llm_response_chunk(chunk) -> LLMResponseChunk:
377385

378386
response_metadata = getattr(chunk, "response_metadata", None) or {}
379387
generation_info = getattr(chunk, "generation_info", None) or {}
380-
381-
usage_metadata = getattr(chunk, "usage_metadata", None)
382-
usage = _build_usage_info(usage_metadata)
383-
if usage is None and response_metadata:
384-
token_usage = response_metadata.get("token_usage") or response_metadata.get("usage")
385-
if token_usage:
386-
usage = _build_usage_info(token_usage)
387-
if usage is None and generation_info:
388-
token_usage = generation_info.get("token_usage") or generation_info.get("usage")
389-
if token_usage:
390-
usage = _build_usage_info(token_usage)
391-
392-
provider_metadata: Dict[str, Any] = {}
393-
for k, v in response_metadata.items():
394-
provider_metadata[k] = v
395-
for k, v in generation_info.items():
396-
if k not in provider_metadata:
397-
provider_metadata[k] = v
388+
merged_metadata = {**response_metadata, **generation_info}
398389

399390
return LLMResponseChunk(
400391
delta_content=content,
401-
usage=usage,
402-
provider_metadata=provider_metadata if provider_metadata else None,
392+
usage=_extract_usage(chunk),
393+
provider_metadata=merged_metadata or None,
403394
)

0 commit comments

Comments
 (0)