Skip to content

Commit a89e626

Browse files
committed
fix(llm): separate provider metadata from streaming handler metadata
Split accumulated_metadata into streaming_handler_metadata (wrapped structure for StreamingHandler) and accumulated_provider_metadata (flat dict matching non-streaming path). Fixes schema parity for llm_response_metadata_var. Follow-up: streaming_handler_metadata still uses the wrapped structure from _extract_chunk_metadata ({"provider_metadata": ..., "usage": ...}). If output rails or streaming consumers need raw provider metadata, that function should be revisited.
1 parent ef0253a commit a89e626

2 files changed

Lines changed: 73 additions & 5 deletions

File tree

nemoguardrails/actions/llm/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ async def _stream_llm_call(
9898
llm_params: Optional[dict] = None,
9999
) -> LLMResponse:
100100
handler.stop = stop or []
101-
accumulated_metadata: Dict[str, Any] = {}
101+
streaming_handler_metadata: Dict[str, Any] = {}
102+
accumulated_provider_metadata: Dict[str, Any] = {}
102103
accumulated_reasoning: List[str] = []
103104
tool_calls = None
104105
model_name: Optional[str] = None
@@ -122,14 +123,16 @@ async def _stream_llm_call(
122123
request_id = chunk.request_id
123124
if chunk.usage:
124125
usage = chunk.usage
126+
if chunk.provider_metadata:
127+
accumulated_provider_metadata.update(chunk.provider_metadata)
125128

126129
chunk_metadata = _extract_chunk_metadata(chunk)
127130
if chunk_metadata:
128-
accumulated_metadata.update(chunk_metadata)
131+
streaming_handler_metadata.update(chunk_metadata)
129132

130133
await handler.push_chunk(content, chunk_metadata)
131134

132-
llm_response_metadata_var.set(accumulated_metadata or None)
135+
llm_response_metadata_var.set(accumulated_provider_metadata or None)
133136

134137
await handler.finish()
135138

@@ -160,7 +163,7 @@ async def _stream_llm_call(
160163
finish_reason=finish_reason,
161164
request_id=request_id,
162165
usage=usage,
163-
provider_metadata=accumulated_metadata if accumulated_metadata else None,
166+
provider_metadata=accumulated_provider_metadata or None,
164167
)
165168

166169
except Exception as e:

tests/test_actions_llm_utils.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
_update_token_stats_from_chunk,
2626
llm_call,
2727
)
28-
from nemoguardrails.context import llm_call_info_var, llm_stats_var, reasoning_trace_var, tool_calls_var
28+
from nemoguardrails.context import (
29+
llm_call_info_var,
30+
llm_response_metadata_var,
31+
llm_stats_var,
32+
reasoning_trace_var,
33+
tool_calls_var,
34+
)
2935
from nemoguardrails.exceptions import LLMCallException
3036
from nemoguardrails.integrations.langchain.llm_adapter import (
3137
LangChainLLMAdapter,
@@ -572,3 +578,62 @@ async def test_request_id_accumulated(self):
572578
result = await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
573579

574580
assert result.request_id == "req-123"
581+
582+
@pytest.mark.asyncio
583+
async def test_clears_tool_calls_var_when_none(self):
584+
tool_calls_var.set([{"id": "stale", "type": "function", "function": {"name": "old", "arguments": {}}}])
585+
586+
model = _make_chunk_model(
587+
[
588+
LLMResponseChunk(delta_content="no tools here", finish_reason="stop"),
589+
]
590+
)
591+
592+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
593+
594+
assert tool_calls_var.get() is None
595+
596+
@pytest.mark.asyncio
597+
async def test_clears_reasoning_var_when_none(self):
598+
reasoning_trace_var.set("stale reasoning")
599+
600+
model = _make_chunk_model(
601+
[
602+
LLMResponseChunk(delta_content="no reasoning", finish_reason="stop"),
603+
]
604+
)
605+
606+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
607+
608+
assert reasoning_trace_var.get() is None
609+
610+
@pytest.mark.asyncio
611+
async def test_provider_metadata_stored_flat(self):
612+
model = _make_chunk_model(
613+
[
614+
LLMResponseChunk(
615+
delta_content="hi",
616+
provider_metadata={"system_fingerprint": "fp_abc"},
617+
finish_reason="stop",
618+
),
619+
]
620+
)
621+
622+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
623+
624+
metadata = llm_response_metadata_var.get()
625+
assert metadata == {"system_fingerprint": "fp_abc"}
626+
627+
@pytest.mark.asyncio
628+
async def test_clears_metadata_var_when_none(self):
629+
llm_response_metadata_var.set({"stale": True})
630+
631+
model = _make_chunk_model(
632+
[
633+
LLMResponseChunk(delta_content="no metadata", finish_reason="stop"),
634+
]
635+
)
636+
637+
await _stream_llm_call(model, "test", StreamingHandler(), stop=None)
638+
639+
assert llm_response_metadata_var.get() is None

0 commit comments

Comments
 (0)