22
33from collections .abc import Iterator , Mapping
44import contextlib
5+ import dataclasses
56import gc
67import logging
78import os
2930_JAX_PLATFORM_PROXY = "proxy"
3031_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
3132_JAX_BACKEND_TARGET_HOSTNAME = "grpc://127.0.0.1"
32- _DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
33+ DEFAULT_PROXY_IMAGE = (
34+ "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
35+ )
3336
3437_logger = logging .getLogger (__name__ )
3538
3639
40+ @dataclasses .dataclass
41+ class ProxyOptions :
42+ """Configuration options for the Pathways proxy.
43+
44+ Attributes:
45+ use_insecure_credentials: Whether to use insecure gRPC credentials for the
46+ proxy server.
47+ """
48+ use_insecure_credentials : bool = False
49+
50+ @classmethod
51+ def from_dict (cls , options : Mapping [str , str ] | None ) -> "ProxyOptions" :
52+ """Creates a ProxyOptions object from a dictionary of options."""
53+ options = options or {}
54+ use_insecure = (
55+ options .get ("use_insecure_credentials" , "false" ).lower () == "true"
56+ )
57+ return cls (use_insecure_credentials = use_insecure )
58+
59+
3760def _deploy_pathways_proxy_server (
38- * , pathways_service : str ,
61+ * ,
62+ pathways_service : str ,
3963 proxy_job_name : str ,
4064 expected_instances : Mapping [Any , Any ],
4165 gcs_scratch_location : str ,
4266 proxy_server_image : str ,
67+ proxy_options : ProxyOptions | None = None ,
4368) -> None :
4469 """Deploys the Pathways proxy pods to the GKE cluster.
4570
@@ -50,6 +75,8 @@ def _deploy_pathways_proxy_server(
5075 instances.
5176 gcs_scratch_location: The Google Cloud Storage location to use.
5277 proxy_server_image: The image to use for the proxy server.
78+ proxy_options: Configuration options for the Pathways proxy. If not
79+ provided, no extra options will be used.
5380
5481 Raises:
5582 subprocess.CalledProcessError: If the kubectl command fails.
@@ -67,6 +94,13 @@ def _deploy_pathways_proxy_server(
6794 instance_type , count = next (iter (expected_instances .items ()))
6895 instances_str = "," .join (instance_type for _ in range (count ))
6996
97+ proxy_options = proxy_options or ProxyOptions ()
98+
99+ proxy_env_str = (
100+ ' - name: IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS\n '
101+ ' value: "true"\n '
102+ ) if proxy_options .use_insecure_credentials else ""
103+
70104 template = string .Template (yaml_template )
71105 substituted_yaml = template .substitute (
72106 PROXY_JOB_NAME = proxy_job_name ,
@@ -76,6 +110,7 @@ def _deploy_pathways_proxy_server(
76110 EXPECTED_INSTANCES = instances_str ,
77111 GCS_SCRATCH_LOCATION = gcs_scratch_location ,
78112 PROXY_SERVER_IMAGE = proxy_server_image ,
113+ PROXY_ENV = proxy_env_str ,
79114 )
80115
81116 _logger .info ("Deploying Pathways proxy: %s" , proxy_job_name )
@@ -97,6 +132,7 @@ class _ISCPathways:
97132 of instances.
98133 proxy_job_name: The name to use for the deployed proxy.
99134 proxy_server_image: The image to use for the proxy server.
135+ proxy_options: Configuration options for the Pathways proxy.
100136 """
101137
102138 def __init__ (
@@ -109,6 +145,7 @@ def __init__(
109145 expected_tpu_instances : Mapping [Any , Any ],
110146 proxy_job_name : str ,
111147 proxy_server_image : str ,
148+ proxy_options : ProxyOptions | None = None ,
112149 ):
113150 """Initializes the TPU manager."""
114151 self .cluster = cluster
@@ -121,14 +158,16 @@ def __init__(
121158 self ._port_forward_process = None
122159 self ._proxy_port = None
123160 self .proxy_server_image = proxy_server_image
161+ self .proxy_options = proxy_options or ProxyOptions ()
124162
125163 def __repr__ (self ):
126164 return (
127165 f"_ISCPathways(cluster='{ self .cluster } ', project='{ self .project } ', "
128166 f"region='{ self .region } ', bucket='{ self .bucket } ', "
129167 f"pathways_service='{ self .pathways_service } ', "
130168 f"expected_tpu_instances={ self .expected_tpu_instances } , "
131- f"_proxy_job_name='{ self ._proxy_job_name } ')"
169+ f"_proxy_job_name='{ self ._proxy_job_name } ', "
170+ f"proxy_options={ self .proxy_options } )"
132171 )
133172
134173 def __enter__ (self ):
@@ -140,6 +179,7 @@ def __enter__(self):
140179 expected_instances = self .expected_tpu_instances ,
141180 gcs_scratch_location = self .bucket ,
142181 proxy_server_image = self .proxy_server_image ,
182+ proxy_options = self .proxy_options ,
143183 )
144184 # Print a link to Cloud Logging
145185 cloud_logging_link = gke_utils .get_log_link (
@@ -215,7 +255,8 @@ def connect(
215255 pathways_service : str ,
216256 expected_tpu_instances : Mapping [str , int ],
217257 proxy_job_name : str | None = None ,
218- proxy_server_image : str = _DEFAULT_PROXY_IMAGE ,
258+ proxy_server_image : str = DEFAULT_PROXY_IMAGE ,
259+ proxy_options : ProxyOptions | None = None ,
219260) -> Iterator ["_ISCPathways" ]:
220261 """Connects to a Pathways server if the cluster exists. If not, creates it.
221262
@@ -231,6 +272,8 @@ def connect(
231272 random name will be generated.
232273 proxy_server_image: The proxy server image to use. If not provided, a
233274 default will be used.
275+ proxy_options: Configuration options for the Pathways proxy. If not
276+ provided, no extra options will be used.
234277
235278 Yields:
236279 The Pathways manager.
@@ -259,5 +302,6 @@ def connect(
259302 expected_tpu_instances = expected_tpu_instances ,
260303 proxy_job_name = proxy_job_name ,
261304 proxy_server_image = proxy_server_image ,
305+ proxy_options = proxy_options ,
262306 ) as t :
263307 yield t
0 commit comments