Skip to content

Commit a7eaa70

Browse files
committed
Handle AIMessage.content properly
1 parent be76151 commit a7eaa70

2 files changed

Lines changed: 122 additions & 15 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@
7777
AgentResponse,
7878
AIMessage,
7979
BaseMessage,
80+
ContentBlock,
8081
HumanMessage,
82+
OpaqueBlock,
8183
OutputT,
8284
StructuredOutputCall,
8385
StructuredOutputMessage,
@@ -87,6 +89,7 @@
8789
SubagentStructuredResult,
8890
SubagentTextResult,
8991
SystemMessage,
92+
TextBlock,
9093
ToolCall,
9194
ToolFailureResult,
9295
ToolMessage,
@@ -122,7 +125,7 @@
122125
LC_ModelRequest = Langchain_ModelRequest["InvokeContext"]
123126

124127
# Set to True to enable debugging mode.
125-
_DEBUG = False
128+
_DEBUG = True
126129

127130
# Disallow _DEBUG == True in CI.
128131
# Github actions sets the CI env var.
@@ -951,7 +954,7 @@ async def awrap_tool_call(
951954
return LC_ToolMessage(
952955
name=_normalize_agent_name(call.name),
953956
tool_call_id=call.id,
954-
content=content,
957+
content=_map_content_to_langchain(content),
955958
status=status,
956959
artifact=sdk_result,
957960
)
@@ -1085,7 +1088,7 @@ def _convert_model_response_to_model_result(
10851088
# This invariant is asserted via ModelResponse.__post_init__
10861089
assert len(resp.message.structured_output_calls) <= 1
10871090

1088-
lc_message = LC_AIMessage(content=resp.message.content)
1091+
lc_message = LC_AIMessage(content=_map_content_to_langchain(resp.message.content))
10891092
# This field can't be set via __init__()
10901093
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]
10911094

@@ -1160,7 +1163,7 @@ def _convert_tool_message_to_lc(
11601163
name=name,
11611164
tool_call_id=message.call_id,
11621165
status=status,
1163-
content=content,
1166+
content=_map_content_to_langchain(content),
11641167
artifact=artifact,
11651168
)
11661169

@@ -1245,7 +1248,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12451248

12461249
return ModelResponse(
12471250
message=AIMessage(
1248-
content=ai_message.content.__str__(),
1251+
content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType]
12491252
calls=[
12501253
_map_tool_call_from_langchain(tc)
12511254
for tc in ai_message.tool_calls
@@ -1433,7 +1436,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
14331436

14341437
async def invoke_agent(
14351438
message: HumanMessage, thread_id: str | None
1436-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1439+
) -> tuple[
1440+
OutputT | str | list[str | ContentBlock],
1441+
SubagentStructuredResult | SubagentTextResult,
1442+
]:
14371443
result = await agent.invoke([message], thread_id=thread_id)
14381444

14391445
if agent.output_schema:
@@ -1452,13 +1458,19 @@ async def invoke_agent(
14521458

14531459
async def _run( # pyright: ignore[reportRedeclaration]
14541460
content: str, thread_id: str
1455-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1461+
) -> tuple[
1462+
OutputT | str | list[str | ContentBlock],
1463+
SubagentStructuredResult | SubagentTextResult,
1464+
]:
14561465
return await invoke_agent(HumanMessage(content=content), thread_id)
14571466
else:
14581467

14591468
async def _run( # pyright: ignore[reportRedeclaration]
14601469
content: str,
1461-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1470+
) -> tuple[
1471+
OutputT | str | list[str | ContentBlock],
1472+
SubagentStructuredResult | SubagentTextResult,
1473+
]:
14621474
return await invoke_agent(HumanMessage(content=content), None)
14631475

14641476
return StructuredTool.from_function(
@@ -1471,7 +1483,10 @@ async def _run( # pyright: ignore[reportRedeclaration]
14711483

14721484
async def invoke_agent_structured(
14731485
content: BaseModel, thread_id: str | None
1474-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1486+
) -> tuple[
1487+
OutputT | str | list[str | ContentBlock],
1488+
SubagentStructuredResult | SubagentTextResult,
1489+
]:
14751490
result = await agent.invoke_with_data(
14761491
instructions="Follow the system prompt.",
14771492
data=content.model_dump(),
@@ -1492,7 +1507,10 @@ async def invoke_agent_structured(
14921507

14931508
async def _run(
14941509
**kwargs: Any, # noqa: ANN401
1495-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1510+
) -> tuple[
1511+
OutputT | str | list[str | ContentBlock],
1512+
SubagentStructuredResult | SubagentTextResult,
1513+
]:
14961514
content: BaseModel = kwargs["content"]
14971515
thread_id: str = kwargs["thread_id"]
14981516
return await invoke_agent_structured(content, thread_id)
@@ -1512,7 +1530,10 @@ async def _run(
15121530

15131531
async def _run(
15141532
**kwargs: Any, # noqa: ANN401
1515-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1533+
) -> tuple[
1534+
OutputT | str | list[str | ContentBlock],
1535+
SubagentStructuredResult | SubagentTextResult,
1536+
]:
15161537
content = InputSchema(**kwargs)
15171538
return await invoke_agent_structured(content, None)
15181539

@@ -1564,11 +1585,68 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
15641585
return LC_ToolCall(id=call.id, name=name, args=args)
15651586

15661587

1588+
def _map_content_from_langchain(
1589+
content: str | list[str | dict[str, Any]],
1590+
) -> str | list[str | ContentBlock]:
1591+
if isinstance(content, str):
1592+
return content
1593+
1594+
result_content = [_map_content_block_from_langchain(b) for b in content]
1595+
1596+
return result_content
1597+
1598+
1599+
def _map_content_block_from_langchain(
1600+
block: str | dict[str, Any],
1601+
) -> str | ContentBlock:
1602+
if isinstance(block, str):
1603+
return block
1604+
1605+
match block.get("type"):
1606+
case "text":
1607+
return TextBlock(
1608+
text=block["text"],
1609+
extras=block.get("extras"),
1610+
)
1611+
case _:
1612+
# NOTE: we return data we're not handling
1613+
# as opaque content blocks so they
1614+
# are preserved and sent back to the LLM
1615+
return OpaqueBlock(data=block)
1616+
1617+
1618+
def _map_content_to_langchain(
1619+
content: str | list[str | ContentBlock],
1620+
) -> str | list[str | dict[str, Any]]:
1621+
if isinstance(content, str):
1622+
return content
1623+
1624+
result_content = [_map_content_block_to_langchain(b) for b in content]
1625+
1626+
return result_content
1627+
1628+
1629+
def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]:
1630+
if isinstance(block, str):
1631+
return block
1632+
1633+
match block:
1634+
case TextBlock():
1635+
result: dict[str, Any] = {"type": "text", "text": block.text}
1636+
if block.extras:
1637+
result["extras"] = block.extras
1638+
return result
1639+
case OpaqueBlock():
1640+
return block.data
1641+
case _:
1642+
raise AssertionError("not supported yet")
1643+
1644+
15671645
def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15681646
match message:
15691647
case LC_AIMessage():
15701648
return AIMessage(
1571-
content=message.content.__str__(),
1649+
content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType]
15721650
calls=[
15731651
_map_tool_call_from_langchain(tc)
15741652
for tc in message.tool_calls
@@ -1597,7 +1675,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15971675
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
15981676
match message:
15991677
case AIMessage():
1600-
lc_message = LC_AIMessage(content=message.content)
1678+
lc_message = LC_AIMessage(
1679+
content=_map_content_to_langchain(message.content)
1680+
)
16011681
# This field can't be set via constructor
16021682
lc_message.tool_calls = [
16031683
_map_tool_call_to_langchain(c) for c in message.calls

splunklib/ai/messages.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,30 @@
2121
from splunklib.ai.tools import ToolType
2222

2323

24+
@dataclass(frozen=True)
25+
class TextBlock:
26+
"""Plain text content block returned by a model."""
27+
28+
text: str
29+
# Provider-specific extras (e.g. Gemini thought signature on text blocks).
30+
extras: dict[str, Any] | None = field(default=None)
31+
32+
33+
@dataclass(frozen=True)
34+
class OpaqueBlock:
35+
"""Content block of an unrecognized or unsupported type.
36+
37+
The raw provider dict is preserved in `data` so it can be sent back
38+
to the model unchanged on subsequent calls.
39+
"""
40+
41+
data: dict[str, Any]
42+
43+
44+
# Type alias for all content block variants.
45+
ContentBlock = TextBlock | OpaqueBlock
46+
47+
2448
@dataclass(frozen=True)
2549
class ToolCall:
2650
name: str
@@ -85,12 +109,15 @@ class AIMessage(BaseMessage):
85109
"""
86110

87111
role: Literal["assistant"] = field(default="assistant", init=False)
88-
content: str
112+
content: str | list[str | ContentBlock]
89113

90114
calls: Sequence[ToolCall | SubagentCall]
91115
structured_output_calls: Sequence[StructuredOutputCall] = field(
92116
default_factory=tuple
93117
)
118+
# Backend-specific metadata (e.g. provider additional_kwargs) not
119+
# representable in the standard fields. Opaque to callers.
120+
extras: dict[str, Any] | None = field(default=None)
94121

95122

96123
@dataclass(frozen=True)
@@ -120,7 +147,7 @@ class SubagentTextResult:
120147
Returned by subagent calls that don't have an output schema.
121148
"""
122149

123-
content: str
150+
content: str | list[str | ContentBlock]
124151

125152

126153
@dataclass(frozen=True)

0 commit comments

Comments
 (0)