Skip to content

Commit 7b2d6a2

Browse files
authored
Set kw_only=True in all dataclasses (#758)
Make all public dataclasses keyword-only using `kw_only=True` to allow flexible field ordering. This enables adding new fields or reordering existing ones without breaking positional initialization. This is important in terms of backwards compatibility of the SDK.
1 parent 4ae8c4a commit 7b2d6a2

13 files changed

Lines changed: 76 additions & 62 deletions

splunklib/ai/engines/langchain.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,13 @@ async def awrap_tool_call(
286286
assert resp.artifact is None, "artifact is already populated"
287287

288288
if resp.name.startswith(AGENT_PREFIX):
289-
resp.artifact = SubagentFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
289+
resp.artifact = SubagentFailureResult(
290+
error_message=str(resp.content) # pyright: ignore[reportUnknownArgumentType]
291+
)
290292
else:
291-
resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
293+
resp.artifact = ToolFailureResult(
294+
error_message=str(resp.content) # pyright: ignore[reportUnknownArgumentType]
295+
)
292296

293297
return resp
294298

@@ -863,7 +867,9 @@ async def llm_handler(req: ModelRequest) -> ModelResponse:
863867
case LC_StructuredOutputValidationError():
864868
raise StructuredOutputGenerationException(
865869
message=msg,
866-
error=StructuredOutputValidationError(str(e.source)),
870+
error=StructuredOutputValidationError(
871+
validation_error=str(e.source)
872+
),
867873
)
868874
case LC_StructuredOutputError():
869875
# Langchain only returns the above handled exceptions, LC_StructuredOutputError
@@ -1013,7 +1019,7 @@ async def _sdk_handler(request: ToolRequest) -> ToolResponse:
10131019
assert isinstance(sdk_result, ToolMessage), (
10141020
"Expected tool response from tool middleware handler"
10151021
)
1016-
return ToolResponse(sdk_result.result)
1022+
return ToolResponse(result=sdk_result.result)
10171023

10181024
return _sdk_handler
10191025

@@ -1033,7 +1039,7 @@ async def _sdk_handler(
10331039
assert isinstance(sdk_result, SubagentMessage), (
10341040
"Expected subagent response from subagent middleware handler"
10351041
)
1036-
return SubagentResponse(sdk_result.result)
1042+
return SubagentResponse(result=sdk_result.result)
10371043

10381044
return _sdk_handler
10391045

@@ -1276,10 +1282,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12761282

12771283
tool_strategy_messages = [
12781284
StructuredOutputMessage(
1279-
m.tool_call_id,
1280-
m.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX) if m.name else "",
1281-
m.status,
1282-
str(m.content), # pyright: ignore[reportUnknownArgumentType]
1285+
call_id=m.tool_call_id,
1286+
name=m.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX) if m.name else "",
1287+
status=m.status,
1288+
content=str(m.content), # pyright: ignore[reportUnknownArgumentType]
12831289
)
12841290
for m in model_response.result
12851291
if isinstance(m, LC_ToolMessage)
@@ -1404,7 +1410,9 @@ async def _tool_call(
14041410
"ToolException from LangChain should not be raised in tool.func"
14051411
)
14061412

1407-
artifact = ToolResult(result.content, result.structured_content)
1413+
artifact = ToolResult(
1414+
content=result.content, structured_content=result.structured_content
1415+
)
14081416

14091417
if result.structured_content:
14101418
# For both local tools and remote tools (Splunk MCP Server App), the primary
@@ -1721,9 +1729,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
17211729
],
17221730
structured_output_calls=[
17231731
StructuredOutputCall(
1724-
tc["id"] or "",
1725-
tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
1726-
tc["args"],
1732+
id=tc["id"] or "",
1733+
name=tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
1734+
args=tc["args"],
17271735
)
17281736
for tc in message.tool_calls
17291737
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)

splunklib/ai/messages.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from splunklib.ai.tools import ToolType
2222

2323

24-
@dataclass(frozen=True)
24+
@dataclass(frozen=True, kw_only=True)
2525
class TextBlock:
2626
"""Plain text content block returned by a model."""
2727

@@ -36,7 +36,7 @@ class TextBlock:
3636
"""
3737

3838

39-
@dataclass(frozen=True)
39+
@dataclass(frozen=True, kw_only=True)
4040
class OpaqueBlock:
4141
"""Content block of an unrecognized or unsupported type.
4242
@@ -62,30 +62,30 @@ class OpaqueBlock:
6262
ContentBlock = TextBlock | OpaqueBlock
6363

6464

65-
@dataclass(frozen=True)
65+
@dataclass(frozen=True, kw_only=True)
6666
class ToolCall:
6767
id: str
6868
name: str
6969
type: ToolType
7070
args: dict[str, Any]
7171

7272

73-
@dataclass(frozen=True)
73+
@dataclass(frozen=True, kw_only=True)
7474
class SubagentCall:
7575
id: str
7676
name: str
7777
args: str | dict[str, Any]
7878
thread_id: str | None
7979

8080

81-
@dataclass(frozen=True)
81+
@dataclass(frozen=True, kw_only=True)
8282
class StructuredOutputCall:
8383
id: str
8484
name: str
8585
args: dict[str, Any]
8686

8787

88-
@dataclass(frozen=True)
88+
@dataclass(frozen=True, kw_only=True)
8989
class BaseMessage:
9090
role: str = field(init=False)
9191

@@ -96,7 +96,7 @@ def __post_init__(self) -> None:
9696
)
9797

