Skip to content

Commit a5984c0

Browse files
guptaakacopybara-github
authored andcommitted
Integrate metrics collection into ISC Pathways
This change introduces a new `metrics_collector` module to track key metrics within the ISC Pathways client. The `_ISCPathways` context manager initializes and uses the `MetricsCollector` to record the below metrics if `collect_isc_metrics` flag is enabled: - Requested TPU capacity - Active user count on TPU placement - Capacity in use - TPU assignment time upon successful placement - Successful request count - Waiting user count PiperOrigin-RevId: 900220428
1 parent c2a9fe0 commit a5984c0

4 files changed

Lines changed: 374 additions & 6 deletions

File tree

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
import string
1111
import subprocess
1212
import threading
13+
import time
1314
from typing import Any
1415

1516
import jax
1617
import jax.extend.backend as jax_backend
1718
import pathwaysutils
1819
from pathwaysutils.experimental.shared_pathways_service import gke_utils
20+
from pathwaysutils.experimental.shared_pathways_service import metrics_collector
1921
from pathwaysutils.experimental.shared_pathways_service import validators
2022

2123

@@ -128,6 +130,9 @@ def _wait_for_placement(
128130
pod_name: str,
129131
num_slices: int,
130132
stream_logs_func=gke_utils.stream_pod_logs,
133+
metrics_collector_inst: Any = None,
134+
start_time: float | None = None,
135+
total_chips: int = 0,
131136
) -> None:
132137
"""Waits for the placement to be complete by checking proxy logs."""
133138
_logger.info("Streaming proxy logs until the placement is complete...")
@@ -150,6 +155,8 @@ def _wait_for_placement(
150155
f"STDERR: {stderr}"
151156
)
152157

158+
if metrics_collector_inst:
159+
metrics_collector_inst.record_user_waiting(True)
153160
for line in log_process.stdout:
154161
line_lower = line.lower()
155162
if any(keyword.lower() in line_lower for keyword in keywords):
@@ -165,6 +172,13 @@ def _wait_for_placement(
165172
)
166173
else:
167174
_logger.info("TPU placement for %d slice(s) complete!", num_slices)
175+
metrics_collector_inst.record_active_user(True)
176+
metrics_collector_inst.record_user_waiting(False)
177+
metrics_collector_inst.record_capacity_in_use(total_chips)
178+
if start_time:
179+
duration = time.time() - start_time
180+
metrics_collector_inst.record_assignment_time(duration)
181+
metrics_collector_inst.record_successful_request()
168182
break
169183

170184

