Skip to content

Commit 243ad5c

Browse files
authored
Merge branch 'main' into gzip-compression
2 parents 30d71e6 + 53ae9fc commit 243ad5c

9 files changed

Lines changed: 273 additions & 23 deletions

File tree

temporalio/client/_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,8 +1196,6 @@ async def _start_update_with_start(
11961196
args=temporalio.common._arg_or_args(arg, args),
11971197
headers={},
11981198
ret_type=result_type or result_type_from_type_hint,
1199-
rpc_metadata=rpc_metadata,
1200-
rpc_timeout=rpc_timeout,
12011199
wait_for_stage=wait_for_stage,
12021200
)
12031201

@@ -1222,6 +1220,8 @@ def on_start_error(
12221220
input = StartWorkflowUpdateWithStartInput(
12231221
start_workflow_input=start_workflow_operation._start_workflow_input,
12241222
update_workflow_input=update_input,
1223+
rpc_metadata=rpc_metadata,
1224+
rpc_timeout=rpc_timeout,
12251225
_on_start=on_start,
12261226
_on_start_error=on_start_error,
12271227
)

temporalio/client/_impl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,11 @@ def on_start(
852852

853853
try:
854854
return await self._start_workflow_update_with_start(
855-
input.start_workflow_input, input.update_workflow_input, on_start
855+
input.start_workflow_input,
856+
input.update_workflow_input,
857+
input.rpc_metadata,
858+
input.rpc_timeout,
859+
on_start,
856860
)
857861
except asyncio.CancelledError as _err:
858862
err = _err
@@ -914,6 +918,8 @@ async def _start_workflow_update_with_start(
914918
self,
915919
start_input: UpdateWithStartStartWorkflowInput,
916920
update_input: UpdateWithStartUpdateWorkflowInput,
921+
rpc_metadata: Mapping[str, str | bytes],
922+
rpc_timeout: timedelta | None,
917923
on_start: Callable[
918924
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
919925
],
@@ -941,7 +947,12 @@ async def _start_workflow_update_with_start(
941947
# Repeatedly try to invoke ExecuteMultiOperation until the update is durable
942948
while True:
943949
multiop_response = (
944-
await self._client.workflow_service.execute_multi_operation(multiop_req)
950+
await self._client.workflow_service.execute_multi_operation(
951+
multiop_req,
952+
retry=True,
953+
metadata=rpc_metadata,
954+
timeout=rpc_timeout,
955+
)
945956
)
946957
start_response = multiop_response.responses[0].start_workflow
947958
update_response = multiop_response.responses[1].update_workflow

temporalio/client/_interceptor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ class UpdateWithStartUpdateWorkflowInput:
334334
wait_for_stage: WorkflowUpdateStage
335335
headers: Mapping[str, temporalio.api.common.v1.Payload]
336336
ret_type: type | None
337-
rpc_metadata: Mapping[str, str | bytes]
338-
rpc_timeout: timedelta | None
339337

340338

341339
@dataclass
@@ -366,18 +364,23 @@ class UpdateWithStartStartWorkflowInput:
366364
static_details: str | None
367365
# Type may be absent
368366
ret_type: type | None
369-
rpc_metadata: Mapping[str, str | bytes]
370-
rpc_timeout: timedelta | None
371367
priority: temporalio.common.Priority
372368
versioning_override: temporalio.common.VersioningOverride | None = None
373369

374370

375371
@dataclass
376372
class StartWorkflowUpdateWithStartInput:
377-
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`."""
373+
"""Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`.
374+
375+
The ``rpc_metadata`` and ``rpc_timeout`` fields are authoritative for the
376+
``execute_multi_operation`` gRPC call. Interceptors that wish to set RPC
377+
metadata should modify :py:attr:`rpc_metadata` on this object.
378+
"""
378379

379380
start_workflow_input: UpdateWithStartStartWorkflowInput
380381
update_workflow_input: UpdateWithStartUpdateWorkflowInput
382+
rpc_metadata: Mapping[str, str | bytes]
383+
rpc_timeout: timedelta | None
381384
_on_start: Callable[
382385
[temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None
383386
]

temporalio/client/_workflow.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,8 +1065,6 @@ def __init__(
10651065
static_summary: str | None = None,
10661066
static_details: str | None = None,
10671067
start_delay: timedelta | None = None,
1068-
rpc_metadata: Mapping[str, str | bytes] = {},
1069-
rpc_timeout: timedelta | None = None,
10701068
priority: temporalio.common.Priority = temporalio.common.Priority.default,
10711069
versioning_override: temporalio.common.VersioningOverride | None = None,
10721070
) -> None: ...
@@ -1095,8 +1093,6 @@ def __init__(
10951093
static_summary: str | None = None,
10961094
static_details: str | None = None,
10971095
start_delay: timedelta | None = None,
1098-
rpc_metadata: Mapping[str, str | bytes] = {},
1099-
rpc_timeout: timedelta | None = None,
11001096
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11011097
versioning_override: temporalio.common.VersioningOverride | None = None,
11021098
) -> None: ...
@@ -1127,8 +1123,6 @@ def __init__(
11271123
static_summary: str | None = None,
11281124
static_details: str | None = None,
11291125
start_delay: timedelta | None = None,
1130-
rpc_metadata: Mapping[str, str | bytes] = {},
1131-
rpc_timeout: timedelta | None = None,
11321126
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11331127
versioning_override: temporalio.common.VersioningOverride | None = None,
11341128
) -> None: ...
@@ -1159,8 +1153,6 @@ def __init__(
11591153
static_summary: str | None = None,
11601154
static_details: str | None = None,
11611155
start_delay: timedelta | None = None,
1162-
rpc_metadata: Mapping[str, str | bytes] = {},
1163-
rpc_timeout: timedelta | None = None,
11641156
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11651157
versioning_override: temporalio.common.VersioningOverride | None = None,
11661158
) -> None: ...
@@ -1189,8 +1181,6 @@ def __init__(
11891181
static_summary: str | None = None,
11901182
static_details: str | None = None,
11911183
start_delay: timedelta | None = None,
1192-
rpc_metadata: Mapping[str, str | bytes] = {},
1193-
rpc_timeout: timedelta | None = None,
11941184
priority: temporalio.common.Priority = temporalio.common.Priority.default,
11951185
versioning_override: temporalio.common.VersioningOverride | None = None,
11961186
stack_level: int = 2,
@@ -1228,8 +1218,6 @@ def __init__(
12281218
start_delay=start_delay,
12291219
headers={},
12301220
ret_type=result_type or result_type_from_run_fn,
1231-
rpc_metadata=rpc_metadata,
1232-
rpc_timeout=rpc_timeout,
12331221
priority=priority,
12341222
versioning_override=versioning_override,
12351223
)

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from agents.items import TResponseStreamEvent
3131
from agents.tool import (
3232
ApplyPatchTool,
33+
CustomTool,
3334
LocalShellTool,
3435
ShellTool,
3536
ShellToolEnvironment,
@@ -39,6 +40,7 @@
3940
APIStatusError,
4041
AsyncOpenAI,
4142
)
43+
from openai.types.responses import CustomToolParam
4244
from openai.types.responses.tool_param import Mcp
4345
from typing_extensions import Required, TypedDict
4446

@@ -112,6 +114,15 @@ class ApplyPatchToolInput:
112114
name: str = "apply_patch"
113115

114116

117+
@dataclass
118+
class CustomToolInput:
119+
"""Data conversion friendly representation of a CustomTool. Contains only the fields which are needed by the model
120+
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
121+
"""
122+
123+
tool_config: CustomToolParam
124+
125+
115126
ToolInput = (
116127
FunctionToolInput
117128
| FileSearchTool
@@ -122,6 +133,7 @@ class ApplyPatchToolInput:
122133
| ShellToolInput
123134
| LocalShellTool
124135
| ApplyPatchToolInput
136+
| CustomToolInput
125137
| ToolSearchTool
126138
)
127139

@@ -235,6 +247,14 @@ def _build_tool(tool: ToolInput) -> Tool:
235247
return ApplyPatchTool(name=tool.name, editor=_NoopApplyPatchEditor())
236248
elif isinstance(tool, HostedMCPToolInput):
237249
return HostedMCPTool(tool_config=tool.tool_config)
250+
elif isinstance(tool, CustomToolInput):
251+
return CustomTool(
252+
name=tool.tool_config["name"],
253+
description=tool.tool_config.get("description", ""),
254+
on_invoke_tool=_empty_on_invoke_tool,
255+
format=tool.tool_config.get("format"),
256+
defer_loading=tool.tool_config.get("defer_loading", False),
257+
)
238258
elif isinstance(tool, FunctionToolInput):
239259
return FunctionTool(
240260
name=tool.name,

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,21 @@
2222
WebSearchTool,
2323
)
2424
from agents.items import TResponseStreamEvent
25-
from agents.tool import ApplyPatchTool, LocalShellTool, ShellTool, ToolSearchTool
25+
from agents.tool import (
26+
ApplyPatchTool,
27+
CustomTool,
28+
LocalShellTool,
29+
ShellTool,
30+
ToolSearchTool,
31+
)
2632
from openai.types.responses.response_prompt_param import ResponsePromptParam
2733

2834
from temporalio import workflow
2935
from temporalio.contrib.openai_agents._invoke_model_activity import (
3036
ActivityModelInput,
3137
AgentOutputSchemaInput,
3238
ApplyPatchToolInput,
39+
CustomToolInput,
3340
FunctionToolInput,
3441
HandoffInput,
3542
HostedMCPToolInput,
@@ -92,6 +99,8 @@ def make_tool_info(tool: Tool) -> ToolInput:
9299
return ApplyPatchToolInput(name=tool.name)
93100
elif isinstance(tool, HostedMCPTool):
94101
return HostedMCPToolInput(tool_config=tool.tool_config)
102+
elif isinstance(tool, CustomTool):
103+
return CustomToolInput(tool_config=tool.tool_config)
95104
elif isinstance(tool, FunctionTool):
96105
return FunctionToolInput(
97106
name=tool.name,

temporalio/contrib/opentelemetry/_otel_interceptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def workflow_interceptor_class(
200200
provider = get_tracer_provider()
201201
if not isinstance(provider, ReplaySafeTracerProvider):
202202
raise ValueError(
203-
"When using OpenTelemetryPlugin, the global trace provider must be a ReplaySafeTracerProvider. Use init_tracer_provider to create one."
203+
"When using OpenTelemetryPlugin, the global trace provider must be a ReplaySafeTracerProvider. Use create_tracer_provider to create one."
204204
)
205205

206206
class InterceptorWithState(_TracingWorkflowInboundInterceptor):

0 commit comments

Comments
 (0)