Skip to content

Commit f2d4492

Browse files
authored
feat(llm): add streaming tool call accumulation and LLMResponse parity (#1789)
addresses #1760 (comment) part of #1760
1 parent 5c2c1ea commit f2d4492

4 files changed

Lines changed: 372 additions & 13 deletions

File tree

nemoguardrails/actions/llm/utils.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from nemoguardrails.exceptions import LLMCallException
3030
from nemoguardrails.logging.explain import LLMCallInfo
3131
from nemoguardrails.logging.llm_tracker import track_llm_call
32-
from nemoguardrails.types import ChatMessage, LLMModel, LLMResponse, LLMResponseChunk
32+
from nemoguardrails.types import ChatMessage, LLMModel, LLMResponse, LLMResponseChunk, UsageInfo
3333

3434
if TYPE_CHECKING:
3535
from nemoguardrails.streaming import StreamingHandler
@@ -98,43 +98,84 @@ async def _stream_llm_call(
9898
llm_params: Optional[dict] = None,
9999
) -> LLMResponse:
100100
handler.stop = stop or []
101-
accumulated_metadata: Dict[str, Any] = {}
102-
last_chunk: Optional[LLMResponseChunk] = None
101+
streaming_handler_metadata: Dict[str, Any] = {}
102+
accumulated_provider_metadata: Dict[str, Any] = {}
103+
accumulated_reasoning: List[str] = []
104+
tool_calls = None
105+
model_name: Optional[str] = None
106+
finish_reason: Optional[str] = None
107+
request_id: Optional[str] = None
108+
usage: Optional[UsageInfo] = None
103109

104110
try:
105111
async for chunk in model.stream_async(prompt, stop=stop, **(llm_params or {})):
106-
last_chunk = chunk
107112
content = chunk.delta_content or ""
108113

114+
if chunk.delta_reasoning:
115+
accumulated_reasoning.append(chunk.delta_reasoning)
116+
if chunk.delta_tool_calls:
117+
tool_calls = chunk.delta_tool_calls
118+
if chunk.model:
119+
model_name = chunk.model
120+
if chunk.finish_reason:
121+
finish_reason = chunk.finish_reason
122+
if chunk.request_id:
123+
request_id = chunk.request_id
124+
if chunk.usage:
125+
usage = chunk.usage
126+
if chunk.provider_metadata:
127+
accumulated_provider_metadata.update(chunk.provider_metadata)
128+
109129
chunk_metadata = _extract_chunk_metadata(chunk)
110130
if chunk_metadata:
111-
accumulated_metadata.update(chunk_metadata)
131+
streaming_handler_metadata.update(chunk_metadata)
112132

113133
await handler.push_chunk(content, chunk_metadata)
114134

115-
if accumulated_metadata:
116-
llm_response_metadata_var.set(accumulated_metadata)
135+
llm_response_metadata_var.set(accumulated_provider_metadata or None)
117136

118137
await handler.finish()
119138

120139
llm_call_info = llm_call_info_var.get()
121140
if llm_call_info:
122141
llm_call_info.completion = handler.completion
123142

124-
if last_chunk is not None:
125-
_update_token_stats_from_chunk(last_chunk)
143+
if usage:
144+
fake_chunk = LLMResponseChunk(usage=usage)
145+
_update_token_stats_from_chunk(fake_chunk)
146+
147+
if tool_calls:
148+
tool_calls_var.set([tc.to_dict() for tc in tool_calls])
149+
else:
150+
tool_calls_var.set(None)
151+
152+
reasoning_content = "".join(accumulated_reasoning) if accumulated_reasoning else None
153+
# TODO: call _extract_and_remove_think_tags on the completed response
154+
# to handle models that stream reasoning via <think> tags in content
155+
# rather than via delta_reasoning. Pre-existing gap, not introduced here.
156+
reasoning_trace_var.set(reasoning_content)
126157

127158
return LLMResponse(
128159
content=handler.completion,
129-
usage=last_chunk.usage if last_chunk else None,
130-
provider_metadata=accumulated_metadata if accumulated_metadata else None,
160+
reasoning=reasoning_content,
161+
tool_calls=tool_calls,
162+
model=model_name,
163+
finish_reason=finish_reason,
164+
request_id=request_id,
165+
usage=usage,
166+
provider_metadata=accumulated_provider_metadata or None,
131167
)
132168

133169
except Exception as e:
134170
_raise_llm_call_exception(e, model)
135171

136172

137173
def _extract_chunk_metadata(chunk: LLMResponseChunk) -> Optional[Dict[str, Any]]:
174+
# This feeds handler.push_chunk() for the StreamingHandler consumer path
175+
# (API responses, output rails). Separate from the field accumulation in
176+
# _stream_llm_call which builds the returned LLMResponse for the pipeline.
177+
# TODO(Pouyanpi): consider pushing tool_calls and reasoning through the handler too,
178+
# so output rails and streaming consumers can see them in real-time.
138179
metadata: Dict[str, Any] = {}
139180
if chunk.provider_metadata:
140181
metadata["provider_metadata"] = chunk.provider_metadata

nemoguardrails/integrations/langchain/llm_adapter.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import uuid
1819
from typing import Any, AsyncIterator, Dict, List, NamedTuple, Optional, Union
@@ -205,8 +206,33 @@ async def stream_async(
205206
) -> AsyncIterator[LLMResponseChunk]:
206207
llm = self._prepare_llm(kwargs)
207208
messages = self._to_langchain_input(prompt)
209+
210+
tool_call_acc: Dict[int, Dict[str, Any]] = {}
211+
208212
async for chunk in llm.astream(messages, stop=stop):
209-
yield _langchain_chunk_to_llm_response_chunk(chunk)
213+
for tc_chunk in getattr(chunk, "tool_call_chunks", None) or []:
214+
idx = tc_chunk.get("index", 0)
215+
if idx not in tool_call_acc:
216+
tool_call_acc[idx] = {
217+
"id": tc_chunk.get("id") or "",
218+
"name": tc_chunk.get("name") or "",
219+
"arguments_buffer": "",
220+
}
221+
else:
222+
if tc_chunk.get("id"):
223+
tool_call_acc[idx]["id"] = tc_chunk["id"]
224+
if tc_chunk.get("name"):
225+
tool_call_acc[idx]["name"] = tc_chunk["name"]
226+
arg_fragment = tc_chunk.get("args") or ""
227+
if arg_fragment:
228+
tool_call_acc[idx]["arguments_buffer"] += arg_fragment
229+
230+
response_chunk = _langchain_chunk_to_llm_response_chunk(chunk)
231+
232+
if response_chunk.finish_reason == "tool_calls" and tool_call_acc:
233+
response_chunk.delta_tool_calls = _finalize_tool_call_acc(tool_call_acc)
234+
235+
yield response_chunk
210236

211237

212238
class LangChainFramework:
@@ -354,6 +380,29 @@ def _extract_tool_calls(response: Any) -> Optional[List[ToolCall]]:
354380
]
355381

