-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathisc_pathways.py
More file actions
284 lines (252 loc) · 9.83 KB
/
isc_pathways.py
File metadata and controls
284 lines (252 loc) · 9.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""Module for connecting to a Pathways server for interactive supercomputing."""
from collections.abc import Iterator, Mapping
import contextlib
import gc
import logging
import os
import random
import string
import subprocess
from typing import Any
import jax
import jax.extend.backend as jax_backend
import pathwaysutils
from pathwaysutils.experimental.shared_pathways_service import gke_utils
from pathwaysutils.experimental.shared_pathways_service import validators
PROXY_FILEPATH = os.path.join(
os.path.dirname(__file__), "yamls/pw-proxy.yaml"
)
# TODO(b/459935429): Hardcoding the port and using hostNetwork: true in the
# proxy YAML limits us to one proxy server pod per node. Consider alternative
# networking configurations to allow multiple proxies per node if needed.
PROXY_SERVER_PORT = 29_000
_JAX_PLATFORMS_KEY = "jax_platforms"
_JAX_PLATFORM_PROXY = "proxy"
_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
_JAX_BACKEND_TARGET_HOSTNAME = "grpc://127.0.0.1"
_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
_logger = logging.getLogger(__name__)
class _ISCPathways:
"""Class for managing TPUs for interactive supercomputing.
Attributes:
cluster: The name of the GKE cluster.
project: The GCP project ID.
region: The GCP region.
bucket: The Google Cloud Storage bucket to use.
pathways_service: The service name and port of the Pathways head pod.
expected_tpu_instances: A dictionary mapping TPU machine types to the number
of instances.
proxy_job_name: The name to use for the deployed proxy.
proxy_server_image: The image to use for the proxy server.
"""
def __init__(
self,
*, cluster: str,
project: str,
region: str,
gcs_bucket: str,
pathways_service: str,
expected_tpu_instances: Mapping[Any, Any],
proxy_job_name: str,
proxy_server_image: str,
):
"""Initializes the TPU manager."""
self.cluster = cluster
self.project = project
self.region = region
self.bucket = gcs_bucket
self.pathways_service = pathways_service
self.expected_tpu_instances = expected_tpu_instances
self._proxy_job_name = proxy_job_name
self._port_forward_process = None
self._proxy_port = None
self.proxy_server_image = proxy_server_image
self._run_cleanup = True
def __repr__(self):
return (
f"_ISCPathways(cluster='{self.cluster}', project='{self.project}', "
f"region='{self.region}', bucket='{self.bucket}', "
f"pathways_service='{self.pathways_service}', "
f"expected_tpu_instances={self.expected_tpu_instances}, "
f"_proxy_job_name='{self._proxy_job_name}', "
f"proxy_server_image='{self.proxy_server_image}')"
)
def _deploy_pathways_proxy_server(
self, pathways_service: str,
proxy_job_name: str,
expected_instances: Mapping[Any, Any],
gcs_scratch_location: str,
proxy_server_image: str,
) -> None:
"""Deploys the Pathways proxy pods to the GKE cluster.
Args:
pathways_service: The service name and port of the Pathways head.
proxy_job_name: The name to use for the deployed proxy.
expected_instances: A dictionary mapping instance types to the number of
instances.
gcs_scratch_location: The Google Cloud Storage location to use.
proxy_server_image: The image to use for the proxy server.
Raises:
subprocess.CalledProcessError: If the kubectl command fails.
RuntimeError: If a proxy with the given name already exists.
"""
try:
with open(PROXY_FILEPATH, "r") as f:
yaml_template = f.read()
except OSError as err:
raise ValueError("Could not read file: " + PROXY_FILEPATH) from err
pathways_head_hostname, pathways_head_port = pathways_service.split(":")
# Take the first instance type and count since we only support a single
# instance type for now.
instance_type, count = next(iter(expected_instances.items()))
instances_str = ",".join(instance_type for _ in range(count))
template = string.Template(yaml_template)
substituted_yaml = template.substitute(
PROXY_JOB_NAME=proxy_job_name,
PROXY_SERVER_PORT=PROXY_SERVER_PORT,
PATHWAYS_HEAD_HOSTNAME=pathways_head_hostname,
PATHWAYS_HEAD_PORT=pathways_head_port,
EXPECTED_INSTANCES=instances_str,
GCS_SCRATCH_LOCATION=gcs_scratch_location,
PROXY_SERVER_IMAGE=proxy_server_image,
)
# In _deploy_pathways_proxy_server
if gke_utils.job_exists(proxy_job_name):
_logger.info(
"Proxy job '%s' already exists. Skipping deployment.", proxy_job_name
)
# Prevent cleanup since the existing proxy may be in use.
self._run_cleanup = False
# Potentially raise a specific error or handle appropriately
raise RuntimeError(
f"A proxy with the name '{proxy_job_name}' already exists. Please "
"choose a different name and try again."
)
_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
try:
gke_utils.deploy_gke_yaml(substituted_yaml)
except subprocess.CalledProcessError as e:
_logger.exception("Failed to deploy Pathways proxy.")
raise e
_logger.info("Successfully deployed Pathways proxy.")
def __enter__(self):
"""Enters the context manager, ensuring cluster exists."""
try:
self._deploy_pathways_proxy_server(
pathways_service=self.pathways_service,
proxy_job_name=self._proxy_job_name,
expected_instances=self.expected_tpu_instances,
gcs_scratch_location=self.bucket,
proxy_server_image=self.proxy_server_image,
)
# Print a link to Cloud Logging
cloud_logging_link = gke_utils.get_log_link(
cluster=self.cluster,
project=self.project,
job_name=self._proxy_job_name,
)
_logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link)
proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name)
self._proxy_port, self._port_forward_process = (
gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT)
)
# Update the JAX backend to use the proxy.
jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY)
jax.config.update(
_JAX_BACKEND_TARGET_KEY,
f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}",
)
pathwaysutils.initialize()
_logger.info(
"Interactive supercomputing proxy client ready for cluster '%s'.",
self.cluster,
)
return self
except Exception as e:
_logger.exception("Error setting up Pathways proxy: %r", e)
if self._run_cleanup:
# If any part of setup fails after deployment, cleanup.
self._cleanup()
raise
def __exit__(self, exc_type, exc_value, traceback):
"""Exits the context manager."""
_logger.info("Exiting ISCPathways context.")
self._cleanup()
def _cleanup(self):
"""Cleans up resources created by the ISCPathways context."""
# 1. Clear JAX caches and run garbage collection.
_logger.info("Starting Pathways proxy cleanup.")
jax_backend.clear_backends()
jax.clear_caches()
gc.collect()
_logger.info("Cleared JAX caches and ran garbage collection.")
# 2. Terminate the port forwarding process.
if self._port_forward_process:
_logger.info("Terminating port forwarding process...")
self._port_forward_process.terminate()
try:
self._port_forward_process.wait(timeout=10)
except subprocess.TimeoutExpired as e:
_logger.exception(
"Failed to terminate port forwarding process. Not treating as an "
"error: %r",
e,
)
# 3. Delete the proxy GKE job ONLY if this process created it.
if self._run_cleanup:
_logger.info("Deleting Pathways proxy...")
gke_utils.delete_gke_job(self._proxy_job_name)
_logger.info("Pathways proxy GKE job deletion complete.")
@contextlib.contextmanager
def connect(
*,
cluster: str,
project: str,
region: str,
gcs_bucket: str,
pathways_service: str,
expected_tpu_instances: Mapping[str, int],
proxy_job_name: str | None = None,
proxy_server_image: str = _DEFAULT_PROXY_IMAGE,
) -> Iterator["_ISCPathways"]:
"""Connects to a Pathways server if the cluster exists. If not, creates it.
Args:
cluster: The name of the GKE cluster.
project: The GCP project ID.
region: The GCP region.
gcs_bucket: The Google Cloud Storage bucket to use for scratch space.
pathways_service: The service name and port of the Pathways head pod.
expected_tpu_instances: A dictionary mapping TPU machine types to the number
of instances. For example: {"tpuv6e:2x2": 2}
proxy_job_name: The name to use for the deployed proxy. If not provided, a
random name will be generated.
proxy_server_image: The proxy server image to use. If not provided, a
default will be used.
Yields:
The Pathways manager.
"""
_logger.info("Validating Pathways service and TPU instances...")
validators.validate_pathways_service(pathways_service)
validators.validate_tpu_instances(expected_tpu_instances)
validators.validate_proxy_server_image(proxy_server_image)
_logger.info("Validation complete.")
gke_utils.fetch_cluster_credentials(
cluster_name=cluster, project_id=project, location=region
)
proxy_job_name = (
proxy_job_name or f"isc-proxy-{os.environ.get('USER', 'user')}-{''.join(
random.choices(string.ascii_lowercase + string.digits, k=5)
)}"
)
_logger.info("Starting ISCPathways context.")
with _ISCPathways(
cluster=cluster,
project=project,
region=region,
gcs_bucket=gcs_bucket,
pathways_service=pathways_service,
expected_tpu_instances=expected_tpu_instances,
proxy_job_name=proxy_job_name,
proxy_server_image=proxy_server_image,
) as t:
yield t