9898

99-
@dataclass(frozen=True)
99+
@dataclass(frozen=True, kw_only=True)
100100
class HumanMessage(BaseMessage):
101101
"""
102102
Message originating from a human user.
@@ -110,7 +110,7 @@ class HumanMessage(BaseMessage):
110110
content: str
111111

112112

113-
@dataclass(frozen=True)
113+
@dataclass(frozen=True, kw_only=True)
114114
class AIMessage(BaseMessage):
115115
"""
116116
Message produced by an LLM.
@@ -141,7 +141,7 @@ class AIMessage(BaseMessage):
141141
"""
142142

143143

144-
@dataclass(frozen=True)
144+
@dataclass(frozen=True, kw_only=True)
145145
class ToolResult:
146146
"""
147147
ToolResult represents a result of a successful tool call.
@@ -151,7 +151,7 @@ class ToolResult:
151151
structured_content: dict[str, Any] | None
152152

153153

154-
@dataclass(frozen=True)
154+
@dataclass(frozen=True, kw_only=True)
155155
class SubagentStructuredResult:
156156
"""
157157
SubagentStructuredResult represents a result of a successful subagent call.
@@ -161,7 +161,7 @@ class SubagentStructuredResult:
161161
structured_output: dict[str, Any]
162162

163163

164-
@dataclass(frozen=True)
164+
@dataclass(frozen=True, kw_only=True)
165165
class SubagentTextResult:
166166
"""
167167
SubagentTextResult represents a result of a successful subagent call.
@@ -171,7 +171,7 @@ class SubagentTextResult:
171171
content: str
172172

173173

174-
@dataclass(frozen=True)
174+
@dataclass(frozen=True, kw_only=True)
175175
class ToolFailureResult:
176176
"""
177177
Represents the result of a failed sub-agent call.
@@ -183,7 +183,7 @@ class ToolFailureResult:
183183
error_message: str
184184

185185

186-
@dataclass(frozen=True)
186+
@dataclass(frozen=True, kw_only=True)
187187
class SubagentFailureResult:
188188
"""
189189
Represents the result of a failed tool call.
@@ -195,7 +195,7 @@ class SubagentFailureResult:
195195
error_message: str
196196

197197

198-
@dataclass(frozen=True)
198+
@dataclass(frozen=True, kw_only=True)
199199
class ToolMessage(BaseMessage):
200200
"""ToolMessage represents a response of a tool call"""
201201

@@ -208,7 +208,7 @@ class ToolMessage(BaseMessage):
208208

209209

210210
# TODO: do we have a test that uses this?
211-
@dataclass(frozen=True)
211+
@dataclass(frozen=True, kw_only=True)
212212
class SystemMessage(BaseMessage):
213213
"""
214214
A message used to prime or control agent behavior.
@@ -218,7 +218,7 @@ class SystemMessage(BaseMessage):
218218
content: str
219219

220220

221-
@dataclass(frozen=True)
221+
@dataclass(frozen=True, kw_only=True)
222222
class SubagentMessage(BaseMessage):
223223
"""
224224
SubagentMessage represents a response of an agent invocation
@@ -231,7 +231,7 @@ class SubagentMessage(BaseMessage):
231231
result: SubagentStructuredResult | SubagentTextResult | SubagentFailureResult
232232

233233

