1010import string
1111import subprocess
1212import threading
13+ import time
1314from typing import Any
1415
1516import jax
1617import jax .extend .backend as jax_backend
1718import pathwaysutils
1819from pathwaysutils .experimental .shared_pathways_service import gke_utils
20+ from pathwaysutils .experimental .shared_pathways_service import metrics_collector
1921from pathwaysutils .experimental .shared_pathways_service import validators
2022
2123
@@ -128,6 +130,8 @@ def _wait_for_placement(
128130 pod_name : str ,
129131 num_slices : int ,
130132 stream_logs_func = gke_utils .stream_pod_logs ,
133+ metrics_collector_inst : Any = None ,
134+ start_time : float | None = None ,
131135) -> None :
132136 """Waits for the placement to be complete by checking proxy logs."""
133137 _logger .info ("Streaming proxy logs until the placement is complete..." )
@@ -165,7 +169,16 @@ def _wait_for_placement(
165169 )
166170 else :
167171 _logger .info ("TPU placement for %d slice(s) complete!" , num_slices )
172+ if metrics_collector_inst :
173+ metrics_collector_inst .record_user_waiting (0 )
174+ if start_time :
175+ duration = time .time () - start_time
176+ metrics_collector_inst .record_assignment_time (duration )
177+ metrics_collector_inst .record_successful_request ()
168178 break
179+ else :
180+ if metrics_collector_inst :
181+ metrics_collector_inst .record_user_waiting (1 )
169182
170183
171184def _restore_env_var (key : str , original_value : str | None ) -> None :
@@ -195,11 +208,14 @@ class _ISCPathways:
195208 proxy_pod_name: The name of the proxy pod, assigned during deployment.
196209 proxy_server_image: The image to use for the proxy server.
197210 proxy_options: Configuration options for the Pathways proxy.
211+ metrics_collector: The metrics collector instance if enabled.
212+ start_time: The start time of the TPU assignment.
198213 """
199214
200215 def __init__ (
201216 self ,
202- * , cluster : str ,
217+ * ,
218+ cluster : str ,
203219 project : str ,
204220 region : str ,
205221 gcs_bucket : str ,
@@ -208,6 +224,7 @@ def __init__(
208224 proxy_job_name : str ,
209225 proxy_server_image : str ,
210226 proxy_options : ProxyOptions | None = None ,
227+ collect_service_metrics : bool = False ,
211228 ):
212229 """Initializes the TPU manager."""
213230 self .cluster = cluster
@@ -223,6 +240,10 @@ def __init__(
223240 self .proxy_server_image = proxy_server_image
224241 self .proxy_options = proxy_options or ProxyOptions ()
225242 self ._old_jax_platforms = None
243+ self .metrics_collector = None
244+ if collect_service_metrics :
245+ self .metrics_collector = metrics_collector .MetricsCollector (self .project )
246+ self .start_time = None
226247 self ._old_jax_backend_target = None
227248 self ._old_jax_platforms_config = None
228249 self ._old_jax_backend_target_config = None
@@ -237,8 +258,25 @@ def __repr__(self):
237258 f"proxy_options={ self .proxy_options } )"
238259 )
239260
261+ def _get_total_chips (self ) -> int :
262+ """Calculates total chips from expected_tpu_instances."""
263+ total_chips = 0
264+ for tpu_type , count in self .expected_tpu_instances .items ():
265+ parts = tpu_type .split (":" )
266+ topology = parts [1 ]
267+ dimensions = [int (d ) for d in topology .split ("x" )]
268+ chips_per_instance = 1
269+ for d in dimensions :
270+ chips_per_instance *= d
271+ total_chips += chips_per_instance * count
272+ return total_chips
273+
240274 def __enter__ (self ):
241275 """Enters the context manager, ensuring cluster exists."""
276+ if self .metrics_collector :
277+ self .metrics_collector .record_active_user (True )
278+ self .metrics_collector .record_capacity_in_use (self ._get_total_chips ())
279+
242280 self ._old_jax_platforms = os .environ .get (_JAX_PLATFORMS_KEY .upper ())
243281 self ._old_jax_backend_target = os .environ .get (
244282 _JAX_BACKEND_TARGET_KEY .upper ()
@@ -251,6 +289,7 @@ def __enter__(self):
251289 )
252290
253291 try :
292+ self .start_time = time .time ()
254293 _deploy_pathways_proxy_server (
255294 pathways_service = self .pathways_service ,
256295 proxy_job_name = self ._proxy_job_name ,
@@ -303,14 +342,19 @@ def __exit__(self, exc_type, exc_value, traceback):
303342
304343 def _cleanup (self ) -> None :
305344 """Cleans up resources created by the ISCPathways context."""
306- # 1. Clear JAX caches and run garbage collection.
345+ # Reset metrics on exit.
346+ if self .metrics_collector :
347+ self .metrics_collector .record_active_user (False )
348+ self .metrics_collector .record_capacity_in_use (0 )
349+
350+ # Clear JAX caches and run garbage collection.
307351 _logger .info ("Starting Pathways proxy cleanup." )
308352 jax_backend .clear_backends ()
309353 jax .clear_caches ()
310354 gc .collect ()
311355 _logger .info ("Cleared JAX caches and ran garbage collection." )
312356
313- # 2. Terminate the port forwarding process.
357+ # Terminate the port forwarding process.
314358 if self ._port_forward_process :
315359 _logger .info ("Terminating port forwarding process..." )
316360 self ._port_forward_process .terminate ()
@@ -323,12 +367,12 @@ def _cleanup(self) -> None:
323367 e ,
324368 )
325369
326- # 3. Delete the proxy GKE job.
370+ # Delete the proxy GKE job.
327371 _logger .info ("Deleting Pathways proxy..." )
328372 gke_utils .delete_gke_job (self ._proxy_job_name )
329373 _logger .info ("Pathways proxy GKE job deletion complete." )
330374
331- # 4. Restore JAX variables.
375+ # Restore JAX variables.
332376 _logger .info ("Restoring JAX env and config variables..." )
333377 _restore_env_var (_JAX_PLATFORMS_KEY .upper (), self ._old_jax_platforms )
334378 _restore_env_var (
@@ -353,6 +397,7 @@ def connect(
353397 proxy_job_name : str | None = None ,
354398 proxy_server_image : str = DEFAULT_PROXY_IMAGE ,
355399 proxy_options : ProxyOptions | None = None ,
400+ collect_service_metrics : bool = False ,
356401) -> Iterator ["_ISCPathways" ]:
357402 """Connects to a Pathways server if the cluster exists. If not, creates it.
358403
@@ -370,6 +415,8 @@ def connect(
370415 default will be used.
371416 proxy_options: Configuration options for the Pathways proxy. If not
372417 provided, no extra options will be used.
418+ collect_service_metrics: Whether to collect usage metrics for Shared
419+ Pathways Service.
373420
374421 Yields:
375422 The Pathways manager.
@@ -399,6 +446,7 @@ def connect(
399446 proxy_job_name = proxy_job_name ,
400447 proxy_server_image = proxy_server_image ,
401448 proxy_options = proxy_options ,
449+ collect_service_metrics = collect_service_metrics ,
402450 ) as t :
403451 if t .proxy_pod_name :
404452 num_slices = sum (t .expected_tpu_instances .values ())
@@ -407,6 +455,9 @@ def connect(
407455 args = (
408456 t .proxy_pod_name ,
409457 num_slices ,
458+ gke_utils .stream_pod_logs ,
459+ t .metrics_collector ,
460+ t .start_time ,
410461 ),
411462 daemon = True ,
412463 )
0 commit comments