Skip to content

Commit 46b00e0

Browse files
committed
Handle AIMessage.content properly
1 parent be76151 commit 46b00e0

3 files changed

Lines changed: 231 additions & 15 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@
7777
AgentResponse,
7878
AIMessage,
7979
BaseMessage,
80+
ContentBlock,
8081
HumanMessage,
82+
OpaqueBlock,
8183
OutputT,
8284
StructuredOutputCall,
8385
StructuredOutputMessage,
@@ -87,6 +89,7 @@
8789
SubagentStructuredResult,
8890
SubagentTextResult,
8991
SystemMessage,
92+
TextBlock,
9093
ToolCall,
9194
ToolFailureResult,
9295
ToolMessage,
@@ -122,7 +125,7 @@
122125
LC_ModelRequest = Langchain_ModelRequest["InvokeContext"]
123126

124127
# Set to True to enable debugging mode.
125-
_DEBUG = False
128+
_DEBUG = True
126129

127130
# Disallow _DEBUG == True in CI.
128131
# Github actions sets the CI env var.
@@ -951,7 +954,7 @@ async def awrap_tool_call(
951954
return LC_ToolMessage(
952955
name=_normalize_agent_name(call.name),
953956
tool_call_id=call.id,
954-
content=content,
957+
content=_map_content_to_langchain(content),
955958
status=status,
956959
artifact=sdk_result,
957960
)
@@ -1085,7 +1088,7 @@ def _convert_model_response_to_model_result(
10851088
# This invariant is asserted via ModelResponse.__post_init__
10861089
assert len(resp.message.structured_output_calls) <= 1
10871090

1088-
lc_message = LC_AIMessage(content=resp.message.content)
1091+
lc_message = LC_AIMessage(content=_map_content_to_langchain(resp.message.content))
10891092
# This field can't be set via __init__()
10901093
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]
10911094

@@ -1160,7 +1163,7 @@ def _convert_tool_message_to_lc(
11601163
name=name,
11611164
tool_call_id=message.call_id,
11621165
status=status,
1163-
content=content,
1166+
content=_map_content_to_langchain(content),
11641167
artifact=artifact,
11651168
)
11661169

@@ -1245,7 +1248,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12451248

12461249
return ModelResponse(
12471250
message=AIMessage(
1248-
content=ai_message.content.__str__(),
1251+
content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType]
12491252
calls=[
12501253
_map_tool_call_from_langchain(tc)
12511254
for tc in ai_message.tool_calls
@@ -1433,7 +1436,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
14331436

14341437
async def invoke_agent(
14351438
message: HumanMessage, thread_id: str | None
1436-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1439+
) -> tuple[
1440+
OutputT | str | list[str | ContentBlock],
1441+
SubagentStructuredResult | SubagentTextResult,
1442+
]:
14371443
result = await agent.invoke([message], thread_id=thread_id)
14381444

14391445
if agent.output_schema:
@@ -1452,13 +1458,19 @@ async def invoke_agent(
14521458

14531459
async def _run( # pyright: ignore[reportRedeclaration]
14541460
content: str, thread_id: str
1455-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1461+
) -> tuple[
1462+
OutputT | str | list[str | ContentBlock],
1463+
SubagentStructuredResult | SubagentTextResult,
1464+
]:
14561465
return await invoke_agent(HumanMessage(content=content), thread_id)
14571466
else:
14581467

14591468
async def _run( # pyright: ignore[reportRedeclaration]
14601469
content: str,
1461-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1470+
) -> tuple[
1471+
OutputT | str | list[str | ContentBlock],
1472+
SubagentStructuredResult | SubagentTextResult,
1473+
]:
14621474
return await invoke_agent(HumanMessage(content=content), None)
14631475

14641476
return StructuredTool.from_function(
@@ -1471,7 +1483,10 @@ async def _run( # pyright: ignore[reportRedeclaration]
14711483

14721484
async def invoke_agent_structured(
14731485
content: BaseModel, thread_id: str | None
1474-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1486+
) -> tuple[
1487+
OutputT | str | list[str | ContentBlock],
1488+
SubagentStructuredResult | SubagentTextResult,
1489+
]:
14751490
result = await agent.invoke_with_data(
14761491
instructions="Follow the system prompt.",
14771492
data=content.model_dump(),
@@ -1492,7 +1507,10 @@ async def invoke_agent_structured(
14921507

14931508
async def _run(
14941509
**kwargs: Any, # noqa: ANN401
1495-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1510+
) -> tuple[
1511+
OutputT | str | list[str | ContentBlock],
1512+
SubagentStructuredResult | SubagentTextResult,
1513+
]:
14961514
content: BaseModel = kwargs["content"]
14971515
thread_id: str = kwargs["thread_id"]
14981516
return await invoke_agent_structured(content, thread_id)
@@ -1512,7 +1530,10 @@ async def _run(
15121530

15131531
async def _run(
15141532
**kwargs: Any, # noqa: ANN401
1515-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1533+
) -> tuple[
1534+
OutputT | str | list[str | ContentBlock],
1535+
SubagentStructuredResult | SubagentTextResult,
1536+
]:
15161537
content = InputSchema(**kwargs)
15171538
return await invoke_agent_structured(content, None)
15181539

@@ -1564,11 +1585,66 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
15641585
return LC_ToolCall(id=call.id, name=name, args=args)
15651586

15661587

1588+
def _map_content_from_langchain(
1589+
content: str | list[str | dict[str, Any]],
1590+
) -> str | list[str | ContentBlock]:
1591+
if isinstance(content, str):
1592+
return content
1593+
1594+
result_content = [_map_content_block_from_langchain(b) for b in content]
1595+
1596+
return result_content
1597+
1598+
1599+
def _map_content_block_from_langchain(
1600+
block: str | dict[str, Any],
1601+
) -> str | ContentBlock:
1602+
if isinstance(block, str):
1603+
return block
1604+
1605+
match block.get("type"):
1606+
case "text":
1607+
return TextBlock(
1608+
text=block["text"],
1609+
extras=block.get("extras"),
1610+
)
1611+
case _:
1612+
# NOTE: we return data we're not handling
1613+
# as opaque content blocks so they
1614+
# are preserved and sent back to the LLM
1615+
return OpaqueBlock(data=block)
1616+
1617+
1618+
def _map_content_to_langchain(
1619+
content: str | list[str | ContentBlock],
1620+
) -> str | list[str | dict[str, Any]]:
1621+
if isinstance(content, str):
1622+
return content
1623+
1624+
result_content = [_map_content_block_to_langchain(b) for b in content]
1625+
1626+
return result_content
1627+
1628+
1629+
def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]:
1630+
if isinstance(block, str):
1631+
return block
1632+
1633+
match block:
1634+
case TextBlock():
1635+
result: dict[str, Any] = {"type": "text", "text": block.text}
1636+
if block.extras:
1637+
result["extras"] = block.extras
1638+
return result
1639+
case OpaqueBlock():
1640+
return block.data
1641+
1642+
15671643
def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15681644
match message:
15691645
case LC_AIMessage():
15701646
return AIMessage(
1571-
content=message.content.__str__(),
1647+
content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType]
15721648
calls=[
15731649
_map_tool_call_from_langchain(tc)
15741650
for tc in message.tool_calls
@@ -1597,7 +1673,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15971673
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
15981674
match message:
15991675
case AIMessage():
1600-
lc_message = LC_AIMessage(content=message.content)
1676+
lc_message = LC_AIMessage(
1677+
content=_map_content_to_langchain(message.content)
1678+
)
16011679
# This field can't be set via constructor
16021680
lc_message.tool_calls = [
16031681
_map_tool_call_to_langchain(c) for c in message.calls

splunklib/ai/messages.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,31 @@
2121
from splunklib.ai.tools import ToolType
2222

2323

24+
@dataclass(frozen=True)
25+
class TextBlock:
26+
"""Plain text content block returned by a model."""
27+
28+
text: str
29+
# TODO: should we have the id here as well?
30+
# Provider-specific extras (e.g. Gemini thought signature on text blocks).
31+
extras: dict[str, Any] | None = field(default=None)
32+
33+
34+
@dataclass(frozen=True)
35+
class OpaqueBlock:
36+
"""Content block of an unrecognized or unsupported type.
37+
38+
The raw provider dict is preserved in `data` so it can be sent back
39+
to the model unchanged on subsequent calls.
40+
"""
41+
42+
data: dict[str, Any]
43+
44+
45+
# Type alias for all content block variants.
46+
ContentBlock = TextBlock | OpaqueBlock
47+
48+
2449
@dataclass(frozen=True)
2550
class ToolCall:
2651
name: str
@@ -85,12 +110,15 @@ class AIMessage(BaseMessage):
85110
"""
86111

87112
role: Literal["assistant"] = field(default="assistant", init=False)
88-
content: str
113+
content: str | list[str | ContentBlock]
89114

90115
calls: Sequence[ToolCall | SubagentCall]
91116
structured_output_calls: Sequence[StructuredOutputCall] = field(
92117
default_factory=tuple
93118
)
119+
# Backend-specific metadata (e.g. provider additional_kwargs) not
120+
# representable in the standard fields. Opaque to callers.
121+
extras: dict[str, Any] | None = field(default=None)
94122

95123

96124
@dataclass(frozen=True)
@@ -120,7 +148,7 @@ class SubagentTextResult:
120148
Returned by subagent calls that don't have an output schema.
121149
"""
122150

123-
content: str
151+
content: str | list[str | ContentBlock]
124152

125153

126154
@dataclass(frozen=True)

0 commit comments

Comments
 (0)