Skip to content

Commit 316c490

Browse files
committed
Adding following metrics into JetStream server:
1) kv_cache_utilization: This refers to percentage of memory in the allocated kv-cache on TPU HBM, that is actually used during decode. It is based on the percentage of slots used. 2) num_requests_waiting: Total number of requests which are waiting to be decoded. 3) lora_requests_info: List of LoRA adapters that are loaded into the TPU HBM for serving the requests.
1 parent 3c6fcbd commit 316c490

4 files changed

Lines changed: 377 additions & 0 deletions

File tree

jetstream/core/adapter_tensorstore.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,17 @@ def convert_if_np(leaf):
139139
return jax.tree_util.tree_map(convert_if_np, params)
140140

141141

142+
async def get_hbm_loaded_adapters(self):
143+
hbm_loaded_adapters = []
144+
145+
async with self.lock:
146+
for adapter_id, metadata in self.adapter_registry.items():
147+
if metadata.status == "loaded_hbm":
148+
hbm_loaded_adapters.append(adapter_id)
149+
150+
return ", ".join(hbm_loaded_adapters)
151+
152+
142153
async def load_adapter(self, adapter_id: str, adapter_weights = None, to_hbm: bool = True):
143154
"""Loads a LoRA adapter's weights, managing HBM and CPU memory."""
144155
if adapter_id not in self.adapter_registry:

jetstream/core/metrics/prometheus.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,31 @@ def __new__(cls):
214214
],
215215
)
216216

217+
_num_requests_waiting = Gauge(
218+
name="num_requests_waiting",
219+
documentation="Number of requests waiting to be processed for inference.",
220+
labelnames=["id"],
221+
multiprocess_mode="sum",
222+
)
223+
224+
_kv_cache_utilization = Gauge(
225+
name="kv_cache_utilization_perc",
226+
documentation="Percentage of kv-cache utilized by the requests under processing.",
227+
labelnames=["id"],
228+
multiprocess_mode="sum",
229+
)
230+
231+
_lora_request_info = Gauge(
232+
name="lora_request_info",
233+
documentation="Information about LoRA adapters loaded into TPU Memory for serving current requests.",
234+
labelnames=[
235+
"id",
236+
"max_lora",
237+
"running_lora_adapters",
238+
],
239+
multiprocess_mode="livemostrecent",
240+
)
241+
217242
def get_prefill_backlog_metric(self):
218243
return self._prefill_backlog.labels(id=self._id)
219244

@@ -255,3 +280,12 @@ def get_request_output_length(self):
255280

256281
def get_request_success_count_metric(self):
257282
return self._request_success_count.labels(id=self._id)
283+
284+
def get_num_requests_waiting_metric(self):
285+
return self._num_requests_waiting.labels(id=self._id)
286+
287+
def get_kv_cache_utilization_metric(self):
288+
return self._kv_cache_utilization.labels(id=self._id)
289+
290+
def get_lora_request_info_metric(self, max_lora: int, loaded_adapters: str):
291+
return self._lora_request_info.labels(id=self._id, max_lora=max_lora, running_lora_adapters=loaded_adapters)

jetstream/core/orchestrator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def __init__(
314314
self._metrics_collector.get_generate_backlog_metric(idx).set_function(
315315
functools.partial(float, backlog.qsize())
316316
)
317+
317318
# Stage 4
318319
# After prefill and generation, ActiveRequests are placed on the
319320
# detokenization backlog for tokens to be sent into each ActiveRequest's
@@ -433,6 +434,12 @@ def __init__(
433434
self.live = True
434435
self._is_ray_backend = is_ray_backend
435436

437+
if self._metrics_collector:
438+
self._metrics_collector.get_num_requests_waiting_metric().set_function(
439+
self._get_total_requests_waiting_decode)
440+
self._metrics_collector.get_kv_cache_utilization_metric().set_function(
441+
self._get_kv_cache_utilization)
442+
436443
# Start all threads
437444
for t in self._all_threads:
438445
t.start()
@@ -481,6 +488,28 @@ def stop(self):
481488
for t in self._all_threads:
482489
t.join()
483490

491+
def _get_kv_cache_utilization(self):
492+
"""Calculated the kv_cache utilization in percentage based on requests being decoded."""
493+
total_slots = 0
494+
empty_slots = 0
495+
for idx, engine in enumerate(self._generate_engines):
496+
total_slots += engine.max_concurrent_decodes
497+
empty_slots += self._generate_slots[idx].qsize()
498+
499+
return ((total_slots - empty_slots) * 100 / total_slots)
500+
501+
def _get_total_requests_waiting_decode(self):
502+
"""Calculate the total size of all relevant queues."""
503+
total_size = self._prefill_backlog.qsize()
504+
505+
for transfer_queue in self._transfer_backlogs:
506+
total_size += transfer_queue.qsize()
507+
508+
for gen_queue in self._generate_backlogs.values():
509+
total_size += gen_queue.qsize()
510+
511+
return float(total_size)
512+
484513
def get_total_concurrent_requests(self) -> int:
485514
"""Gets the total number of concurrent requests the driver can handle."""
486515
# We don't support filling all backlogs at once because it can cause GIL
@@ -819,6 +848,14 @@ def _generate_thread(self, idx: int):
819848

820849
start_time = time.perf_counter()
821850

851+
if self._metrics_collector:
852+
adapters_list_str = asyncio.run(self._adapter_tensorstore.get_hbm_loaded_adapters())
853+
854+
max_loras = max_concurrent_decodes
855+
856+
self._metrics_collector.get_lora_request_info_metric(max_loras,
857+
adapters_list_str).set_to_current_time()
858+
822859
# Now we actually take a generate step on requests in the slots.
823860
decode_state, sampled_tokens = generate_engine.generate(
824861
generate_params[adapter_id], decode_state

0 commit comments

Comments
 (0)