Skip to content

Commit 1b4bace

Browse files
committed
Port system Nexus payload handling to WIT generation
1 parent 1bc2a59 commit 1b4bace

21 files changed

Lines changed: 1165 additions & 185 deletions

scripts/_nexus/deps/nexus-temporal-types/model.wit

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/// @nexus.support
2-
/// python="python/model_overrides.py"
3-
/// typescript="typescript/model_overrides.ts"
2+
/// python="python/temporal_model_converters.py"
3+
/// typescript="typescript/temporal_model_converters.ts"
44
package nexus:temporal-types@1.0.0;
55

66
interface model {

scripts/_nexus/deps/nexus-temporal-types/python/model_overrides.py renamed to scripts/_nexus/deps/nexus-temporal-types/python/temporal_model_converters.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# pyright: reportAny=false, reportExplicitAny=false
22

33
import collections.abc
4-
from datetime import timedelta
54
import typing
5+
from datetime import timedelta
66

77
import google.protobuf.duration_pb2
8+
89
import temporalio.api.common.v1.message_pb2 as common_pb2
910
import temporalio.api.enums.v1.workflow_pb2 as workflow_enums_pb2
1011
import temporalio.api.taskqueue.v1.message_pb2 as taskqueue_pb2
1112
import temporalio.api.workflow.v1
12-
import temporalio.converter
1313
import temporalio.common
14-
import temporalio.workflow
14+
import temporalio.converter
1515

1616

1717
def retry_policy_from_proto(
@@ -31,14 +31,20 @@ def retry_policy_to_proto(
3131
def workflow_function_name(
3232
value: str | collections.abc.Callable[..., collections.abc.Awaitable[object]],
3333
) -> str:
34-
name, _result_type = temporalio.workflow._Definition.get_name_and_result_type(value) # pyright: ignore[reportPrivateUsage]
34+
from temporalio.workflow import _Definition # pyright: ignore[reportPrivateUsage]
35+
36+
name, _result_type = _Definition.get_name_and_result_type(value)
3537
return name
3638

3739

3840
def signal_function_to_proto(
3941
value: str | collections.abc.Callable[..., typing.Any],
4042
) -> str:
41-
return temporalio.workflow._SignalDefinition.must_name_from_fn_or_str(value) # pyright: ignore[reportPrivateUsage, reportUnknownMemberType]
43+
from temporalio.workflow import (
44+
_SignalDefinition, # pyright: ignore[reportPrivateUsage]
45+
)
46+
47+
return _SignalDefinition.must_name_from_fn_or_str(value) # pyright: ignore[reportUnknownMemberType]
4248

4349

4450
def workflow_type_to_proto(
@@ -61,13 +67,17 @@ def task_queue_to_proto(
6167

6268

6369
def workflow_namespace() -> str:
64-
return temporalio.workflow.info().namespace
70+
from temporalio.workflow import info
71+
72+
return info().namespace
6573

6674

6775
def payloads_to_proto(
6876
values: collections.abc.Sequence[typing.Any],
6977
) -> common_pb2.Payloads:
70-
return temporalio.workflow.payload_converter().to_payloads_wrapper(values)
78+
from temporalio.workflow import payload_converter
79+
80+
return payload_converter().to_payloads_wrapper(values)
7181

7282

7383
def _clone_payload(payload: common_pb2.Payload) -> common_pb2.Payload:
@@ -79,16 +89,20 @@ def _clone_payload(payload: common_pb2.Payload) -> common_pb2.Payload:
7989
def _value_to_payload(value: object | common_pb2.Payload) -> common_pb2.Payload:
8090
if isinstance(value, common_pb2.Payload):
8191
return _clone_payload(value)
82-
payloads = temporalio.workflow.payload_converter().to_payloads_wrapper([value])
92+
from temporalio.workflow import payload_converter
93+
94+
payloads = payload_converter().to_payloads_wrapper([value])
8395
return _clone_payload(payloads.payloads[0])
8496

8597

8698
def _payload_to_value(payload: common_pb2.Payload) -> object:
8799
wrapper = common_pb2.Payloads()
88100
wrapper.payloads.add().CopyFrom(payload)
101+
from temporalio.workflow import payload_converter
102+
89103
return typing.cast(
90104
object,
91-
temporalio.workflow.payload_converter().from_payloads_wrapper(wrapper)[0],
105+
payload_converter().from_payloads_wrapper(wrapper)[0],
92106
)
93107

94108

@@ -152,7 +166,9 @@ def workflow_id_conflict_policy_from_proto(
152166
def workflow_id_conflict_policy_to_proto(
153167
policy: temporalio.common.WorkflowIDConflictPolicy,
154168
) -> workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType:
155-
return typing.cast(workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, int(policy))
169+
return typing.cast(
170+
workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, int(policy)
171+
)
156172

157173

158174
def search_attributes_to_proto(

scripts/_nexus/deps/nexus-temporal-types/typescript/model_overrides.ts renamed to scripts/_nexus/deps/nexus-temporal-types/typescript/temporal_model_converters.ts

File renamed without changes.

scripts/_nexus/temporal-system.wit

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ world system {
66

77
/// @nexus.endpoint "temporal-system"
88
/// @nexus.service-name "temporal.api.workflowservice.v1.WorkflowService"
9+
/// @nexus.delay-load-temporalio-workflow
10+
/// @nexus.experimental
911
interface workflow-service {
1012
use nexus:temporal-types/model@1.0.0.{
1113
duration,
@@ -25,6 +27,7 @@ interface workflow-service {
2527
};
2628

2729
/// @nexus.doc "Request fields for signaling a workflow, starting it first if needed."
30+
/// @nexus.experimental
2831
/// @nexus.proto "temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" typescript-package="@temporalio/proto"
2932
record signal-with-start-workflow-request {
3033
/// @nexus.doc
@@ -90,6 +93,7 @@ interface workflow-service {
9093
time-skipping-config: placeholder,
9194
}
9295

96+
/// @nexus.experimental
9397
/// @nexus.proto "temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse" typescript-package="@temporalio/proto"
9498
record signal-with-start-workflow-response {
9599
run-id: option<string>,
@@ -108,6 +112,7 @@ interface workflow-service {
108112
/// typescript="workflow.getExternalWorkflowHandle(request.id, result.runId ?? undefined)"
109113
/// typescript-package="@temporalio/workflow"
110114
/// @nexus.operation name="SignalWithStartWorkflowExecution"
115+
/// @nexus.experimental
111116
signal-with-start-workflow: func(
112117
request: signal-with-start-workflow-request,
113118
) -> signal-with-start-workflow-response;

scripts/gen_nexus_system_api.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import subprocess
44
import sys
55
import tempfile
6+
from importlib.util import module_from_spec, spec_from_file_location
67
from pathlib import Path
8+
from typing import cast
79

810
import gen_protos
911

1012
base_dir = Path(__file__).parent.parent
13+
sys.path.insert(0, str(base_dir))
1114
wit_input_dir = base_dir / "scripts" / "_nexus"
1215
wit_path = wit_input_dir / "temporal-system.wit"
1316
wit_deps_dir = wit_input_dir / "deps"
1417
output_dir = base_dir / "temporalio" / "nexus" / "system" / "workflow_service"
15-
default_nex_gen_install_root = (
16-
Path(tempfile.gettempdir()) / "temporal-sdk-python-nex-gen"
17-
)
18+
workflow_init_path = base_dir / "temporalio" / "workflow" / "__init__.py"
1819
workflowservice_request_response_proto = (
1920
gen_protos.api_proto_dir
2021
/ "temporal"
@@ -29,30 +30,9 @@ def nex_gen_command() -> list[str]:
2930
if bin_path := os.environ.get("NEX_GEN_BIN"):
3031
return [bin_path]
3132

32-
return [str(install_published_nex_gen())]
33-
34-
35-
def install_published_nex_gen() -> Path:
36-
install_root = Path(
37-
os.environ.get(
38-
"NEX_GEN_INSTALL_ROOT",
39-
str(default_nex_gen_install_root),
40-
)
41-
)
42-
bin_name = "nex-gen.exe" if os.name == "nt" else "nex-gen"
43-
bin_path = install_root / "bin" / bin_name
44-
if not bin_path.exists():
45-
subprocess.check_call(
46-
[
47-
"cargo",
48-
"install",
49-
"--locked",
50-
"--root",
51-
str(install_root),
52-
"nex-gen",
53-
]
54-
)
55-
return bin_path
33+
if shutil.which("nex-gen") is None:
34+
subprocess.check_call(["cargo", "install", "--locked", "nex-gen"])
35+
return ["nex-gen"]
5636

5737

5838
def build_descriptor_set(descriptor_path: Path) -> None:
@@ -79,6 +59,42 @@ def strip_unsupported_pyright_comments() -> None:
7959
path.write_text(content)
8060

8161

62+
def generate_workflow_exports() -> None:
63+
spec = spec_from_file_location(
64+
"temporalio_nexus_system_workflow_service_exports",
65+
output_dir / "__init__.py",
66+
submodule_search_locations=[str(output_dir)],
67+
)
68+
if spec is None or spec.loader is None:
69+
raise RuntimeError(f"Cannot load generated workflow service from {output_dir}")
70+
module = module_from_spec(spec)
71+
sys.modules[spec.name] = module
72+
spec.loader.exec_module(module)
73+
exports = cast(list[str], module.__all__)
74+
75+
import_block = [
76+
"# BEGIN GENERATED NEXUS SYSTEM EXPORTS\n",
77+
"from temporalio.nexus.system.workflow_service import (\n",
78+
*[f" {export},\n" for export in exports],
79+
")\n",
80+
"# END GENERATED NEXUS SYSTEM EXPORTS\n",
81+
]
82+
all_block = [
83+
" # BEGIN GENERATED NEXUS SYSTEM __ALL__\n",
84+
*[f' "{export}",\n' for export in exports],
85+
" # END GENERATED NEXUS SYSTEM __ALL__\n",
86+
]
87+
content = workflow_init_path.read_text()
88+
start = content.index("# BEGIN GENERATED NEXUS SYSTEM EXPORTS")
89+
end = content.index("# END GENERATED NEXUS SYSTEM EXPORTS", start)
90+
end = content.index("\n", end) + 1
91+
content = content[:start] + "".join(import_block) + content[end:]
92+
start = content.index(" # BEGIN GENERATED NEXUS SYSTEM __ALL__")
93+
end = content.index(" # END GENERATED NEXUS SYSTEM __ALL__", start)
94+
end = content.index("\n", end) + 1
95+
workflow_init_path.write_text(content[:start] + "".join(all_block) + content[end:])
96+
97+
8298
def generate_nexus_system_api() -> None:
8399
if not wit_path.exists():
84100
raise RuntimeError(f"missing WIT source: {wit_path}")
@@ -111,6 +127,7 @@ def generate_nexus_system_api() -> None:
111127

112128
(output_dir.parent / "__init__.py").touch()
113129
strip_unsupported_pyright_comments()
130+
generate_workflow_exports()
114131
subprocess.check_call(
115132
[
116133
sys.executable,
@@ -121,6 +138,7 @@ def generate_nexus_system_api() -> None:
121138
"I",
122139
"--fix",
123140
str(output_dir),
141+
str(workflow_init_path),
124142
]
125143
)
126144
subprocess.check_call(
@@ -130,6 +148,7 @@ def generate_nexus_system_api() -> None:
130148
"ruff",
131149
"format",
132150
str(output_dir),
151+
str(workflow_init_path),
133152
]
134153
)
135154

0 commit comments

Comments
 (0)