@@ -195,11 +209,15 @@ class _ISCPathways:
195209
proxy_pod_name: The name of the proxy pod, assigned during deployment.
196210
proxy_server_image: The image to use for the proxy server.
197211
proxy_options: Configuration options for the Pathways proxy.
212+
metrics_collector: The metrics collector instance if enabled.
213+
start_time: The start time of the TPU assignment.
214+
total_chips: The total number of TPU chips expected across all instances.
198215
"""
199216

200217
def __init__(
201218
self,
202-
*, cluster: str,
219+
*,
220+
cluster: str,
203221
project: str,
204222
region: str,
205223
gcs_bucket: str,
@@ -208,6 +226,7 @@ def __init__(
208226
proxy_job_name: str,
209227
proxy_server_image: str,
210228
proxy_options: ProxyOptions | None = None,
229+
collect_service_metrics: bool = False,
211230
):
212231
"""Initializes the TPU manager."""
213232
self.cluster = cluster
@@ -223,9 +242,19 @@ def __init__(
223242
self.proxy_server_image = proxy_server_image
224243
self.proxy_options = proxy_options or ProxyOptions()
225244
self._old_jax_platforms = None
245+
raw_collector = (
246+
metrics_collector.MetricsCollector(self.project)
247+
if collect_service_metrics
248+
else None
249+
)
250+
self.metrics_collector = metrics_collector.SafeMetricsCollector(
251+
raw_collector
252+
)
253+
self.start_time = None
226254
self._old_jax_backend_target = None
227255
self._old_jax_platforms_config = None
228256
self._old_jax_backend_target_config = None
257+
self.total_chips = self._get_total_chips()
229258

230259
def __repr__(self):
231260
return (
@@ -237,8 +266,23 @@ def __repr__(self):
237266
f"proxy_options={self.proxy_options})"
238267
)
239268

269+
def _get_total_chips(self) -> int:
270+
"""Calculates total chips from expected_tpu_instances."""
271+
total_chips = 0
272+
for tpu_type, count in self.expected_tpu_instances.items():
273+
parts = tpu_type.split(":")
274+
topology = parts[1]
275+
dimensions = [int(d) for d in topology.split("x")]
276+
chips_per_instance = 1
277+
for d in dimensions:
278+
chips_per_instance *= d
279+
total_chips += chips_per_instance * count
280+
return total_chips
281+
240282
def __enter__(self):
241283
"""Enters the context manager, ensuring cluster exists."""
284+
self.metrics_collector.record_requested_capacity(self.total_chips)
285+
242286
self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper())
243287
self._old_jax_backend_target = os.environ.get(
244288
_JAX_BACKEND_TARGET_KEY.upper()
@@ -251,6 +295,7 @@ def __enter__(self):
251295
)
252296

253297
try:
298+
self.start_time = time.time()
254299
_deploy_pathways_proxy_server(
255300
pathways_service=self.pathways_service,
256301
proxy_job_name=self._proxy_job_name,
@@ -259,7 +304,7 @@ def __enter__(self):
259304
proxy_server_image=self.proxy_server_image,
260305
proxy_options=self.proxy_options,
261306
)
262-
# Print a link to Cloud Logging
307+
self.metrics_collector.record_user_waiting(True)
263308
cloud_logging_link = gke_utils.get_log_link(
264309
cluster=self.cluster,
265310
project=self.project,
@@ -303,14 +348,14 @@ def __exit__(self, exc_type, exc_value, traceback):
303348

304349
def _cleanup(self) -> None:
305350
"""Cleans up resources created by the ISCPathways context."""
306-
# 1. Clear JAX caches and run garbage collection.
351+
# Clear JAX caches and run garbage collection.
307352
_logger.info("Starting Pathways proxy cleanup.")
308353
jax_backend.clear_backends()
309354
jax.clear_caches()
310355
gc.collect()
311356
_logger.info("Cleared JAX caches and ran garbage collection.")
312357

313-
# 2. Terminate the port forwarding process.
358+
# Terminate the port forwarding process.
314359
if self._port_forward_process:
315360
_logger.info("Terminating port forwarding process...")
316361
self._port_forward_process.terminate()
@@ -323,12 +368,12 @@ def _cleanup(self) -> None:
323368
e,
324369
)
325370

326-
# 3. Delete the proxy GKE job.
371+
# Delete the proxy GKE job.
327372
_logger.info("Deleting Pathways proxy...")
328373
gke_utils.delete_gke_job(self._proxy_job_name)
329374
_logger.info("Pathways proxy GKE job deletion complete.")
330375

331-
# 4. Restore JAX variables.
376+
# Restore JAX variables.
332377
_logger.info("Restoring JAX env and config variables...")
333378
_restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms)
334379
_restore_env_var(
@@ -353,6 +398,7 @@ def connect(
353398
proxy_job_name: str | None = None,
354399
proxy_server_image: str = DEFAULT_PROXY_IMAGE,
355400
proxy_options: ProxyOptions | None = None,
401+
collect_service_metrics: bool = False,
356402
) -> Iterator["_ISCPathways"]:
357403
"""Connects to a Pathways server if the cluster exists. If not, creates it.
358404
@@ -370,6 +416,8 @@ def connect(
370416
default will be used.
371417
proxy_options: Configuration options for the Pathways proxy. If not
372418
provided, no extra options will be used.
419+
collect_service_metrics: Whether to collect usage metrics for Shared
420+
Pathways Service.
373421
374422
Yields:
375423
The Pathways manager.
@@ -399,6 +447,7 @@ def connect(
399447
proxy_job_name=proxy_job_name,
400448
proxy_server_image=proxy_server_image,
401449
proxy_options=proxy_options,
450+
collect_service_metrics=collect_service_metrics,
402451
) as t:
403452
if t.proxy_pod_name:
404453
num_slices = sum(t.expected_tpu_instances.values())
@@ -407,6 +456,10 @@ def connect(
407456
args=(
408457
t.proxy_pod_name,
409458
num_slices,
459+
gke_utils.stream_pod_logs,
460+
t.metrics_collector,
461+
t.start_time,
462+
t.total_chips,
410463
),
411464
daemon=True,
412465
)

0 commit comments

Comments
 (0)