Skip to content

Commit f35524d

Browse files
YunchuWangCopilot
andcommitted
Fix on-demand sandbox CI checks
Use package-internal public helper names for cross-module calls, add strict typing around generated gRPC transport methods, and keep typed dataclass default factories without extra helper functions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent e8ee577 commit f35524d

6 files changed

Lines changed: 73 additions & 55 deletions

File tree

durabletask-azuremanaged/durabletask/azuremanaged/preview/on_demand_sandbox/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import grpc
77
from azure.core.credentials import TokenCredential
88

9-
from durabletask.azuremanaged.preview.on_demand_sandbox.helpers import _normalize_required
9+
from durabletask.azuremanaged.preview.on_demand_sandbox.helpers import normalize_required
1010
from durabletask.azuremanaged.preview.on_demand_sandbox.declarations import (
11-
_build_profile_on_demand_sandbox_activity_declarations,
11+
build_profile_on_demand_sandbox_activity_declarations,
1212
)
1313
from durabletask.azuremanaged.preview.on_demand_sandbox.transport import (
1414
OnDemandSandboxActivitiesGrpcTransport,
@@ -43,13 +43,13 @@ def close(self) -> None:
4343

4444
def enable_on_demand_sandbox_activities(self) -> None:
4545
"""Declare all configured on-demand sandbox worker profiles with Durable Task Scheduler."""
46-
declarations = _build_profile_on_demand_sandbox_activity_declarations()
46+
declarations = build_profile_on_demand_sandbox_activity_declarations()
4747
if not declarations:
4848
raise ValueError("No configured on-demand sandbox activities were found.")
4949

5050
for declaration in declarations:
5151
self._transport.declare_on_demand_sandbox_activities(declaration)
5252

5353
def remove_on_demand_sandbox_activity_declaration(self, worker_profile_id: str) -> None:
54-
worker_profile_id = _normalize_required(worker_profile_id, "Worker profile ID is required.")
54+
worker_profile_id = normalize_required(worker_profile_id, "Worker profile ID is required.")
5555
self._transport.remove_on_demand_sandbox_activity_declaration(worker_profile_id)

durabletask-azuremanaged/durabletask/azuremanaged/preview/on_demand_sandbox/declarations.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
from dataclasses import dataclass, field
55
from decimal import Decimal, InvalidOperation
6-
from typing import Callable, Iterable, Optional
6+
from typing import Any, Callable, Iterable, Optional
77

88
from durabletask import task
99
from durabletask.azuremanaged.internal import on_demand_sandbox_activities_service_pb2 as pb
1010
from durabletask.azuremanaged.preview.on_demand_sandbox.helpers import (
11-
_normalize_required,
12-
_resolve_activity_names,
11+
normalize_required,
12+
resolve_activity_names,
1313
)
1414

1515

@@ -32,17 +32,17 @@ class OnDemandSandboxWorkerProfileOptions:
3232
scheduler_managed_identity_client_id: Optional[str] = None
3333
cpu: str = DEFAULT_CPU
3434
memory: str = DEFAULT_MEMORY
35-
environment_variables: dict[str, str] = field(default_factory=dict)
35+
environment_variables: dict[str, str] = field(default_factory=dict[str, str])
3636
max_concurrent_activities: int = DEFAULT_MAX_CONCURRENT_ACTIVITIES
37-
entrypoint: list[str] = field(default_factory=list)
38-
cmd: list[str] = field(default_factory=list)
39-
activity_names: list[str] = field(default_factory=list)
37+
entrypoint: list[str] = field(default_factory=list[str])
38+
cmd: list[str] = field(default_factory=list[str])
39+
activity_names: list[str] = field(default_factory=list[str])
4040

41-
def add_activity(self, activity: str | Callable) -> None:
41+
def add_activity(self, activity: str | Callable[..., Any]) -> None:
4242
"""Add an activity to the on-demand sandbox worker profile declaration."""
4343
activity_name = task.get_name(activity) if callable(activity) else activity
4444
self.activity_names.append(
45-
_normalize_required(activity_name, "On-demand sandbox activity name is required."))
45+
normalize_required(activity_name, "On-demand sandbox activity name is required."))
4646

4747

4848
class OnDemandSandboxWorkerProfile:
@@ -57,7 +57,7 @@ def configure(self, options: OnDemandSandboxWorkerProfileOptions) -> None:
5757

5858
def on_demand_sandbox_worker_profile(worker_profile_id: str) -> Callable[[type], type]:
5959
"""Declare an on-demand sandbox worker profile using a decorated marker class."""
60-
normalized_profile = _normalize_required(worker_profile_id, "On-demand sandbox worker profile ID is required.")
60+
normalized_profile = normalize_required(worker_profile_id, "On-demand sandbox worker profile ID is required.")
6161

