Skip to content

Commit f7a2498

Browse files
guptaakacopybara-github
authored andcommitted
Add a script to deploy Pathways service as a JobSet
PiperOrigin-RevId: 896092607
1 parent 2cb53bb commit f7a2498

2 files changed

Lines changed: 378 additions & 16 deletions

File tree

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""Deploys Pathways service to a Kubernetes cluster using a JobSet template."""
2+
3+
from collections.abc import Sequence
4+
import logging
5+
import math
6+
import os
7+
import string
8+
from absl import app
9+
from absl import flags
10+
from kubernetes import client
11+
from kubernetes import config
12+
import yaml
13+
14+
_logger = logging.getLogger(__name__)
15+
16+
# Flag definitions
17+
FLAGS = flags.FLAGS
18+
_JOBSET_NAME = flags.DEFINE_string(
19+
"jobset_name", "pathways-service", "Name of the JobSet"
20+
)
21+
_JAX_VERSION = flags.DEFINE_string(
22+
"jax_version", "0.9.0", "JAX version (e.g., 0.9.0)"
23+
)
24+
_TPU_TYPE = flags.DEFINE_enum(
25+
"tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type"
26+
)
27+
_TOPOLOGY = flags.DEFINE_string(
28+
"topology", "2x2", "TPU topology (e.g., 4x8, 2x2x2)"
29+
)
30+
_NUM_SLICES = flags.DEFINE_integer(
31+
"num_slices", 2, "Number of TPU slices"
32+
)
33+
_GCS_BUCKET = flags.DEFINE_string(
34+
"gcs_bucket",
35+
"gs://pathways-test-bucket",
36+
"GCS bucket name for scratch space",
37+
)
38+
_TEMPLATE_FILE = flags.DEFINE_string(
39+
"template_file",
40+
os.path.join(
41+
os.path.dirname(__file__), "yamls/pw-service-example.yaml",
42+
),
43+
"Path to the JobSet YAML template file",
44+
)
45+
_DRY_RUN = flags.DEFINE_boolean(
46+
"dry_run",
47+
False,
48+
"If true, only print the generated YAML without deploying.",
49+
)
50+
51+
52+
def get_tpu_config(tpu_type):
53+
"""Returns a dictionary containing TPU configuration details."""
54+
tpu_configs = {
55+
"v5e": {
56+
"machine_type": "ct5lp-hightpu-4t",
57+
"chips_per_vm": 4,
58+
"accelerator_label": "tpu-v5-lite-podslice",
59+
"instance_prefix": "tpuv5e",
60+
},
61+
"v5p": {
62+
"machine_type": "ct5p-hightpu-4t",
63+
"chips_per_vm": 4,
64+
"accelerator_label": "tpu-v5p-slice",
65+
"instance_prefix": "tpuv5p",
66+
},
67+
"v6e": {
68+
"machine_type": "ct6e-standard-4t",
69+
"chips_per_vm": 4,
70+
"accelerator_label": "tpu-v6e-slice",
71+
"instance_prefix": "tpuv6e",
72+
},
73+
"tpu7x": {
74+
"machine_type": "tpu7x-standard-4t",
75+
"chips_per_vm": 4,
76+
"accelerator_label": "tpu-v7-slice",
77+
"instance_prefix": "tpu7x",
78+
},
79+
}
80+
if tpu_type not in tpu_configs:
81+
raise ValueError(
82+
f"Unsupported TPU type: {tpu_type}. Supported types are:"
83+
f" {list(tpu_configs.keys())}"
84+
)
85+
return tpu_configs[tpu_type]
86+
87+
88+
def calculate_vms_per_slice(topology, chips_per_vm):
89+
"""Calculates the number of VMs per slice based on the topology."""
90+
try:
91+
dims = [int(d) for d in topology.split("x")]
92+
total_chips = math.prod(dims)
93+
if total_chips % chips_per_vm != 0:
94+
raise ValueError(
95+
f"Total chips ({total_chips}) in topology {topology} is not divisible"
96+
f" by chips_per_vm ({chips_per_vm})"
97+
)
98+
return total_chips // chips_per_vm
99+
except ValueError as e:
100+
raise ValueError(
101+
f"Invalid topology format: {topology}. Expected format like 'AxB' or"
102+
f" 'AxBxC'. {e}"
103+
) from e
104+
105+
106+
def load_and_substitute_template(template_path, context):
107+
"""Loads and substitutes the string.Template from the given path."""
108+
try:
109+
with open(template_path, "r") as f:
110+
template_str = f.read()
111+
except OSError as err:
112+
raise ValueError(
113+
f"Could not read template file: {template_path}: {err}"
114+
) from err
115+
116+
_logger.info("Template file: %s", template_path)
117+
_logger.info("Context: %s", context)
118+
template = string.Template(template_str)
119+
_logger.info("Template: %s", template)
120+
substituted_yaml = template.substitute(context)
121+
_logger.info("Substituted YAML: %s", substituted_yaml)
122+
return yaml.safe_load(substituted_yaml)
123+
124+
125+
def deploy_jobset(jobset_yaml):
126+
"""Deploys the JobSet to the current Kubernetes cluster."""
127+
try:
128+
config.load_kube_config()
129+
api = client.CustomObjectsApi()
130+
api.create_namespaced_custom_object(
131+
group="jobset.x-k8s.io",
132+
version="v1alpha2",
133+
namespace=jobset_yaml["metadata"]["namespace"],
134+
body=jobset_yaml,
135+
plural="jobsets",
136+
)
137+
_logger.info(
138+
"JobSet '%s' created successfully.", jobset_yaml["metadata"]["name"]
139+
)
140+
except client.rest.ApiException as e:
141+
_logger.error("Error creating JobSet: %s", e)
142+
except config.ConfigException as e:
143+
_logger.error("Error loading Kubernetes configuration: %s", e)
144+
# TODO idea -- keep checking until up -- surface logs.
145+
146+
def run_deployment(
147+
tpu_type,
148+
topology,
149+
num_slices,
150+
jobset_name,
151+
gcs_bucket,
152+
jax_version,
153+
template_file,
154+
dry_run,
155+
deploy_func=deploy_jobset,
156+
):
157+
"""Executes the deployment logic."""
158+
tpu_config = get_tpu_config(tpu_type)
159+
vms_per_slice = calculate_vms_per_slice(topology, tpu_config["chips_per_vm"])
160+
161+
context = {
162+
"JOBSET_NAME": jobset_name,
163+
"JAX_VERSION": jax_version,
164+
"GCS_SCRATCH_LOCATION": gcs_bucket,
165+
"NUM_SLICES": num_slices,
166+
"INSTANCE_TYPE": f"{tpu_config['instance_prefix']}:{topology}",
167+
"VMS_PER_SLICE": vms_per_slice,
168+
"CHIPS_PER_VM": tpu_config["chips_per_vm"],
169+
"ACCELERATOR_LABEL": tpu_config["accelerator_label"],
170+
"TOPOLOGY": topology,
171+
}
172+
173+
jobset_config = load_and_substitute_template(template_file, context)
174+
175+
_logger.info("--- Generated JobSet YAML ---")
176+
_logger.info("\n%s", yaml.dump(jobset_config))
177+
_logger.info("---")
178+
179+
if not dry_run:
180+
deploy_func(jobset_config)
181+
else:
182+
_logger.info("Dry run mode, not deploying.")
183+
184+
185+
def main(argv: Sequence[str]) -> None:
186+
if len(argv) > 1:
187+
raise app.UsageError("Too many command-line arguments.")
188+
189+
try:
190+
run_deployment(
191+
tpu_type=_TPU_TYPE.value,
192+
topology=_TOPOLOGY.value,
193+
num_slices=_NUM_SLICES.value,
194+
jobset_name=_JOBSET_NAME.value,
195+
gcs_bucket=_GCS_BUCKET.value,
196+
jax_version=_JAX_VERSION.value,
197+
template_file=_TEMPLATE_FILE.value,
198+
dry_run=_DRY_RUN.value,
199+
)
200+
except ValueError as e:
201+
_logger.exception("Error: %s", e)
202+
except FileNotFoundError:
203+
_logger.exception(
204+
"Error: Template file not found at %s", _TEMPLATE_FILE.value
205+
)
206+
207+
208+
if __name__ == "__main__":
209+
app.run(main)
Lines changed: 169 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,171 @@
1-
apiVersion: pathways-job.pathways.domain/v1
2-
kind: PathwaysJob
1+
apiVersion: jobset.x-k8s.io/v1alpha2
2+
kind: JobSet
33
metadata:
4-
name: pathways-cluster # jobset name
4+
name: ${JOBSET_NAME}
5+
namespace: default
56
spec:
6-
maxRestarts: 1
7-
customComponents:
8-
- componentType: pathways_server
9-
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.8.0@sha256:ccbdf86d185654f8fb749f51ca7dcc8178377b583d75f74180eb936a8f808050
10-
- componentType: worker
11-
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.8.0@sha256:ccbdf86d185654f8fb749f51ca7dcc8178377b583d75f74180eb936a8f808050
12-
workers: # Modify this section to use your TPU type, topology, number of slices and the GCS bucket.
13-
- type: ct6e-standard-4t
14-
topology: 2x2
15-
numSlices: 2
16-
pathwaysDir: "gs://pathways-bucket" # Pre-create this bucket.
17-
controller:
18-
deploymentMode: default
7+
coordinator:
8+
replicatedJob: pathways-head
9+
failurePolicy:
10+
maxRestarts: 1
11+
restartStrategy: Recreate
12+
network:
13+
enableDNSHostnames: true
14+
publishNotReadyAddresses: true
15+
replicatedJobs:
16+
- name: pathways-head
17+
replicas: 1
18+
template:
19+
metadata:
20+
annotations:
21+
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
22+
spec:
23+
backoffLimit: 3
24+
completionMode: Indexed
25+
completions: 1
26+
parallelism: 1
27+
template:
28+
metadata:
29+
annotations:
30+
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
31+
spec:
32+
containers:
33+
- name: pathways-rm
34+
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-${JAX_VERSION}
35+
imagePullPolicy: Always
36+
args:
37+
- --server_port=29001
38+
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
39+
- --node_type=resource_manager
40+
- --instance_count=${NUM_SLICES}
41+
- --instance_type=${INSTANCE_TYPE}
42+
env:
43+
- name: REPLICATED_JOB_NAME
44+
valueFrom:
45+
fieldRef:
46+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
47+
- name: JOBSET_NAME
48+
valueFrom:
49+
fieldRef:
50+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
51+
- name: HOST_ADDRESS
52+
valueFrom:
53+
fieldRef:
54+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
55+
- name: TPU_SKIP_MDS_QUERY
56+
value: "true"
57+
ports:
58+
- containerPort: 29001
59+
protocol: TCP
60+
- containerPort: 29002
61+
protocol: TCP
62+
resources:
63+
limits:
64+
cpu: "8"
65+
memory: 32G
66+
- name: pathways-proxy
67+
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-${JAX_VERSION}
68+
imagePullPolicy: Always
69+
args:
70+
- --server_port=29000
71+
- --resource_manager_address=$$(PATHWAYS_HEAD):29001
72+
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
73+
env:
74+
- name: PATHWAYS_HEAD
75+
valueFrom:
76+
fieldRef:
77+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
78+
ports:
79+
- containerPort: 29000
80+
protocol: TCP
81+
resources:
82+
limits:
83+
cpu: "16"
84+
memory: 100G
85+
dnsPolicy: ClusterFirstWithHostNet
86+
hostNetwork: true
87+
restartPolicy: OnFailure
88+
- name: worker
89+
replicas: ${NUM_SLICES}
90+
template:
91+
spec:
92+
backoffLimit: 1000000
93+
completionMode: Indexed
94+
completions: ${VMS_PER_SLICE}
95+
parallelism: ${VMS_PER_SLICE}
96+
template:
97+
metadata:
98+
annotations:
99+
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
100+
spec:
101+
containers:
102+
- name: pathways-worker
103+
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-${JAX_VERSION}
104+
imagePullPolicy: Always
105+
args:
106+
- --server_port=29005
107+
- --resource_manager_address=$$(PATHWAYS_HEAD):29001
108+
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
109+
env:
110+
- name: TPU_MIN_LOG_LEVEL
111+
value: "0"
112+
- name: TF_CPP_MIN_LOG_LEVEL
113+
value: "0"
114+
- name: XCLOUD_ENVIRONMENT
115+
value: GCP
116+
- name: MEGASCALE_GRPC_ENABLE_XOR_TRACER
117+
value: "false"
118+
- name: MEGASCALE_NUM_SLICES
119+
valueFrom:
120+
fieldRef:
121+
fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']
122+
- name: JOBSET_NAME
123+
valueFrom:
124+
fieldRef:
125+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
126+
- name: REPLICATED_JOB_NAME
127+
valueFrom:
128+
fieldRef:
129+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
130+
- name: MEGASCALE_SLICE_ID
131+
valueFrom:
132+
fieldRef:
133+
fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index']
134+
- name: PATHWAYS_HEAD
135+
valueFrom:
136+
fieldRef:
137+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
138+
- name: MEGASCALE_COORDINATOR_ADDRESS
139+
valueFrom:
140+
fieldRef:
141+
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
142+
ports:
143+
- containerPort: 29005
144+
protocol: TCP
145+
- containerPort: 29006
146+
protocol: TCP
147+
- containerPort: 8471
148+
protocol: TCP
149+
- containerPort: 8080
150+
protocol: TCP
151+
resources:
152+
limits:
153+
google.com/tpu: "${CHIPS_PER_VM}"
154+
volumeMounts:
155+
- mountPath: /tmp
156+
name: shared-tmp
157+
dnsPolicy: ClusterFirstWithHostNet
158+
hostNetwork: true
159+
nodeSelector:
160+
cloud.google.com/gke-tpu-accelerator: ${ACCELERATOR_LABEL}
161+
cloud.google.com/gke-tpu-topology: ${TOPOLOGY}
162+
restartPolicy: OnFailure
163+
volumes:
164+
- name: shared-tmp
165+
hostPath:
166+
path: /tmp
167+
type: DirectoryOrCreate
168+
startupPolicy:
169+
startupPolicyOrder: InOrder
170+
successPolicy:
171+
operator: All

0 commit comments

Comments
 (0)