1010import string
1111import subprocess
1212import threading
13+ import time
1314from typing import Any
1415
1516import jax
1617import jax .extend .backend as jax_backend
1718import pathwaysutils
1819from pathwaysutils .experimental .shared_pathways_service import gke_utils
20+ from pathwaysutils .experimental .shared_pathways_service import metrics_collector
1921from pathwaysutils .experimental .shared_pathways_service import validators
2022
2123
@@ -128,6 +130,9 @@ def _wait_for_placement(
128130 pod_name : str ,
129131 num_slices : int ,
130132 stream_logs_func = gke_utils .stream_pod_logs ,
133+ metrics_collector_inst : Any = None ,
134+ start_time : float | None = None ,
135+ total_chips : int = 0 ,
131136) -> None :
132137 """Waits for the placement to be complete by checking proxy logs."""
133138 _logger .info ("Streaming proxy logs until the placement is complete..." )
@@ -150,6 +155,8 @@ def _wait_for_placement(
150155 f"STDERR: { stderr } "
151156 )
152157
158+ if metrics_collector_inst :
159+ metrics_collector_inst .record_user_waiting (True )
153160 for line in log_process .stdout :
154161 line_lower = line .lower ()
155162 if any (keyword .lower () in line_lower for keyword in keywords ):
@@ -165,6 +172,13 @@ def _wait_for_placement(
165172 )
166173 else :
167174 _logger .info ("TPU placement for %d slice(s) complete!" , num_slices )
175+ metrics_collector_inst .record_active_user (True )
176+ metrics_collector_inst .record_user_waiting (False )
177+ metrics_collector_inst .record_capacity_in_use (total_chips )
178+ if start_time :
179+ duration = time .time () - start_time
180+ metrics_collector_inst .record_assignment_time (duration )
181+ metrics_collector_inst .record_successful_request ()
168182 break
169183
170184
@@ -195,11 +209,15 @@ class _ISCPathways:
195209 proxy_pod_name: The name of the proxy pod, assigned during deployment.
196210 proxy_server_image: The image to use for the proxy server.
197211 proxy_options: Configuration options for the Pathways proxy.
212+ metrics_collector: The metrics collector instance if enabled.
213+ start_time: The start time of the TPU assignment.
214+ total_chips: The total number of TPU chips expected across all instances.
198215 """
199216
200217 def __init__ (
201218 self ,
202- * , cluster : str ,
219+ * ,
220+ cluster : str ,
203221 project : str ,
204222 region : str ,
205223 gcs_bucket : str ,
@@ -208,6 +226,7 @@ def __init__(
208226 proxy_job_name : str ,
209227 proxy_server_image : str ,
210228 proxy_options : ProxyOptions | None = None ,
229+ collect_service_metrics : bool = False ,
211230 ):
212231 """Initializes the TPU manager."""
213232 self .cluster = cluster
@@ -223,9 +242,19 @@ def __init__(
223242 self .proxy_server_image = proxy_server_image
224243 self .proxy_options = proxy_options or ProxyOptions ()
225244 self ._old_jax_platforms = None
245+ raw_collector = (
246+ metrics_collector .MetricsCollector (self .project )
247+ if collect_service_metrics
248+ else None
249+ )
250+ self .metrics_collector = metrics_collector .SafeMetricsCollector (
251+ raw_collector
252+ )
253+ self .start_time = None
226254 self ._old_jax_backend_target = None
227255 self ._old_jax_platforms_config = None
228256 self ._old_jax_backend_target_config = None
257+ self .total_chips = self ._get_total_chips ()
229258
230259 def __repr__ (self ):
231260 return (
@@ -237,8 +266,23 @@ def __repr__(self):
237266 f"proxy_options={ self .proxy_options } )"
238267 )
239268
269+ def _get_total_chips (self ) -> int :
270+ """Calculates total chips from expected_tpu_instances."""
271+ total_chips = 0
272+ for tpu_type , count in self .expected_tpu_instances .items ():
273+ parts = tpu_type .split (":" )
274+ topology = parts [1 ]
275+ dimensions = [int (d ) for d in topology .split ("x" )]
276+ chips_per_instance = 1
277+ for d in dimensions :
278+ chips_per_instance *= d
279+ total_chips += chips_per_instance * count
280+ return total_chips
281+
240282 def __enter__ (self ):
241283 """Enters the context manager, ensuring cluster exists."""
284+ self .metrics_collector .record_requested_capacity (self .total_chips )
285+
242286 self ._old_jax_platforms = os .environ .get (_JAX_PLATFORMS_KEY .upper ())
243287 self ._old_jax_backend_target = os .environ .get (
244288 _JAX_BACKEND_TARGET_KEY .upper ()
@@ -251,6 +295,7 @@ def __enter__(self):
251295 )
252296
253297 try :
298+ self .start_time = time .time ()
254299 _deploy_pathways_proxy_server (
255300 pathways_service = self .pathways_service ,
256301 proxy_job_name = self ._proxy_job_name ,
@@ -259,7 +304,7 @@ def __enter__(self):
259304 proxy_server_image = self .proxy_server_image ,
260305 proxy_options = self .proxy_options ,
261306 )
262- # Print a link to Cloud Logging
307+ self . metrics_collector . record_user_waiting ( True )
263308 cloud_logging_link = gke_utils .get_log_link (
264309 cluster = self .cluster ,
265310 project = self .project ,
@@ -303,14 +348,14 @@ def __exit__(self, exc_type, exc_value, traceback):
303348
304349 def _cleanup (self ) -> None :
305350 """Cleans up resources created by the ISCPathways context."""
306- # 1. Clear JAX caches and run garbage collection.
351+ # Clear JAX caches and run garbage collection.
307352 _logger .info ("Starting Pathways proxy cleanup." )
308353 jax_backend .clear_backends ()
309354 jax .clear_caches ()
310355 gc .collect ()
311356 _logger .info ("Cleared JAX caches and ran garbage collection." )
312357
313- # 2. Terminate the port forwarding process.
358+ # Terminate the port forwarding process.
314359 if self ._port_forward_process :
315360 _logger .info ("Terminating port forwarding process..." )
316361 self ._port_forward_process .terminate ()
@@ -323,12 +368,12 @@ def _cleanup(self) -> None:
323368 e ,
324369 )
325370
326- # 3. Delete the proxy GKE job.
371+ # Delete the proxy GKE job.
327372 _logger .info ("Deleting Pathways proxy..." )
328373 gke_utils .delete_gke_job (self ._proxy_job_name )
329374 _logger .info ("Pathways proxy GKE job deletion complete." )
330375
331- # 4. Restore JAX variables.
376+ # Restore JAX variables.
332377 _logger .info ("Restoring JAX env and config variables..." )
333378 _restore_env_var (_JAX_PLATFORMS_KEY .upper (), self ._old_jax_platforms )
334379 _restore_env_var (
@@ -353,6 +398,7 @@ def connect(
353398 proxy_job_name : str | None = None ,
354399 proxy_server_image : str = DEFAULT_PROXY_IMAGE ,
355400 proxy_options : ProxyOptions | None = None ,
401+ collect_service_metrics : bool = False ,
356402) -> Iterator ["_ISCPathways" ]:
357403 """Connects to a Pathways server if the cluster exists. If not, creates it.
358404
@@ -370,6 +416,8 @@ def connect(
370416 default will be used.
371417 proxy_options: Configuration options for the Pathways proxy. If not
372418 provided, no extra options will be used.
419+ collect_service_metrics: Whether to collect usage metrics for Shared
420+ Pathways Service.
373421
374422 Yields:
375423 The Pathways manager.
@@ -399,6 +447,7 @@ def connect(
399447 proxy_job_name = proxy_job_name ,
400448 proxy_server_image = proxy_server_image ,
401449 proxy_options = proxy_options ,
450+ collect_service_metrics = collect_service_metrics ,
402451 ) as t :
403452 if t .proxy_pod_name :
404453 num_slices = sum (t .expected_tpu_instances .values ())
@@ -407,6 +456,10 @@ def connect(
407456 args = (
408457 t .proxy_pod_name ,
409458 num_slices ,
459+ gke_utils .stream_pod_logs ,
460+ t .metrics_collector ,
461+ t .start_time ,
462+ t .total_chips ,
410463 ),
411464 daemon = True ,
412465 )
0 commit comments