6262
def decorator(cls: type) -> type:
6363
if normalized_profile in _worker_profiles:
@@ -73,7 +73,7 @@ def decorator(cls: type) -> type:
7373
if callable(configure):
7474
configure(options)
7575

76-
if not _resolve_activity_names(options.activity_names):
76+
if not resolve_activity_names(options.activity_names):
7777
raise ValueError(
7878
f"On-demand sandbox worker profile '{normalized_profile}' must declare at least one activity.")
7979

@@ -86,7 +86,7 @@ def decorator(cls: type) -> type:
8686
def _build_on_demand_sandbox_activity_declaration(
8787
*,
8888
activity_names: str | Iterable[str],
89-
scheduler_managed_identity_client_id: str,
89+
scheduler_managed_identity_client_id: Optional[str],
9090
worker_profile_id: str = DEFAULT_WORKER_PROFILE_ID,
9191
container_image: Optional[str] = None,
9292
image_pull_managed_identity_client_id: Optional[str] = None,
@@ -103,7 +103,7 @@ def _build_on_demand_sandbox_activity_declaration(
103103
such as "myregistry.azurecr.io/workers/hello:1.0" or
104104
"myregistry.azurecr.io/workers/hello@sha256:0123456789abcdef...".
105105
"""
106-
resolved_activity_names = _resolve_activity_names(activity_names)
106+
resolved_activity_names = resolve_activity_names(activity_names)
107107
if not resolved_activity_names:
108108
raise ValueError("On-demand sandbox activity declaration requires at least one activity name.")
109109

@@ -113,16 +113,16 @@ def _build_on_demand_sandbox_activity_declaration(
113113
if max_concurrent_activities <= 0:
114114
raise ValueError("On-demand sandbox activity max concurrent activities must be greater than zero.")
115115

116-
image_ref = _normalize_required(
116+
image_ref = normalize_required(
117117
container_image,
118118
"On-demand sandbox activity image metadata requires a container image reference like "
119119
"'myregistry.azurecr.io/workers/hello:1.0' or "
120120
"'myregistry.azurecr.io/workers/hello@sha256:...'.")
121121

122-
resolved_scheduler_managed_identity_client_id = _normalize_required(
122+
resolved_scheduler_managed_identity_client_id = normalize_required(
123123
scheduler_managed_identity_client_id,
124124
"On-demand sandbox activity declaration requires the managed identity client ID workers use to connect to Durable Task Scheduler.")
125-
resolved_image_pull_managed_identity_client_id = _normalize_required(
125+
resolved_image_pull_managed_identity_client_id = normalize_required(
126126
image_pull_managed_identity_client_id,
127127
"On-demand sandbox activity declaration requires the managed identity client ID ADC uses to pull the worker image.")
128128

@@ -146,19 +146,19 @@ def _build_on_demand_sandbox_activity_declaration(
146146
return declaration
147147

148148

149-
def _build_profile_on_demand_sandbox_activity_declarations() -> list[pb.OnDemandSandboxActivityDeclaration]:
149+
def build_profile_on_demand_sandbox_activity_declarations() -> list[pb.OnDemandSandboxActivityDeclaration]:
150150
"""Build on-demand sandbox declarations from worker profile configuration."""
151151
declarations: list[pb.OnDemandSandboxActivityDeclaration] = []
152152
activity_owners: dict[str, str] = {}
153153
for profile in _worker_profiles.values():
154-
activity_names = _resolve_activity_names(profile.activity_names)
154+
activity_names = resolve_activity_names(profile.activity_names)
155155

156156
for activity_name in activity_names:
157157
existing_profile = activity_owners.get(activity_name)
158158
if existing_profile and existing_profile != profile.worker_profile_id:
159159
raise ValueError(
160160
f"On-demand sandbox activity '{activity_name}' is assigned to both worker profile "
161-
f"'{existing_profile}' and '{profile.worker_profile_id}'.")
161+
f"'{existing_profile}' and '{profile.worker_profile_id}'.")
162162
activity_owners[activity_name] = profile.worker_profile_id
163163

164164
declarations.append(_build_on_demand_sandbox_activity_declaration(
@@ -177,7 +177,7 @@ def _build_profile_on_demand_sandbox_activity_declarations() -> list[pb.OnDemand
177177
return declarations
178178

179179

180-
def _build_on_demand_sandbox_worker_start(
180+
def build_on_demand_sandbox_worker_start(
181181
*,
182182
taskhub: str,
183183
worker_profile_id: str,
@@ -194,7 +194,7 @@ def _build_on_demand_sandbox_worker_start(
194194
if max_activities_count <= 0:
195195
raise ValueError("On-demand sandbox activity worker max activity count must be greater than zero.")
196196

197-
resolved_activity_names = _resolve_activity_names(activity_names)
197+
resolved_activity_names = resolve_activity_names(activity_names)
198198
if not resolved_activity_names:
199199
raise ValueError("On-demand sandbox activity worker registration requires at least one registered activity.")
200200

@@ -209,7 +209,7 @@ def _build_on_demand_sandbox_worker_start(
209209
return message
210210

211211

212-
def _build_on_demand_sandbox_worker_heartbeat(active_activities_count: int) -> pb.OnDemandSandboxActivityWorkerMessage:
212+
def build_on_demand_sandbox_worker_heartbeat(active_activities_count: int) -> pb.OnDemandSandboxActivityWorkerMessage:
213213
if active_activities_count < 0:
214214
raise ValueError("On-demand sandbox activity worker active activity count cannot be negative.")
215215

@@ -223,7 +223,7 @@ def _normalize_optional_strings(values: Iterable[str]) -> list[str]:
223223

224224

225225
def _normalize_cpu(value: str) -> str:
226-
normalized = _normalize_required(value, "On-demand sandbox activity declaration requires CPU resources.")
226+
normalized = normalize_required(value, "On-demand sandbox activity declaration requires CPU resources.")
227227
milli_cpu = _try_parse_cpu_millicores(normalized)
228228
if milli_cpu is None or milli_cpu <= 0:
229229
raise ValueError(
@@ -233,7 +233,7 @@ def _normalize_cpu(value: str) -> str:
233233

234234

235235
def _normalize_memory(value: str) -> str:
236-
normalized = _normalize_required(value, "On-demand sandbox activity declaration requires memory resources.")
236+
normalized = normalize_required(value, "On-demand sandbox activity declaration requires memory resources.")
237237
memory_mib = _try_parse_memory_mib(normalized)
238238
if memory_mib is None or memory_mib <= 0:
239239
raise ValueError(

durabletask-azuremanaged/durabletask/azuremanaged/preview/on_demand_sandbox/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from typing import Iterable, Optional
55

66

7-
def _normalize_required(value: Optional[str], message: str) -> str:
7+
def normalize_required(value: Optional[str], message: str) -> str:
88
if not value or not value.strip():
99
raise ValueError(message)
1010
return value.strip()
1111

1212

13-
def _resolve_activity_names(activity_names: str | Iterable[str]) -> list[str]:
13+
def resolve_activity_names(activity_names: str | Iterable[str]) -> list[str]:
1414
resolved: list[str] = []
1515
seen: set[str] = set()
1616
names = [activity_names] if isinstance(activity_names, str) else activity_names

durabletask-azuremanaged/durabletask/azuremanaged/preview/on_demand_sandbox/transport.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
from typing import Iterable, Optional, Sequence
4+
from typing import Callable, Iterable, Optional, Protocol, Sequence, cast
55

66
import grpc
77
from azure.core.credentials import TokenCredential
@@ -15,6 +15,21 @@
1515
import durabletask.internal.shared as shared
1616

1717

18+
class _OnDemandSandboxActivitiesStub(Protocol):
19+
DeclareOnDemandSandboxActivities: Callable[
20+
[pb.OnDemandSandboxActivityDeclaration],
21+
pb.OnDemandSandboxActivityDeclarationResult,
22+
]
23+
RemoveOnDemandSandboxActivityDeclaration: Callable[
24+
[pb.RemoveOnDemandSandboxActivityDeclarationRequest],
25+
pb.RemoveOnDemandSandboxActivityDeclarationResult,
26+
]
27+
ConnectOnDemandSandboxActivityWorker: Callable[
28+
[Iterable[pb.OnDemandSandboxActivityWorkerMessage]],
29+
pb.OnDemandSandboxActivityWorkerSessionResult,
30+
]
31+
32+
1833
class OnDemandSandboxActivitiesGrpcTransport:
1934
"""Internal gRPC transport for on-demand sandbox activity RPCs."""
2035

@@ -42,7 +57,7 @@ def __init__(
4257
interceptors=resolved_interceptors,
4358
channel_options=channel_options)
4459
self._channel = channel
45-
self._stub = stubs.OnDemandSandboxActivitiesStub(channel)
60+
self._stub = cast(_OnDemandSandboxActivitiesStub, stubs.OnDemandSandboxActivitiesStub(channel))
4661

4762
def close(self) -> None:
4863
if self._owns_channel:

durabletask-azuremanaged/durabletask/azuremanaged/preview/on_demand_sandbox/worker.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
import random
66
import threading
77

8-
from typing import Iterator, Optional
8+
from typing import Any, Iterator, Optional
99

10+
from azure.core.credentials import TokenCredential
1011
from azure.identity import ManagedIdentityCredential
1112

12-
from durabletask.azuremanaged.preview.on_demand_sandbox.helpers import _resolve_activity_names
13+
from durabletask import task
14+
from durabletask.azuremanaged.internal import on_demand_sandbox_activities_service_pb2 as pb
15+
from durabletask.azuremanaged.preview.on_demand_sandbox.helpers import resolve_activity_names
1316
from durabletask.azuremanaged.preview.on_demand_sandbox.declarations import (
1417
DEFAULT_MAX_CONCURRENT_ACTIVITIES,
1518
DEFAULT_WORKER_PROFILE_ID,
16-
_build_on_demand_sandbox_worker_heartbeat,
17-
_build_on_demand_sandbox_worker_start,
19+
build_on_demand_sandbox_worker_heartbeat,
20+
build_on_demand_sandbox_worker_start,
1821
)
1922
from durabletask.azuremanaged.preview.on_demand_sandbox.transport import (
2023
OnDemandSandboxActivitiesGrpcTransport,
@@ -35,7 +38,7 @@ class OnDemandSandboxWorker(DurableTaskSchedulerWorker):
3538
restricts dispatch to the activities registered on this worker.
3639
"""
3740

38-
def __init__(self):
41+
def __init__(self) -> None:
3942
resolved_host_address = _resolve_host_address()
4043
resolved_taskhub = _resolve_taskhub()
4144
resolved_secure_channel = _resolve_secure_channel(resolved_host_address)
@@ -68,7 +71,7 @@ def __init__(self):
6871
self._on_demand_sandbox_active_activities = 0
6972
self._on_demand_sandbox_active_activities_lock = threading.Lock()
7073

71-
def add_activity(self, fn) -> str:
74+
def add_activity(self, fn: task.Activity[Any, Any]) -> str:
7275
activity_name = super().add_activity(fn)
7376
self._on_demand_sandbox_activity_names.append(activity_name)
7477
return activity_name
@@ -82,16 +85,16 @@ def stop(self) -> None:
8285
self._stop_on_demand_sandbox_registration()
8386
super().stop()
8487

85-
def _durabletask_on_activity_execution_started(self, req) -> None:
88+
def _durabletask_on_activity_execution_started(self, req: object) -> None:
8689
with self._on_demand_sandbox_active_activities_lock:
8790
self._on_demand_sandbox_active_activities += 1
8891

89-
def _durabletask_on_activity_execution_completed(self, req) -> None:
92+
def _durabletask_on_activity_execution_completed(self, req: object) -> None:
9093
with self._on_demand_sandbox_active_activities_lock:
9194
self._on_demand_sandbox_active_activities = max(0, self._on_demand_sandbox_active_activities - 1)
9295

9396
def _configure_on_demand_sandbox_activity_filters(self) -> None:
94-
activity_names = _resolve_activity_names(self._on_demand_sandbox_activity_names)
97+
activity_names = resolve_activity_names(self._on_demand_sandbox_activity_names)
9598
if not activity_names:
9699
raise RuntimeError(
97100
"On-demand sandbox worker requires at least one registered activity before it can register.")
@@ -143,8 +146,8 @@ def _run_on_demand_sandbox_registration_loop(self) -> None:
143146
self._on_demand_sandbox_registration_stop.wait(delay)
144147
retry_delay = min(retry_delay * 2, 30.0)
145148

146-
def _registration_messages(self) -> Iterator:
147-
yield _build_on_demand_sandbox_worker_start(
149+
def _registration_messages(self) -> Iterator[pb.OnDemandSandboxActivityWorkerMessage]:
150+
yield build_on_demand_sandbox_worker_start(
148151
taskhub=self._on_demand_sandbox_taskhub,
149152
worker_profile_id=self._on_demand_sandbox_worker_profile_id,
150153
max_activities_count=self._on_demand_sandbox_max_activities,
@@ -156,7 +159,7 @@ def _registration_messages(self) -> Iterator:
156159
self._on_demand_sandbox_heartbeat_interval_seconds):
157160
with self._on_demand_sandbox_active_activities_lock:
158161
active_count = self._on_demand_sandbox_active_activities
159-
yield _build_on_demand_sandbox_worker_heartbeat(active_count)
162+
yield build_on_demand_sandbox_worker_heartbeat(active_count)
160163

161164

162165
def _resolve_taskhub() -> str:
@@ -193,7 +196,7 @@ def _resolve_worker_profile_id() -> str:
193196
return resolved_worker_profile_id.strip()
194197

195198

196-
def _resolve_token_credential():
199+
def _resolve_token_credential() -> TokenCredential | None:
197200
authentication = os.getenv("DTS_AUTHENTICATION", "")
198201
if authentication.lower() != "managedidentity":
199202
return None

0 commit comments

Comments
 (0)