Skip to content

Commit ce0c1e4

Browse files
committed
cleanup code
1 parent 37575be commit ce0c1e4

4 files changed

Lines changed: 27 additions & 194 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py

Lines changed: 1 addition & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,6 @@
1717
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler
1818

1919

20-
def _tool_call_id_from_entry(tc: Any) -> Any:
21-
"""Return tool_call id from a dict or ToolCall-like object."""
22-
if isinstance(tc, dict):
23-
return tc.get('id')
24-
return getattr(tc, 'id', None)
25-
26-
2720
def _make_handoff_tool(child_key: str, description: str) -> Any:
2821
"""
2922
Create a tool that transfers control to ``child_key``.
@@ -58,114 +51,6 @@ def handoff(
5851
return handoff
5952

6053

61-
def _coalesce_tool_messages_for_openai(msgs: List[Any]) -> List[Any]:
62-
"""
63-
Rewind shared LangGraph message state into OpenAI's required shape.
64-
65-
Multi-agent graphs append multiple AIMessages before tool outputs are written, so
66-
ToolMessages can appear *after* later assistant turns. OpenAI requires each
67-
assistant ``tool_calls`` block to be followed immediately by matching ToolMessages.
68-
69-
This walks non-tool messages in order; after each AIMessage with ``tool_calls``,
70-
appends the corresponding ToolMessages (looked up by ``tool_call_id``). Any
71-
ToolMessage not referenced by some assistant ``tool_calls`` in the list is dropped.
72-
73-
An AIMessage whose tool_calls have *no* matching ToolMessages in the state is also
74-
dropped. This handles the parallel fan-out case where a sibling branch's tool
75-
execution hasn't run yet (or routes to END) by the time a downstream node reads the
76-
accumulated state — sending such an AIMessage to OpenAI would cause a 400 error.
77-
"""
78-
from langchain_core.messages import AIMessage, ToolMessage
79-
80-
pending: Dict[str, Any] = {}
81-
for m in msgs:
82-
if isinstance(m, ToolMessage):
83-
tid = getattr(m, 'tool_call_id', None)
84-
if tid:
85-
if tid in pending:
86-
log.warning(
87-
'LangGraphAgentGraphRunner: duplicate ToolMessage for tool_call_id=%r; keeping last',
88-
tid,
89-
)
90-
pending[tid] = m
91-
92-
output: List[Any] = []
93-
for m in msgs:
94-
if isinstance(m, ToolMessage):
95-
continue
96-
if isinstance(m, AIMessage) and getattr(m, 'tool_calls', None):
97-
call_ids = [_tool_call_id_from_entry(tc) for tc in m.tool_calls]
98-
resolved = [tid for tid in call_ids if tid and tid in pending]
99-
if not resolved:
100-
# None of this AIMessage's tool_calls have responses in the current
101-
# state — it belongs to a sibling branch whose ToolMessages aren't
102-
# available here. Drop it to avoid an OpenAI 400.
103-
log.warning(
104-
'LangGraphAgentGraphRunner: dropping AIMessage with unresolvable '
105-
'tool_calls %s (no matching ToolMessages in state — likely from a '
106-
'sibling branch in a parallel fan-out)',
107-
call_ids,
108-
)
109-
continue
110-
output.append(m)
111-
for tid in resolved:
112-
output.append(pending.pop(tid))
113-
else:
114-
output.append(m)
115-
116-
if pending:
117-
log.warning(
118-
'LangGraphAgentGraphRunner: dropping %s orphan ToolMessage(s) (no assistant '
119-
'tool_calls in history): %s',
120-
len(pending),
121-
list(pending.keys())[:32],
122-
)
123-
124-
return output
125-
126-
127-
def _message_content_len(msg: Any) -> int:
128-
c = getattr(msg, 'content', None)
129-
if c is None:
130-
return 0
131-
if isinstance(c, str):
132-
return len(c)
133-
return len(str(c))
134-
135-
136-
def _format_chat_messages_for_log(msgs: List[Any]) -> str:
137-
"""
138-
One line per message index — matches OpenAI error indices (e.g. messages.[5]).
139-
Logged at DEBUG before each model ainvoke.
140-
"""
141-
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
142-
143-
lines: List[str] = []
144-
for idx, m in enumerate(msgs):
145-
cls = type(m).__name__
146-
if isinstance(m, AIMessage):
147-
tc = getattr(m, 'tool_calls', None) or []
148-
pairs = []
149-
for x in tc:
150-
if isinstance(x, dict):
151-
pairs.append((x.get('name'), x.get('id')))
152-
else:
153-
pairs.append((getattr(x, 'name', None), getattr(x, 'id', None)))
154-
lines.append(
155-
f'[{idx}] {cls} tool_calls={pairs!r} content_len={_message_content_len(m)}'
156-
)
157-
elif isinstance(m, ToolMessage):
158-
lines.append(
159-
f'[{idx}] {cls} tool_call_id={getattr(m, "tool_call_id", None)!r} '
160-
f'name={getattr(m, "name", None)!r} content_len={_message_content_len(m)}'
161-
)
162-
elif isinstance(m, (HumanMessage, SystemMessage)):
163-
lines.append(f'[{idx}] {cls} content_len={_message_content_len(m)}')
164-
else:
165-
lines.append(f'[{idx}] {cls}')
166-
return '\n'.join(lines)
167-
168-
16954
class LangGraphAgentGraphRunner(AgentGraphRunner):
17055
"""
17156
CAUTION:
@@ -268,15 +153,9 @@ def make_node_fn(bound_model: Any, node_instructions: Any, nk: str):
268153
async def invoke(state: WorkflowState) -> dict:
269154
if not bound_model:
270155
return {'messages': []}
271-
msgs = _coalesce_tool_messages_for_openai(list(state['messages']))
156+
msgs = list(state['messages'])
272157
if node_instructions:
273158
msgs = [SystemMessage(content=node_instructions)] + msgs
274-
log.debug(
275-
'LangGraphAgentGraphRunner node=%s: CHAT_INPUT (%s messages)\n%s',
276-
nk,
277-
len(msgs),
278-
_format_chat_messages_for_log(msgs),
279-
)
280159
response = await bound_model.ainvoke(msgs)
281160
return {'messages': [response]}
282161

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from ldai.agent_graph import AgentGraphDefinition
88
from ldai.tracker import TokenUsage
99

