Skip to content

Commit 58bf8e3

Browse files
authored
add tracer in v1 to log generator perf metrics (#720)
1 parent a3ae18b commit 58bf8e3

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

src/forge/actors/vllm/v1/forge_executor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import cloudpickle
3232
from forge.actors._torchstore_utils import extract_param_name, get_param_prefix
3333
from forge.actors.vllm.v1.monarch_executor import MonarchExecutor, WorkerWrapper
34+
from forge.observability.perf_tracker import trace
3435
from monarch.actor import endpoint
3536
from torchstore.client import LocalClient
3637

@@ -57,6 +58,11 @@ def set_torchstore_controller(self, controller) -> None:
5758
self._torchstore_client = None # Reset cached client
5859

5960
@endpoint
61+
@trace(
62+
prefix="generator_perf/update_weights/generator_worker_update",
63+
track_memory=False,
64+
timer="gpu",
65+
)
6066
def update_weights(self, version: int) -> int:
6167
"""Load weights directly from torchstore.
6268

src/forge/actors/vllm/v1/generator.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import base64
1010
import logging
1111
import os
12+
import time
1213
import uuid
1314
from collections.abc import Mapping
1415
from dataclasses import dataclass, field
@@ -20,6 +21,10 @@
2021
from forge.controller.provisioner import _get_provisioner
2122
from forge.data_models.completion import Completion
2223
from forge.data_models.prompt import to_prompt
24+
from forge.env import FORGE_DISABLE_METRICS
25+
from forge.observability.metric_actors import get_or_create_metric_logger
26+
from forge.observability.metrics import record_metric, Reduce
27+
from forge.observability.perf_tracker import Tracer
2328
from monarch.actor import endpoint, this_host
2429
from torchstore.api import _controller as get_torchstore_controller
2530
from vllm.engine.arg_utils import EngineArgs
@@ -142,6 +147,10 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
142147
)
143148
logger.info("[Generator.launch] Spawned generator_proc on head host")
144149

150+
# Register LocalFetcherActor for generator_proc to enable metrics collection
151+
if not FORGE_DISABLE_METRICS.get_value():
152+
await get_or_create_metric_logger(generator_proc, process_name=mesh_name)
153+
145154
# Import WorkerRegistry here to avoid circular import with monarch_executor
146155
from forge.actors.vllm.v1.monarch_executor import WorkerRegistry
147156

@@ -257,6 +266,10 @@ async def generate(
257266
Returns:
258267
list[Completion]: n completions from vLLM based on your prompt.
259268
"""
269+
t = Tracer("generator_perf/generate", timer="gpu")
270+
t.start()
271+
record_metric("generator/generate/count_requests", 1, Reduce.SUM)
272+
260273
if self.llm is None:
261274
raise RuntimeError("Generator not initialized. Call setup() first.")
262275

@@ -277,6 +290,12 @@ async def generate(
277290

278291
completions = self._to_completions(request_output, prompt)
279292

293+
record_metric(
294+
"generator/generate/count_sequences_completed",
295+
len(completions),
296+
Reduce.SUM,
297+
)
298+
t.stop()
280299
return completions
281300

282301
@endpoint
@@ -347,17 +366,30 @@ async def update_weights(
347366

348367
logger.info(f"Starting weight update to v{version}")
349368

369+
pause_start = time.perf_counter()
350370
await self.llm.pause_generation(
351371
wait_for_inflight_requests=True, clear_cache=True
352372
)
373+
pause_duration = time.perf_counter() - pause_start
374+
record_metric(
375+
"generator_perf/update_weights/pause_generation_duration_s",
376+
pause_duration,
377+
Reduce.MEAN,
378+
)
353379

354380
try:
381+
load_start = time.perf_counter()
355382
await self.workers.update_weights.call(version)
383+
load_duration = time.perf_counter() - load_start
384+
record_metric(
385+
"generator_perf/update_weights/worker_load_weights_duration_s",
386+
load_duration,
387+
Reduce.MEAN,
388+
)
356389
self.generator_version = version
357390
logger.info(f"Updated weights from torchstore v{version}")
358391
finally:
359392
await self.llm.resume_generation()
360-
361393
logger.info(f"Weight update complete, now v{version}")
362394

363395
@endpoint

0 commit comments

Comments
 (0)