Skip to content

Commit afe344e

Browse files
committed
simplify agent loop to use built-ins
1 parent b1712dc commit afe344e

4 files changed

Lines changed: 139 additions & 122 deletions

File tree

Lines changed: 31 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,63 @@
11
"""LangChain agent runner for LaunchDarkly AI SDK."""
22

3-
from typing import Any, List
3+
from typing import Any
44

5-
from langchain_core.messages import (
6-
AIMessage,
7-
BaseMessage,
8-
HumanMessage,
9-
SystemMessage,
10-
ToolMessage,
11-
)
125
from ldai import log
13-
from ldai.providers import AgentResult, AgentRunner, ToolRegistry
6+
from ldai.providers import AgentResult, AgentRunner
147
from ldai.providers.types import LDAIMetrics
158

16-
from ldai_langchain.langchain_helper import get_ai_metrics_from_response
9+
from ldai_langchain.langchain_helper import sum_token_usage_from_messages
1710

1811

1912
class LangChainAgentRunner(AgentRunner):
2013
"""
2114
AgentRunner implementation for LangChain.
2215
23-
Executes a single-agent loop using a LangChain BaseChatModel with tool calling.
24-
The model is expected to have tools already bound to it.
16+
Wraps a compiled LangChain agent graph (from ``langchain.agents.create_agent``)
17+
and delegates execution to it. Tool calling and loop management are handled
18+
internally by the graph.
2519
Returned by LangChainRunnerFactory.create_agent(config, tools).
2620
"""
2721

28-
def __init__(
29-
self,
30-
llm: Any,
31-
instructions: str,
32-
tools: ToolRegistry,
33-
):
34-
self._llm = llm
35-
self._instructions = instructions
36-
self._tools = tools
22+
def __init__(self, agent: Any):
23+
self._agent = agent
3724

3825
async def run(self, input: Any) -> AgentResult:
3926
"""
4027
Run the agent with the given input string.
4128
42-
Executes an agentic loop: calls the model, handles tool calls,
43-
and continues until the model produces a final response.
29+
Delegates to the compiled LangChain agent, which handles
30+
the tool-calling loop internally.
4431
4532
:param input: The user prompt or input to the agent
4633
:return: AgentResult with output, raw response, and aggregated metrics
4734
"""
48-
messages: List[BaseMessage] = []
49-
if self._instructions:
50-
messages.append(SystemMessage(content=self._instructions))
51-
messages.append(HumanMessage(content=str(input)))
52-
53-
raw_response = None
54-
5535
try:
56-
while True:
57-
response: AIMessage = await self._llm.ainvoke(messages)
58-
raw_response = response
59-
messages.append(response)
60-
61-
tool_calls = getattr(response, 'tool_calls', None)
62-
63-
if not tool_calls:
64-
metrics = get_ai_metrics_from_response(response)
65-
content = response.content if isinstance(response.content, str) else ""
66-
return AgentResult(
67-
output=content,
68-
raw=raw_response,
69-
metrics=metrics,
70-
)
71-
72-
# Execute tool calls and append results
73-
for tool_call in tool_calls:
74-
tool_name = tool_call["name"]
75-
tool_args = tool_call.get("args", {})
76-
tool_id = tool_call.get("id", "")
77-
78-
tool_fn = self._tools.get(tool_name)
79-
if tool_fn:
80-
try:
81-
result = tool_fn(**tool_args)
82-
if hasattr(result, "__await__"):
83-
result = await result
84-
result_str = str(result)
85-
except Exception as error:
86-
log.warning(f"Tool '{tool_name}' execution failed: {error}")
87-
result_str = f"Tool execution failed: {error}"
88-
else:
89-
log.warning(f"Tool '{tool_name}' not found in registry")
90-
result_str = f"Tool '{tool_name}' not found"
91-
92-
messages.append(ToolMessage(content=result_str, tool_call_id=tool_id))
93-
36+
result = await self._agent.ainvoke({
37+
"messages": [{"role": "user", "content": str(input)}]
38+
})
39+
messages = result.get("messages", [])
40+
output = ""
41+
if messages:
42+
last = messages[-1]
43+
if hasattr(last, 'content') and isinstance(last.content, str):
44+
output = last.content
45+
return AgentResult(
46+
output=output,
47+
raw=result,
48+
metrics=LDAIMetrics(
49+
success=True,
50+
usage=sum_token_usage_from_messages(messages),
51+
),
52+
)
9453
except Exception as error:
9554
log.warning(f"LangChain agent run failed: {error}")
9655
return AgentResult(
9756
output="",
98-
raw=raw_response,
57+
raw=None,
9958
metrics=LDAIMetrics(success=False, usage=None),
10059
)
10160