10+
from ldai_langchain.langchain_helper import get_ai_usage_from_response
11+
1012

1113
class LDMetricsCallbackHandler(BaseCallbackHandler):
1214
"""
@@ -137,7 +139,11 @@ def on_llm_end(
137139
if node_key is None:
138140
return
139141

140-
usage = self._extract_token_usage(response)
142+
try:
143+
message = response.generations[0][0].message
144+
except (IndexError, AttributeError, TypeError):
145+
return
146+
usage = get_ai_usage_from_response(message)
141147
if usage is None:
142148
return
143149

@@ -210,42 +216,3 @@ def flush(self, graph: AgentGraphDefinition, graph_tracker: Any) -> None:
210216
for tool_key in self._node_tool_calls.get(node_key, []):
211217
config_tracker.track_tool_call(tool_key, graph_key=gk)
212218

213-
# ------------------------------------------------------------------
214-
# Internal helpers
215-
# ------------------------------------------------------------------
216-
217-
@staticmethod
218-
def _extract_token_usage(response: LLMResult) -> Optional[TokenUsage]:
219-
"""Extract token usage from an LLMResult, trying multiple locations."""
220-
llm_output = response.llm_output or {}
221-
222-
# Primary: llm_output['token_usage'] or llm_output['tokenUsage']
223-
tu = llm_output.get('token_usage') or llm_output.get('tokenUsage')
224-
if tu and isinstance(tu, dict):
225-
total = tu.get('total_tokens') or tu.get('totalTokens') or 0
226-
inp = tu.get('prompt_tokens') or tu.get('promptTokens') or 0
227-
out = tu.get('completion_tokens') or tu.get('completionTokens') or 0
228-
if total or inp or out:
229-
return TokenUsage(total=total, input=inp, output=out)
230-
231-
# Fallback: first generation's generation_info
232-
try:
233-
gen_info = response.generations[0][0].generation_info or {}
234-
except (IndexError, AttributeError, TypeError):
235-
gen_info = {}
236-
237-
for key in ('usage_metadata', 'token_usage'):
238-
meta = gen_info.get(key)
239-
if meta and isinstance(meta, dict):
240-
total = (
241-
meta.get('total_tokens')
242-
or meta.get('totalTokens')
243-
or (meta.get('input_tokens', 0) + meta.get('output_tokens', 0))
244-
or 0
245-
)
246-
inp = meta.get('input_tokens') or meta.get('prompt_tokens') or meta.get('promptTokens') or 0
247-
out = meta.get('output_tokens') or meta.get('completion_tokens') or meta.get('completionTokens') or 0
248-
if total or inp or out:
249-
return TokenUsage(total=total, input=inp, output=out)
250-
251-
return None

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,11 @@ async def test_returns_failure_when_exception_thrown(self):
518518

519519

520520
class TestBuildTools:
521-
"""Tests for langchain_helper.build_tools (sync vs async registry callables)."""
521+
"""Tests for build_structured_tools (sync vs async registry callables)."""
522522

523523
def test_registers_sync_callable_as_structured_tool_func(self):
524524
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig
525-
from ldai_langchain.langchain_helper import build_tools
525+
from ldai_langchain.langchain_helper import build_structured_tools
526526

527527
def sync_tool(x: str = '') -> str:
528528
return 'ok'
@@ -538,14 +538,14 @@ def sync_tool(x: str = '') -> str:
538538
instructions='',
539539
tracker=MagicMock(),
540540
)
541-
tools = build_tools(cfg, {'my_tool': sync_tool})
541+
tools = build_structured_tools(cfg, {'my_tool': sync_tool})
542542
assert len(tools) == 1
543543
assert tools[0].func is sync_tool
544544
assert getattr(tools[0], 'coroutine', None) is None
545545

546546
def test_registers_async_callable_as_structured_tool_coroutine(self):
547547
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig
548-
from ldai_langchain.langchain_helper import build_tools
548+
from ldai_langchain.langchain_helper import build_structured_tools
549549

550550
async def async_tool(x: str = '') -> str:
551551
return 'ok'
@@ -561,6 +561,6 @@ async def async_tool(x: str = '') -> str:
561561
instructions='',
562562
tracker=MagicMock(),
563563
)
564-
tools = build_tools(cfg, {'my_tool': async_tool})
564+
tools = build_structured_tools(cfg, {'my_tool': async_tool})
565565
assert len(tools) == 1
566566
assert tools[0].coroutine is async_tool

packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_k
6969

7070
def _llm_result(total: int, prompt: int, completion: int) -> LLMResult:
7171
return LLMResult(
72-
generations=[[ChatGeneration(message=AIMessage(content='ok'), text='ok')]],
73-
llm_output={
74-
'token_usage': {
75-
'total_tokens': total,
76-
'prompt_tokens': prompt,
77-
'completion_tokens': completion,
78-
}
79-
},
72+
generations=[[ChatGeneration(
73+
message=AIMessage(
74+
content='ok',
75+
usage_metadata={'total_tokens': total, 'input_tokens': prompt, 'output_tokens': completion},
76+
),
77+
text='ok',
78+
)]],
79+
llm_output={},
8080
)
8181

8282

@@ -242,14 +242,17 @@ def test_on_llm_end_unknown_parent_run_id_ignored():
242242

243243

244244
def test_on_llm_end_camel_case_token_keys():
245-
"""camelCase token_usage keys (tokenUsage, totalTokens, etc.) are parsed."""
245+
"""camelCase token keys in response_metadata (e.g. some AWS Bedrock models) are parsed."""
246246
handler = LDMetricsCallbackHandler({'root-agent'}, {})
247247
node_run_id = uuid4()
248248
handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent')
249249

250+
msg = AIMessage(content='ok', response_metadata={
251+
'tokenUsage': {'totalTokens': 20, 'promptTokens': 12, 'completionTokens': 8}
252+
})
250253
result = LLMResult(
251-
generations=[[ChatGeneration(message=AIMessage(content='ok'), text='ok')]],
252-
llm_output={'tokenUsage': {'totalTokens': 20, 'promptTokens': 12, 'completionTokens': 8}},
254+
generations=[[ChatGeneration(message=msg, text='ok')]],
255+
llm_output={},
253256
)
254257
handler.on_llm_end(result, run_id=uuid4(), parent_run_id=node_run_id)
255258

@@ -260,22 +263,6 @@ def test_on_llm_end_camel_case_token_keys():
260263
assert tokens.output == 8
261264

262265

263-
def test_on_llm_end_fallback_generation_info():
264-
"""Token usage in generation_info is used as fallback when llm_output is absent."""
265-
handler = LDMetricsCallbackHandler({'root-agent'}, {})
266-
node_run_id = uuid4()
267-
handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent')
268-
269-
gen = ChatGeneration(message=AIMessage(content='ok'), text='ok')
270-
gen.generation_info = {'usage_metadata': {'input_tokens': 4, 'output_tokens': 2, 'total_tokens': 6}}
271-
result = LLMResult(generations=[[gen]], llm_output={})
272-
handler.on_llm_end(result, run_id=uuid4(), parent_run_id=node_run_id)
273-
274-
tokens = handler.node_tokens.get('root-agent')
275-
assert tokens is not None
276-
assert tokens.total == 6
277-
278-
279266
# ---------------------------------------------------------------------------
280267
# on_tool_end tests
281268
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)