Skip to content

Commit f78888c

Browse files
guptaakacopybara-github
authored andcommitted
Add support for passing environment variables to the Pathways proxy
This change introduces a `proxy_args` parameter to isc_pathways.connect, allowing users to specify environment variables for the proxy server pod. PiperOrigin-RevId: 876024144
1 parent c16512c commit f78888c

3 files changed

Lines changed: 63 additions & 10 deletions

File tree

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Iterator, Mapping
44
import contextlib
5+
import dataclasses
56
import gc
67
import logging
78
import os
@@ -29,17 +30,41 @@
2930
_JAX_PLATFORM_PROXY = "proxy"
3031
_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
3132
_JAX_BACKEND_TARGET_HOSTNAME = "grpc://127.0.0.1"
32-
_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
33+
DEFAULT_PROXY_IMAGE = (
34+
"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
35+
)
3336

3437
_logger = logging.getLogger(__name__)
3538

3639

40+
@dataclasses.dataclass
41+
class ProxyOptions:
42+
"""Configuration options for the Pathways proxy.
43+
44+
Attributes:
45+
use_insecure_credentials: Whether to use insecure gRPC credentials for the
46+
proxy server.
47+
"""
48+
use_insecure_credentials: bool = False
49+
50+
@classmethod
51+
def from_dict(cls, options: Mapping[str, str] | None) -> "ProxyOptions":
52+
"""Creates a ProxyOptions object from a dictionary of options."""
53+
options = options or {}
54+
use_insecure = (
55+
options.get("use_insecure_credentials", "false").lower() == "true"
56+
)
57+
return cls(use_insecure_credentials=use_insecure)
58+
59+
3760
def _deploy_pathways_proxy_server(
38-
*, pathways_service: str,
61+
*,
62+
pathways_service: str,
3963
proxy_job_name: str,
4064
expected_instances: Mapping[Any, Any],
4165
gcs_scratch_location: str,
4266
proxy_server_image: str,
67+
proxy_options: ProxyOptions | None = None,
4368
) -> None:
4469
"""Deploys the Pathways proxy pods to the GKE cluster.
4570
@@ -50,6 +75,8 @@ def _deploy_pathways_proxy_server(
5075
instances.
5176
gcs_scratch_location: The Google Cloud Storage location to use.
5277
proxy_server_image: The image to use for the proxy server.
78+
proxy_options: Configuration options for the Pathways proxy. If not
79+
provided, no extra options will be used.
5380
5481
Raises:
5582
subprocess.CalledProcessError: If the kubectl command fails.
@@ -67,6 +94,13 @@ def _deploy_pathways_proxy_server(
6794
instance_type, count = next(iter(expected_instances.items()))
6895
instances_str = ",".join(instance_type for _ in range(count))
6996

97+
proxy_options = proxy_options or ProxyOptions()
98+
99+
proxy_env_str = (
100+
' - name: IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS\n'
101+
' value: "true"\n'
102+
) if proxy_options.use_insecure_credentials else ""
103+
70104
template = string.Template(yaml_template)
71105
substituted_yaml = template.substitute(
72106
PROXY_JOB_NAME=proxy_job_name,
@@ -76,6 +110,7 @@ def _deploy_pathways_proxy_server(
76110
EXPECTED_INSTANCES=instances_str,
77111
GCS_SCRATCH_LOCATION=gcs_scratch_location,
78112
PROXY_SERVER_IMAGE=proxy_server_image,
113+
PROXY_ENV=proxy_env_str,
79114
)
80115

81116
_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
@@ -97,6 +132,7 @@ class _ISCPathways:
97132
of instances.
98133
proxy_job_name: The name to use for the deployed proxy.
99134
proxy_server_image: The image to use for the proxy server.
135+
proxy_options: Configuration options for the Pathways proxy.
100136
"""
101137

