Skip to content

Commit 8cf248b

Browse files
committed
Create gRPC channel in the caller's event loop
Signed-off-by: Sergio Herrera <627709+seherv@users.noreply.github.com>
1 parent 6eb9ce0 commit 8cf248b

2 files changed

Lines changed: 49 additions & 19 deletions

File tree

ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/aio/client.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,32 @@ def __init__(
7070
else:
7171
interceptors = None
7272

73-
channel = get_grpc_aio_channel(
74-
host_address=host_address,
75-
secure_channel=secure_channel,
76-
interceptors=interceptors,
77-
options=channel_options,
78-
)
79-
self._channel = channel
80-
self._stub = stubs.TaskHubSidecarServiceStub(channel)
73+
self._host_address = host_address
74+
self._secure_channel = secure_channel
75+
self._interceptors = interceptors
76+
self._channel_options = channel_options
77+
self._channel: grpc.aio.Channel | None = None
78+
self._stub: stubs.TaskHubSidecarServiceStub | None = None
8179
self._logger = shared.get_logger('client', log_handler, log_formatter)
8280

81+
def _get_stub(self) -> stubs.TaskHubSidecarServiceStub:
82+
"""Lazily create the channel and stub on first use.
83+
84+
Async grpc binds a channel to the loop active at creation, deferring it avoids binding to the wrong loop.
85+
"""
86+
if self._stub is None:
87+
self._channel = get_grpc_aio_channel(
88+
host_address=self._host_address,
89+
secure_channel=self._secure_channel,
90+
interceptors=self._interceptors,
91+
options=self._channel_options,
92+
)
93+
self._stub = stubs.TaskHubSidecarServiceStub(self._channel)
94+
return self._stub
95+
8396
async def aclose(self):
84-
await self._channel.close()
97+
if self._channel is not None:
98+
await self._channel.close()
8599

86100
async def __aenter__(self):
87101
return self
@@ -112,14 +126,14 @@ async def schedule_new_orchestration(
112126
)
113127

114128
self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
115-
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
129+
res: pb.CreateInstanceResponse = await self._get_stub().StartInstance(req)
116130
return res.instanceId
117131

118132
async def get_orchestration_state(
119133
self, instance_id: str, *, fetch_payloads: bool = True
120134
) -> Optional[WorkflowState]:
121135
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
122-
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
136+
res: pb.GetInstanceResponse = await self._get_stub().GetInstance(req)
123137
return new_orchestration_state(req.instanceId, res)
124138

125139
async def wait_for_orchestration_start(
@@ -131,7 +145,7 @@ async def wait_for_orchestration_start(
131145
)
132146

133147
async def _call(grpc_timeout):
134-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(
148+
res: pb.GetInstanceResponse = await self._get_stub().WaitForInstanceStart(
135149
req, timeout=grpc_timeout
136150
)
137151
return new_orchestration_state(req.instanceId, res)
@@ -150,7 +164,7 @@ async def wait_for_orchestration_completion(
150164
)
151165

152166
async def _call(grpc_timeout):
153-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(
167+
res: pb.GetInstanceResponse = await self._get_stub().WaitForInstanceCompletion(
154168
req, timeout=grpc_timeout
155169
)
156170
state = new_orchestration_state(req.instanceId, res)
@@ -261,7 +275,7 @@ async def raise_orchestration_event(
261275
)
262276

263277
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
264-
await self._stub.RaiseEvent(req)
278+
await self._get_stub().RaiseEvent(req)
265279

266280
async def terminate_orchestration(
267281
self, instance_id: str, *, output: Optional[Any] = None, recursive: bool = True
@@ -273,19 +287,19 @@ async def terminate_orchestration(
273287
)
274288

275289
self._logger.info(f"Terminating instance '{instance_id}'.")
276-
await self._stub.TerminateInstance(req)
290+
await self._get_stub().TerminateInstance(req)
277291

278292
async def suspend_orchestration(self, instance_id: str):
279293
req = pb.SuspendRequest(instanceId=instance_id)
280294
self._logger.info(f"Suspending instance '{instance_id}'.")
281-
await self._stub.SuspendInstance(req)
295+
await self._get_stub().SuspendInstance(req)
282296

283297
async def resume_orchestration(self, instance_id: str):
284298
req = pb.ResumeRequest(instanceId=instance_id)
285299
self._logger.info(f"Resuming instance '{instance_id}'.")
286-
await self._stub.ResumeInstance(req)
300+
await self._get_stub().ResumeInstance(req)
287301

288302
async def purge_orchestration(self, instance_id: str, recursive: bool = True):
289303
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
290304
self._logger.info(f"Purging instance '{instance_id}'.")
291-
await self._stub.PurgeInstances(req)
305+
await self._get_stub().PurgeInstances(req)

ext/dapr-ext-workflow/tests/durabletask/test_client_async.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ def test_async_client_construct_with_metadata():
165165
with patch(
166166
'dapr.ext.workflow._durabletask.aio.internal.shared.grpc_aio.insecure_channel'
167167
) as mock_channel:
168-
AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS, metadata=METADATA)
168+
client = AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS, metadata=METADATA)
169+
assert mock_channel.call_count == 0 # channel is built lazily, not at construction
170+
171+
client._get_stub()
172+
169173
# Ensure channel created with an interceptor that has the expected metadata
170174
args, kwargs = mock_channel.call_args
171175
assert args[0] == HOST_ADDRESS
@@ -175,6 +179,18 @@ def test_async_client_construct_with_metadata():
175179
assert interceptors[0]._metadata == METADATA
176180

177181

182+
def test_async_client_channel_is_lazy():
183+
with patch(
184+
'dapr.ext.workflow._durabletask.aio.internal.shared.grpc_aio.insecure_channel'
185+
) as mock_channel:
186+
client = AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS)
187+
assert mock_channel.call_count == 0 # not built at construction
188+
189+
client._get_stub()
190+
client._get_stub()
191+
assert mock_channel.call_count == 1 # built once on first use, then cached
192+
193+
178194
def test_aio_channel_passes_base_options_and_max_lengths():
179195
base_options = [
180196
('grpc.max_send_message_length', 4321),

0 commit comments

Comments
 (0)