356382

383+
def _finalize_tool_call_acc(acc: Dict[int, Dict[str, Any]]) -> List[ToolCall]:
384+
result = []
385+
for idx in sorted(acc.keys()):
386+
entry = acc[idx]
387+
raw_args = entry["arguments_buffer"]
388+
try:
389+
args_dict = json.loads(raw_args) if raw_args else {}
390+
except json.JSONDecodeError:
391+
log.warning("Failed to parse tool call arguments for '%s' (index %d): %r", entry["name"], idx, raw_args)
392+
args_dict = {}
393+
result.append(
394+
ToolCall(
395+
id=entry["id"] or str(uuid.uuid4()),
396+
type="function",
397+
function=ToolCallFunction(
398+
name=entry["name"],
399+
arguments=args_dict,
400+
),
401+
)
402+
)
403+
return result
404+
405+
357406
def _extract_usage(response: Any) -> Optional[UsageInfo]:
358407
usage = _build_usage_info(getattr(response, "usage_metadata", None))
359408
if usage is not None:

tests/test_actions_llm_utils.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,25 @@
2121
_log_completion,
2222
_store_reasoning_traces,
2323
_store_tool_calls,
24+
_stream_llm_call,
2425
_update_token_stats_from_chunk,
2526
llm_call,
2627
)
27-
from nemoguardrails.context import llm_call_info_var, llm_stats_var, reasoning_trace_var, tool_calls_var
28+
from nemoguardrails.context import (
29+
llm_call_info_var,
30+
llm_response_metadata_var,
31+
llm_stats_var,
32+
reasoning_trace_var,
33+
tool_calls_var,
34+
)
2835
from nemoguardrails.exceptions import LLMCallException
2936
from nemoguardrails.integrations.langchain.llm_adapter import (
3037
LangChainLLMAdapter,
3138
_infer_provider_from_module,
3239
)
3340
from nemoguardrails.logging.explain import LLMCallInfo
3441
from nemoguardrails.logging.stats import LLMStats
42+
from nemoguardrails.streaming import StreamingHandler
3543
from nemoguardrails.types import ChatMessage, LLMResponse, LLMResponseChunk, Role, ToolCall, ToolCallFunction, UsageInfo
3644

3745

@@ -482,3 +490,150 @@ def provider_url(self):
482490
await llm_call(model, [])
483491

