Skip to content

Commit f5f3340

Browse files
committed
Clean up astream_events tests for v1 & v2
1 parent 1f2b421 commit f5f3340

10 files changed

Lines changed: 2300 additions & 326 deletions

File tree

newrelic/common/llm_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,19 @@ def __init__(self, wrapped, on_stop_iteration, on_error, on_stream_chunk=None):
120120
self._nr_on_stream_chunk = on_stream_chunk or noop
121121
# Track if we've sent the LLM events yet to avoid sending them multiple times
122122
self._nr_closed = False
123+
# Lazily established by __aiter__ or __anext__. With LangChain's
124+
# astream_events, __anext__ may be called before __aiter__.
125+
self._nr_wrapped_iter = None
123126

124127
def __aiter__(self):
125128
self._nr_wrapped_iter = self.__wrapped__.__aiter__()
126129
return self
127130

128131
async def __anext__(self):
132+
# Lazily establish the wrapped iterator. With astream_events,
133+
# __anext__ may be called before __aiter__.
134+
if self._nr_wrapped_iter is None:
135+
self._nr_wrapped_iter = self.__wrapped__.__aiter__()
129136
try:
130137
return_val = await self._nr_wrapped_iter.__anext__()
131138
self._nr_on_stream_chunk(self, return_val)

newrelic/hooks/mlmodel_langchain.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,33 @@ def astream(self, *args, **kwargs):
267267

268268
return return_val
269269

270+
def astream_events(self, *args, **kwargs):
271+
transaction = current_transaction()
272+
if not transaction:
273+
return self.__wrapped__.astream_events(*args, **kwargs)
274+
275+
agent_name = getattr(self.__wrapped__, "name", "agent")
276+
agent_id = str(uuid.uuid4())
277+
agent_event_dict = _construct_base_agent_event_dict(agent_name, agent_id, transaction)
278+
function_trace_name = f"astream_events/{agent_name}"
279+
agentic_subcomponent_data = {"type": "APM-AI_AGENT", "name": agent_name}
280+
281+
ft = FunctionTrace(name=function_trace_name, group="Llm/agent/LangChain")
282+
ft.__enter__()
283+
ft._add_agent_attribute("subcomponent", json.dumps(agentic_subcomponent_data))
284+
try:
285+
return_val = self.__wrapped__.astream_events(*args, **kwargs)
286+
return_val = AsyncLLMStreamProxy(
287+
return_val,
288+
on_stop_iteration=self._nr_on_stop_iteration(ft, agent_event_dict),
289+
on_error=self._nr_on_error(ft, agent_event_dict, agent_id),
290+
)
291+
except Exception:
292+
self._nr_on_error(ft, agent_event_dict, agent_id)(transaction)
293+
raise
294+
295+
return return_val
296+
270297
def transform(self, *args, **kwargs):
271298
transaction = current_transaction()
272299
if not transaction:

tests/mlmodel_langchain/_test_tools.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,22 @@ async def add_exclamation_async(message: str) -> str:
3232
return f"{message}!"
3333

3434

35-
@pytest.fixture(scope="session", params=["sync_tool", "async_tool"])
35+
@pytest.fixture(params=["sync_tool", "async_tool"])
3636
def tool_type(request):
3737
return request.param
3838

3939

40-
@pytest.fixture(scope="session")
40+
@pytest.fixture
4141
def tool_method_name(tool_type):
4242
return "run" if tool_type == "sync_tool" else "arun"
4343

4444

45-
@pytest.fixture(scope="session")
45+
@pytest.fixture
4646
def add_exclamation(tool_type, exercise_agent):
4747
if tool_type == "sync_tool":
4848
return add_exclamation_sync
4949
elif tool_type == "async_tool":
50-
if exercise_agent._called_method in {"invoke", "stream"}:
50+
if exercise_agent._called_method in ("invoke", "stream"):
5151
pytest.skip("Async tools cannot be invoked synchronously.")
5252
return add_exclamation_async
5353
else:

0 commit comments

Comments
 (0)