102138
def __init__(
@@ -109,6 +145,7 @@ def __init__(
109145
expected_tpu_instances: Mapping[Any, Any],
110146
proxy_job_name: str,
111147
proxy_server_image: str,
148+
proxy_options: ProxyOptions | None = None,
112149
):
113150
"""Initializes the TPU manager."""
114151
self.cluster = cluster
@@ -121,14 +158,16 @@ def __init__(
121158
self._port_forward_process = None
122159
self._proxy_port = None
123160
self.proxy_server_image = proxy_server_image
161+
self.proxy_options = proxy_options or ProxyOptions()
124162

125163
def __repr__(self):
126164
return (
127165
f"_ISCPathways(cluster='{self.cluster}', project='{self.project}', "
128166
f"region='{self.region}', bucket='{self.bucket}', "
129167
f"pathways_service='{self.pathways_service}', "
130168
f"expected_tpu_instances={self.expected_tpu_instances}, "
131-
f"_proxy_job_name='{self._proxy_job_name}')"
169+
f"_proxy_job_name='{self._proxy_job_name}', "
170+
f"proxy_options={self.proxy_options})"
132171
)
133172

134173
def __enter__(self):
@@ -140,6 +179,7 @@ def __enter__(self):
140179
expected_instances=self.expected_tpu_instances,
141180
gcs_scratch_location=self.bucket,
142181
proxy_server_image=self.proxy_server_image,
182+
proxy_options=self.proxy_options,
143183
)
144184
# Print a link to Cloud Logging
145185
cloud_logging_link = gke_utils.get_log_link(
@@ -215,7 +255,8 @@ def connect(
215255
pathways_service: str,
216256
expected_tpu_instances: Mapping[str, int],
217257
proxy_job_name: str | None = None,
218-
proxy_server_image: str = _DEFAULT_PROXY_IMAGE,
258+
proxy_server_image: str = DEFAULT_PROXY_IMAGE,
259+
proxy_options: ProxyOptions | None = None,
219260
) -> Iterator["_ISCPathways"]:
220261
"""Connects to a Pathways server if the cluster exists. If not, creates it.
221262
@@ -231,6 +272,8 @@ def connect(
231272
random name will be generated.
232273
proxy_server_image: The proxy server image to use. If not provided, a
233274
default will be used.
275+
proxy_options: Configuration options for the Pathways proxy. If not
276+
provided, no extra options will be used.
234277
235278
Yields:
236279
The Pathways manager.
@@ -259,5 +302,6 @@ def connect(
259302
expected_tpu_instances=expected_tpu_instances,
260303
proxy_job_name=proxy_job_name,
261304
proxy_server_image=proxy_server_image,
305+
proxy_options=proxy_options,
262306
) as t:
263307
yield t

pathwaysutils/experimental/shared_pathways_service/run_connect_example.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from pathwaysutils.experimental.shared_pathways_service import isc_pathways
1010

1111

12+
from google3.pyglib.flags.contrib import dict_flag
13+
1214
FLAGS = flags.FLAGS
1315

1416
flags.DEFINE_string("cluster", None, "The name of the GKE cluster.")
@@ -35,6 +37,12 @@
3537
None,
3638
"The proxy server image to use. If not provided, a default will be used.",
3739
)
40+
dict_flag.DEFINE_dict(
41+
"proxy_options",
42+
None,
43+
"Configuration options for the Pathways proxy. Specify entries in the form"
44+
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
45+
)
3846

3947
flags.mark_flags_as_required([
4048
"cluster",
@@ -49,11 +57,7 @@ def main(argv: Sequence[str]) -> None:
4957
if len(argv) > 1:
5058
raise app.UsageError("Too many command-line arguments.")
5159

52-
kwargs = {}
53-
if FLAGS.proxy_job_name:
54-
kwargs["proxy_job_name"] = FLAGS.proxy_job_name
55-
if FLAGS.proxy_server_image:
56-
kwargs["proxy_server_image"] = FLAGS.proxy_server_image
60+
proxy_options = isc_pathways.ProxyOptions.from_dict(FLAGS.proxy_options)
5761

5862
with isc_pathways.connect(
5963
cluster=FLAGS.cluster,
@@ -62,7 +66,10 @@ def main(argv: Sequence[str]) -> None:
6266
gcs_bucket=FLAGS.gcs_bucket,
6367
pathways_service=FLAGS.pathways_service,
6468
expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count},
65-
**kwargs,
69+
proxy_job_name=FLAGS.proxy_job_name,
70+
proxy_server_image=FLAGS.proxy_server_image
71+
or isc_pathways.DEFAULT_PROXY_IMAGE,
72+
proxy_options=proxy_options,
6673
):
6774
orig_matrix = jnp.zeros(5)
6875
result_matrix = orig_matrix + 1

pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ spec:
2121
- --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT}
2222
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
2323
- --virtual_slices=${EXPECTED_INSTANCES}
24+
env:
25+
${PROXY_ENV}
2426
ports:
2527
- containerPort: ${PROXY_SERVER_PORT}
2628
protocol: TCP

0 commit comments

Comments
 (0)