484492
assert received_prompt == []
493+
494+
495+
def _make_chunk_model(chunks):
496+
class _Model:
497+
model_name = "test-model"
498+
provider_name = "test"
499+
provider_url = None
500+
501+
async def generate_async(self, prompt, *, stop=None, **kwargs):
502+
return LLMResponse(content="")
503+
504+
async def stream_async(self, prompt, *, stop=None, **kwargs):
505+
for c in chunks:
506+
yield c
507+
508+
return _Model()
509+
510+
511+
class TestStreamLlmCallAccumulation:
512+
@pytest.mark.asyncio
513+
async def test_accumulates_tool_calls(self):
514+
tc = [ToolCall(id="call_1", function=ToolCallFunction(name="get_weather", arguments={"city": "Paris"}))]
515+
model = _make_chunk_model(
516+
[
517+
LLMResponseChunk(model="gpt-4o"),
518+
LLMResponseChunk(delta_tool_calls=tc, finish_reason="tool_calls"),
519+
LLMResponseChunk(usage=UsageInfo(input_tokens=10, output_tokens=5, total_tokens=15)),
520+
]
521+
)
522+
523+
result = await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
524+
525+
assert result.tool_calls == tc
526+
assert result.model == "gpt-4o"
527+
assert result.finish_reason == "tool_calls"
528+
assert result.usage.total_tokens == 15
529+
assert tool_calls_var.get() is not None
530+
531+
@pytest.mark.asyncio
532+
async def test_accumulates_reasoning(self):
533+
model = _make_chunk_model(
534+
[
535+
LLMResponseChunk(delta_reasoning="Let me ", model="gpt-4o"),
536+
LLMResponseChunk(delta_reasoning="think..."),
537+
LLMResponseChunk(delta_content="42", finish_reason="stop"),
538+
LLMResponseChunk(usage=UsageInfo(input_tokens=5, output_tokens=3, total_tokens=8)),
539+
]
540+
)
541+
542+
result = await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
543+
544+
assert result.content == "42"
545+
assert result.reasoning == "Let me think..."
546+
assert result.model == "gpt-4o"
547+
assert result.finish_reason == "stop"
548+
assert reasoning_trace_var.get() == "Let me think..."
549+
550+
@pytest.mark.asyncio
551+
async def test_text_only(self):
552+
model = _make_chunk_model(
553+
[
554+
LLMResponseChunk(delta_content="Hello", model="gpt-4o"),
555+
LLMResponseChunk(delta_content=" world", finish_reason="stop"),
556+
LLMResponseChunk(usage=UsageInfo(input_tokens=5, output_tokens=2, total_tokens=7)),
557+
]
558+
)
559+
560+
result = await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
561+
562+
assert result.content == "Hello world"
563+
assert result.tool_calls is None
564+
assert result.reasoning is None
565+
assert result.model == "gpt-4o"
566+
assert result.finish_reason == "stop"
567+
assert result.usage.total_tokens == 7
568+
569+
@pytest.mark.asyncio
570+
async def test_request_id_accumulated(self):
571+
model = _make_chunk_model(
572+
[
573+
LLMResponseChunk(delta_content="hi", request_id="req-123", model="gpt-4o"),
574+
LLMResponseChunk(finish_reason="stop"),
575+
]
576+
)
577+
578+
result = await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
579+
580+
assert result.request_id == "req-123"
581+
582+
@pytest.mark.asyncio
583+
async def test_clears_tool_calls_var_when_none(self):
584+
tool_calls_var.set([{"id": "stale", "type": "function", "function": {"name": "old", "arguments": {}}}])
585+
586+
model = _make_chunk_model(
587+
[
588+
LLMResponseChunk(delta_content="no tools here", finish_reason="stop"),
589+
]
590+
)
591+
592+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
593+
594+
assert tool_calls_var.get() is None
595+
596+
@pytest.mark.asyncio
597+
async def test_clears_reasoning_var_when_none(self):
598+
reasoning_trace_var.set("stale reasoning")
599+
600+
model = _make_chunk_model(
601+
[
602+
LLMResponseChunk(delta_content="no reasoning", finish_reason="stop"),
603+
]
604+
)
605+
606+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
607+
608+
assert reasoning_trace_var.get() is None
609+
610+
@pytest.mark.asyncio
611+
async def test_provider_metadata_stored_flat(self):
612+
model = _make_chunk_model(
613+
[
614+
LLMResponseChunk(
615+
delta_content="hi",
616+
provider_metadata={"system_fingerprint": "fp_abc"},
617+
finish_reason="stop",
618+
),
619+
]
620+
)
621+
622+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
623+
624+
metadata = llm_response_metadata_var.get()
625+
assert metadata == {"system_fingerprint": "fp_abc"}
626+
627+
@pytest.mark.asyncio
628+
async def test_clears_metadata_var_when_none(self):
629+
llm_response_metadata_var.set({"stale": True})
630+
631+
model = _make_chunk_model(
632+
[
633+
LLMResponseChunk(delta_content="no metadata", finish_reason="stop"),
634+
]
635+
)
636+
637+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
638+
639+
assert llm_response_metadata_var.get() is None

0 commit comments

Comments
 (0)