102-
def get_llm(self) -> Any:
103-
"""Return the underlying LangChain LLM."""
104-
return self._llm
61+
def get_agent(self) -> Any:
62+
"""Return the underlying compiled LangChain agent."""
63+
return self._agent

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def create_langchain_model(ai_config: AIConfigKind, tool_registry: Optional[Tool
8787
**parameters,
8888
)
8989

90-
if tool_definitions:
91-
bindable = _resolve_tools_for_langchain(tool_definitions, tool_registry or {})
90+
if tool_definitions and tool_registry is not None:
91+
bindable = _resolve_tools_for_langchain(tool_definitions, tool_registry)
9292
if bindable:
9393
model = model.bind_tools(bindable)
9494

@@ -138,6 +138,55 @@ def _resolve_tools_for_langchain(
138138
return bindable
139139

140140

141+
def build_structured_tools(ai_config: AIConfigKind, tool_registry: ToolRegistry) -> List[Any]:
142+
"""
143+
Build a list of LangChain StructuredTool instances from LD tool definitions and a registry.
144+
145+
Tools found in the registry are wrapped as StructuredTool with the name and description
146+
from the LD config. Built-in provider tools and tools missing from the registry are
147+
skipped with a warning.
148+
149+
:param ai_config: The LaunchDarkly AI configuration
150+
:param tool_registry: Registry mapping tool names to callable implementations
151+
:return: List of StructuredTool instances ready to pass to langchain.agents.create_agent
152+
"""
153+
from langchain_core.tools import StructuredTool
154+
155+
config_dict = ai_config.to_dict()
156+
model_dict = config_dict.get('model') or {}
157+
parameters = dict(model_dict.get('parameters') or {})
158+
tool_definitions = parameters.pop('tools', []) or []
159+
160+
structured = []
161+
for td in tool_definitions:
162+
if not isinstance(td, dict):
163+
continue
164+
165+
tool_type = td.get('type')
166+
if tool_type and tool_type != 'function':
167+
log.warning(
168+
f"Built-in tool '{tool_type}' is not reliably supported via LangChain and will be skipped. "
169+
"Use a provider-specific runner to use built-in provider tools."
170+
)
171+
continue
172+
173+
name = td.get('name')
174+
if not name:
175+
continue
176+
177+
if name not in tool_registry:
178+
log.warning(f"Tool '{name}' is defined in the AI config but was not found in the tool registry; skipping.")
179+
continue
180+
181+
structured.append(StructuredTool.from_function(
182+
func=tool_registry[name],
183+
name=name,
184+
description=td.get('description', ''),
185+
))
186+
187+
return structured
188+
189+
141190
def get_ai_usage_from_response(response: Any) -> Optional[TokenUsage]:
142191
"""
143192
Extract token usage from a LangChain response.

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import TYPE_CHECKING, Any, Optional
1+
from typing import Any, Optional
22

3+
from langchain.agents import create_agent as lc_create_agent
34
from ldai.models import AIConfigKind
45
from ldai.providers import AIProvider, ToolRegistry
56

6-
if TYPE_CHECKING:
7-
from ldai_langchain.langchain_agent_runner import LangChainAgentRunner
8-
9-
from ldai_langchain.langchain_helper import create_langchain_model
7+
from ldai_langchain.langchain_agent_runner import LangChainAgentRunner
8+
from ldai_langchain.langchain_helper import (
9+
build_structured_tools,
10+
create_langchain_model,
11+
)
1012
from ldai_langchain.langchain_model_runner import LangChainModelRunner
1113

1214

@@ -36,16 +38,21 @@ def create_model(self, config: AIConfigKind) -> LangChainModelRunner:
3638
llm = create_langchain_model(config)
3739
return LangChainModelRunner(llm)
3840

39-
def create_agent(self, config: Any, tools: Optional[ToolRegistry] = None) -> 'LangChainAgentRunner':
41+
def create_agent(self, config: Any, tools: Optional[ToolRegistry] = None) -> LangChainAgentRunner:
4042
"""
4143
Create a configured LangChainAgentRunner for the given AI agent config.
4244
4345
:param config: The LaunchDarkly AI agent configuration
4446
:param tools: ToolRegistry mapping tool names to callables
4547
:return: LangChainAgentRunner ready to run the agent
4648
"""
47-
from ldai_langchain.langchain_agent_runner import LangChainAgentRunner
48-
4949
instructions = (config.instructions or '') if hasattr(config, 'instructions') else ''
50-
llm = create_langchain_model(config, tool_registry=tools or {})
51-
return LangChainAgentRunner(llm, instructions, tools or {})
50+
llm = create_langchain_model(config)
51+
lc_tools = build_structured_tools(config, tools or {})
52+
53+
agent = lc_create_agent(
54+
llm,
55+
tools=lc_tools or None,
56+
system_prompt=instructions or None,
57+
)
58+
return LangChainAgentRunner(agent)

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ class TestCreateAgent:
402402
"""Tests for LangChainRunnerFactory.create_agent."""
403403

404404
def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
405-
"""Should create LangChainAgentRunner with instructions and tool definitions."""
405+
"""Should create LangChainAgentRunner wrapping a compiled graph."""
406406
from unittest.mock import patch
407407
from ldai_langchain import LangChainAgentRunner
408408

@@ -420,15 +420,18 @@ def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
420420
'provider': {'name': 'openai'},
421421
}
422422

423-
with patch('ldai_langchain.langchain_runner_factory.create_langchain_model') as mock_create:
424-
mock_llm = MagicMock()
425-
mock_create.return_value = mock_llm
423+
mock_agent = MagicMock()
424+
with patch('ldai_langchain.langchain_runner_factory.create_langchain_model') as mock_create, \
425+
patch('ldai_langchain.langchain_runner_factory.build_structured_tools') as mock_tools, \
426+
patch('ldai_langchain.langchain_runner_factory.lc_create_agent', return_value=mock_agent):
427+
mock_create.return_value = MagicMock()
428+
mock_tools.return_value = [MagicMock()]
426429

427430
factory = LangChainRunnerFactory()
428431
result = factory.create_agent(mock_ai_config, {'get-weather': lambda loc: 'sunny'})
429432

430433
assert isinstance(result, LangChainAgentRunner)
431-
assert result._instructions == "You are a helpful assistant."
434+
assert result._agent is mock_agent
432435

433436
def test_creates_agent_runner_with_no_tools(self):
434437
"""Should create LangChainAgentRunner with no tool definitions."""
@@ -442,73 +445,72 @@ def test_creates_agent_runner_with_no_tools(self):
442445
'provider': {'name': 'openai'},
443446
}
444447

445-
with patch('ldai_langchain.langchain_runner_factory.create_langchain_model') as mock_create:
448+
mock_agent = MagicMock()
449+
with patch('ldai_langchain.langchain_runner_factory.create_langchain_model') as mock_create, \
450+
patch('ldai_langchain.langchain_runner_factory.build_structured_tools', return_value=[]), \
451+
patch('ldai_langchain.langchain_runner_factory.lc_create_agent', return_value=mock_agent):
446452
mock_create.return_value = MagicMock()
447453

448454
factory = LangChainRunnerFactory()
449455
result = factory.create_agent(mock_ai_config, {})
450456

451457
assert isinstance(result, LangChainAgentRunner)
452-
assert result._tools == {}
458+
assert result._agent is mock_agent
453459

454460

455461
class TestLangChainAgentRunner:
456462
"""Tests for LangChainAgentRunner.run."""
457463

458464
@pytest.mark.asyncio
459-
async def test_runs_agent_and_returns_result_with_no_tool_calls(self):
460-
"""Should return AgentResult when model responds with no tool calls."""
465+
async def test_runs_agent_and_returns_result(self):
466+
"""Should return AgentResult with the last message content from the graph."""
461467
from ldai_langchain import LangChainAgentRunner
462-
from langchain_core.messages import AIMessage
463468

464-
mock_llm = MagicMock()
465-
mock_response = AIMessage(content="The answer is 42.")
466-
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
469+
final_msg = AIMessage(content="The answer is 42.")
470+
mock_agent = MagicMock()
471+
mock_agent.ainvoke = AsyncMock(return_value={"messages": [final_msg]})
467472

468-
runner = LangChainAgentRunner(mock_llm, "You are helpful.", {})
473+
runner = LangChainAgentRunner(mock_agent)
469474
result = await runner.run("What is the answer?")
470475

471476
assert result.output == "The answer is 42."
472477
assert result.metrics.success is True
478+
mock_agent.ainvoke.assert_called_once_with(
479+
{"messages": [{"role": "user", "content": "What is the answer?"}]}
480+
)
473481

474482
@pytest.mark.asyncio
475-
async def test_executes_tool_calls_and_returns_final_response(self):
476-
"""Should execute tool calls and continue loop until final response."""
483+
async def test_aggregates_token_usage_across_messages(self):
484+
"""Should sum token usage from all messages in the graph result."""
477485
from ldai_langchain import LangChainAgentRunner
478-
from langchain_core.messages import AIMessage
479-
480-
# First response: has a tool call
481-
first_response = AIMessage(content="")
482-
first_response.tool_calls = [
483-
{"name": "get-weather", "args": {"location": "Paris"}, "id": "call_123"}
484-
]
485486

486-
# Second response: final answer
487-
second_response = AIMessage(content="It is sunny in Paris.")
487+
msg1 = AIMessage(content="intermediate")
488+
msg1.usage_metadata = {'total_tokens': 10, 'input_tokens': 6, 'output_tokens': 4}
489+
msg2 = AIMessage(content="final answer")
490+
msg2.usage_metadata = {'total_tokens': 20, 'input_tokens': 12, 'output_tokens': 8}
488491

489-
mock_llm = MagicMock()
490-
mock_llm.ainvoke = AsyncMock(side_effect=[first_response, second_response])
492+
mock_agent = MagicMock()
493+
mock_agent.ainvoke = AsyncMock(return_value={"messages": [msg1, msg2]})
491494

492-
weather_fn = MagicMock(return_value="Sunny, 25°C")
493-
runner = LangChainAgentRunner(
494-
mock_llm, "You are helpful.",
495-
{'get-weather': weather_fn},
496-
)
497-
result = await runner.run("What is the weather in Paris?")
495+
runner = LangChainAgentRunner(mock_agent)
496+
result = await runner.run("Hello")
498497

499-
assert result.output == "It is sunny in Paris."
498+
assert result.output == "final answer"
500499
assert result.metrics.success is True
501-
weather_fn.assert_called_once_with(location="Paris")
500+
assert result.metrics.usage is not None
501+
assert result.metrics.usage.total == 30
502+
assert result.metrics.usage.input == 18
503+
assert result.metrics.usage.output == 12
502504

503505
@pytest.mark.asyncio
504506
async def test_returns_failure_when_exception_thrown(self):
505507
"""Should return unsuccessful AgentResult when exception is thrown."""
506508
from ldai_langchain import LangChainAgentRunner
507509

508-
mock_llm = MagicMock()
509-
mock_llm.ainvoke = AsyncMock(side_effect=Exception("LLM Error"))
510+
mock_agent = MagicMock()
511+
mock_agent.ainvoke = AsyncMock(side_effect=Exception("Graph Error"))
510512

511-
runner = LangChainAgentRunner(mock_llm, "", {})
513+
runner = LangChainAgentRunner(mock_agent)
512514
result = await runner.run("Hello")
513515

514516
assert result.output == ""

0 commit comments

Comments
 (0)