Skip to content

Commit dad55e8

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
Allow specifying a custom proxy job name in _ISCPathways
The `_ISCPathways` constructor now accepts an optional `proxy_job_name` argument, allowing users to provide a specific name for the proxy job instead of always generating one. PiperOrigin-RevId: 857304421
1 parent 8b23013 commit dad55e8

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
gcs_bucket: str,
100100
pathways_service: str,
101101
expected_tpu_instances: Mapping[Any, Any],
102+
proxy_job_name: str | None,
102103
):
103104
"""Initializes the TPU manager."""
104105
self.cluster = cluster
@@ -111,7 +112,7 @@ def __init__(
111112
random.choices(string.ascii_lowercase + string.digits, k=5)
112113
)
113114
user = os.environ.get("USER", "user")
114-
self._proxy_job_name = f"isc-proxy-{user}-{suffix}"
115+
self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}"
115116
self._port_forward_process = None
116117
self._proxy_port = None
117118

@@ -194,6 +195,7 @@ def connect(
194195
gcs_bucket: str,
195196
pathways_service: str,
196197
expected_tpu_instances: Mapping[str, int],
198+
proxy_job_name: str | None = None,
197199
) -> Iterator["_ISCPathways"]:
198200
"""Connects to a Pathways server if the cluster exists. If not, creates it.
199201
@@ -205,12 +207,16 @@ def connect(
205207
pathways_service: The service name and port of the Pathways head pod.
206208
expected_tpu_instances: A dictionary mapping TPU machine types to the number
207209
of instances. For example: {"tpuv6e:2x2": 2}
210+
proxy_job_name: The name to use for the deployed proxy. If not provided, a
211+
random name will be generated.
208212
209213
Yields:
210214
The Pathways manager.
211215
"""
216+
_logger.info("Validating Pathways service and TPU instances...")
212217
validators.validate_pathways_service(pathways_service)
213218
validators.validate_tpu_instances(expected_tpu_instances)
219+
_logger.info("Validation complete.")
214220
gke_utils.fetch_cluster_credentials(
215221
cluster_name=cluster, project_id=project, location=region
216222
)
@@ -222,5 +228,6 @@ def connect(
222228
gcs_bucket=gcs_bucket,
223229
pathways_service=pathways_service,
224230
expected_tpu_instances=expected_tpu_instances,
231+
proxy_job_name=proxy_job_name,
225232
) as t:
226233
yield t

0 commit comments

Comments
 (0)