234-
@dataclass(frozen=True)
234+
@dataclass(frozen=True, kw_only=True)
235235
class StructuredOutputMessage(BaseMessage):
236236
"""
237237
StructuredMessage represents a response to the StructuredOutputCall.
@@ -254,7 +254,7 @@ class StructuredOutputMessage(BaseMessage):
254254
# where developers might want to store messages in say KV store.
255255

256256

257-
@dataclass(frozen=True)
257+
@dataclass(frozen=True, kw_only=True)
258258
class AgentResponse(Generic[OutputT]):
259259
# in case output_schema is provided, this will hold the parsed structured output
260260
structured_output: OutputT

splunklib/ai/middleware.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131

3232

33-
@dataclass(frozen=True)
33+
@dataclass(frozen=True, kw_only=True)
3434
class AgentState:
3535
"""AgentState is available through certain middlewares and contains information about the current state of an agent execution."""
3636

@@ -42,27 +42,27 @@ class AgentState:
4242
token_count: int
4343

4444

45-
@dataclass(frozen=True)
45+
@dataclass(frozen=True, kw_only=True)
4646
class ToolRequest:
4747
call: ToolCall
4848
state: AgentState
4949

5050

51-
@dataclass(frozen=True)
51+
@dataclass(frozen=True, kw_only=True)
5252
class ToolResponse:
5353
result: ToolResult | ToolFailureResult
5454

5555

5656
ToolMiddlewareHandler = Callable[[ToolRequest], Awaitable[ToolResponse]]
5757

5858

59-
@dataclass(frozen=True)
59+
@dataclass(frozen=True, kw_only=True)
6060
class SubagentRequest:
6161
call: SubagentCall
6262
state: AgentState
6363

6464

65-
@dataclass(frozen=True)
65+
@dataclass(frozen=True, kw_only=True)
6666
class SubagentResponse:
6767
result: SubagentStructuredResult | SubagentTextResult | SubagentFailureResult
6868

@@ -73,13 +73,13 @@ class SubagentResponse:
7373
]
7474

7575

76-
@dataclass(frozen=True)
76+
@dataclass(frozen=True, kw_only=True)
7777
class ModelRequest:
7878
system_message: str
7979
state: AgentState
8080

8181

82-
@dataclass(frozen=True)
82+
@dataclass(frozen=True, kw_only=True)
8383
class ModelResponse:
8484
message: AIMessage
8585
structured_output: Any | None = None
@@ -94,7 +94,7 @@ def __post_init__(self) -> None:
9494
ModelMiddlewareHandler = Callable[[ModelRequest], Awaitable[ModelResponse]]
9595

9696

97-
@dataclass(frozen=True)
97+
@dataclass(frozen=True, kw_only=True)
9898
class AgentRequest:
9999
messages: Sequence[BaseMessage]
100100

splunklib/ai/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
import httpx
1919

2020

21-
@dataclass(frozen=True)
21+
@dataclass(frozen=True, kw_only=True)
2222
class PredefinedModel:
2323
"""Base class for models that are predefined in the SDK"""
2424

2525
model: str
2626

2727

28-
@dataclass(frozen=True)
28+
@dataclass(frozen=True, kw_only=True)
2929
class OpenAIModel(PredefinedModel):
3030
"""Predefined OpenAI Model"""
3131

@@ -53,7 +53,7 @@ class OpenAIModel(PredefinedModel):
5353
"""
5454

5555

56-
@dataclass(frozen=True)
56+
@dataclass(frozen=True, kw_only=True)
5757
class AnthropicModel(PredefinedModel):
5858
"""Predefined Anthropic Model"""
5959

splunklib/ai/structured_output.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from splunklib.ai.messages import AIMessage
1818

1919

20-
@dataclass(frozen=True)
20+
@dataclass(frozen=True, kw_only=True)
2121
class StructuredOutputMultipleToolCallsError:
2222
pass
2323

2424

25-
@dataclass(frozen=True)
25+
@dataclass(frozen=True, kw_only=True)
2626
class StructuredOutputValidationError:
2727
validation_error: str
2828

splunklib/ai/tool_settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from splunklib.ai.tools import ToolMetadata
1919

2020

21-
@dataclass(frozen=True)
21+
@dataclass(frozen=True, kw_only=True)
2222
class ToolAllowlist:
2323
"""Holds tool names and tags allowed to be used by Agents.
2424
@@ -41,17 +41,17 @@ def is_allowed(self, tool: ToolMetadata) -> bool:
4141
return self.custom_predicate(tool) if self.custom_predicate else False
4242

4343

44-
@dataclass(frozen=True)
44+
@dataclass(frozen=True, kw_only=True)
4545
class RemoteToolSettings:
4646
allowlist: ToolAllowlist
4747

4848

49-
@dataclass(frozen=True)
49+
@dataclass(frozen=True, kw_only=True)
5050
class LocalToolSettings:
5151
allowlist: ToolAllowlist
5252

5353

54-
@dataclass(frozen=True)
54+
@dataclass(frozen=True, kw_only=True)
5555
class ToolSettings:
5656
local: LocalToolSettings | bool
5757
"""Controls local tool loading (via ``bin/tools.py``).

0 commit comments

Comments
 (0)