Skip to content

Commit 969b8ab

Browse files
JenniferWangfacebook-github-bot
authored andcommitted
prefetch weights while waiting for pending requests to complete (#728)
Summary: Feature parity with v0: allow prefetching weights while waiting for the pending requests to finish. ## Test Plan Introduced a benchmark that simulates the on-going requests with actual weight sync logic. Reference Group (V0) ``` ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: False -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.102 s 2.99 GB/s Avg update_weights 43.738 s 0.35 GB/s Avg total (push + update) 48.840 s ================================================================================ ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: True Fetcher procs: 8 -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.208 s 2.93 GB/s Avg update_weights 29.602 s 0.52 GB/s Avg total (push + update) 34.810 s ================================================================================ ``` Test Group (V1) ``` ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: False -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.070 s 3.01 GB/s Avg update_weights 39.974 s 0.38 GB/s Avg total (push + update) 45.044 s ================================================================================ ================================================================================ WEIGHT SYNC BENCHMARK RESULTS ================================================================================ Model: Qwen/Qwen3-8B Model size: 15.26 GB Iterations: 3 Prefetch enabled: True Fetcher procs: 8 -------------------------------------------------------------------------------- Metric Time (s) Throughput (GB/s) -------------------------------------------------------------------------------- Avg push_weights 5.055 s 3.02 GB/s Avg update_weights 28.730 s 0.53 GB/s Avg total (push + update) 33.784 s ================================================================================ ``` ## Next Steps [-] implement the prefetch logic & shared memory [-] Add metric similar to generator v0 [ ] Perf/Throughput testing compared to generator v0 Differential Revision: D91092833
1 parent 5fa56f7 commit 969b8ab

5 files changed

Lines changed: 589 additions & 18 deletions

File tree

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Weight sync benchmark for torchforge generators.
8+
9+
Measures the time for weight synchronization between trainer and generator,
10+
with and without shared memory prefetching enabled.
11+
12+
Example usage:
13+
# Basic benchmark (no prefetch)
14+
python -m benchmarks.generator.weight_sync --config apps/grpo/qwen3_8b.yaml
15+
16+
# With prefetch enabled
17+
python -m benchmarks.generator.weight_sync \
18+
--config apps/grpo/qwen3_8b.yaml \
19+
benchmark.prefetch_enabled=true \
20+
benchmark.n_fetcher_procs=4 \
21+
benchmark.iterations=5
22+
"""
23+
24+
import asyncio
25+
import logging
26+
import os
27+
import time
28+
from dataclasses import dataclass, field
29+
30+
import torch
31+
import torchstore as ts
32+
from forge.actors.generator import Generator
33+
from forge.actors.trainer import TitanTrainer
34+
from forge.controller.provisioner import init_provisioner, shutdown
35+
from forge.controller.service.service import uuid
36+
from forge.types import LauncherConfig, ProvisionerConfig
37+
from forge.util.config import parse, resolve_hf_hub_paths
38+
from monarch.actor import endpoint
39+
from omegaconf import DictConfig
40+
41+
os.environ.setdefault("HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS", "600")
42+
os.environ.setdefault("HYPERACTOR_CODE_MAX_FRAME_LENGTH", "1073741824")
43+
44+
logger = logging.getLogger(__name__)
45+
logging.basicConfig(level=logging.INFO)
46+
47+
48+
class BenchmarkTitanTrainer(TitanTrainer):
49+
"""TitanTrainer with weight modification capabilities for benchmarking."""
50+
51+
@endpoint
52+
async def modify_weights(self, scale: float = 1.1):
53+
"""Scale all model weights by a factor (simulates training step)."""
54+
for model_part in self.engine.model_parts:
55+
sd = model_part.state_dict()
56+
for k in sd.keys():
57+
if torch.is_floating_point(sd[k]):
58+
sd[k] *= scale
59+
60+
@endpoint
61+
async def get_model_size_bytes(self) -> int:
62+
"""Get total model size in bytes across all model parts."""
63+
total_bytes = 0
64+
for model_part in self.engine.model_parts:
65+
for param in model_part.parameters():
66+
total_bytes += param.numel() * param.element_size()
67+
return total_bytes
68+
69+
70+
@dataclass
71+
class WeightSyncMetrics:
72+
"""Metrics from a single weight sync operation."""
73+
74+
version: int
75+
total_time_s: float
76+
push_time_s: float
77+
update_time_s: float
78+
prefetch_enabled: bool
79+
80+
81+
@dataclass
82+
class BenchmarkResults:
83+
"""Aggregated benchmark results."""
84+
85+
model: str
86+
iterations: int
87+
prefetch_enabled: bool
88+
n_fetcher_procs: int
89+
model_size_bytes: int = 0
90+
metrics: list[WeightSyncMetrics] = field(default_factory=list)
91+
92+
@property
93+
def model_size_gb(self) -> float:
94+
return self.model_size_bytes / (1024**3)
95+
96+
@property
97+
def avg_total_time_s(self) -> float:
98+
if not self.metrics:
99+
return 0.0
100+
return sum(m.total_time_s for m in self.metrics) / len(self.metrics)
101+
102+
@property
103+
def avg_push_time_s(self) -> float:
104+
if not self.metrics:
105+
return 0.0
106+
return sum(m.push_time_s for m in self.metrics) / len(self.metrics)
107+
108+
@property
109+
def avg_update_time_s(self) -> float:
110+
if not self.metrics:
111+
return 0.0
112+
return sum(m.update_time_s for m in self.metrics) / len(self.metrics)
113+
114+
@property
115+
def push_throughput_gb_s(self) -> float:
116+
if self.avg_push_time_s <= 0 or self.model_size_bytes <= 0:
117+
return 0.0
118+
return self.model_size_gb / self.avg_push_time_s
119+
120+
@property
121+
def update_throughput_gb_s(self) -> float:
122+
if self.avg_update_time_s <= 0 or self.model_size_bytes <= 0:
123+
return 0.0
124+
return self.model_size_gb / self.avg_update_time_s
125+
126+
127+
def print_results(results: BenchmarkResults):
128+
"""Print benchmark results."""
129+
print("\n" + "=" * 80)
130+
print("WEIGHT SYNC BENCHMARK RESULTS")
131+
print("=" * 80)
132+
print(f"Model: {results.model}")
133+
print(f"Model size: {results.model_size_gb:.2f} GB")
134+
print(f"Iterations: {results.iterations}")
135+
print(f"Prefetch enabled: {results.prefetch_enabled}")
136+
if results.prefetch_enabled:
137+
print(f"Fetcher procs: {results.n_fetcher_procs}")
138+
print("-" * 80)
139+
print(f"{'Metric':<30} {'Time (s)':<15} {'Throughput (GB/s)':<20}")
140+
print("-" * 80)
141+
print(
142+
f"{'Avg push_weights':<30} {results.avg_push_time_s:>12.3f} s "
143+
f"{results.push_throughput_gb_s:>12.2f} GB/s"
144+
)
145+
print(
146+
f"{'Avg update_weights':<30} {results.avg_update_time_s:>12.3f} s "
147+
f"{results.update_throughput_gb_s:>12.2f} GB/s"
148+
)
149+
print(f"{'Avg total (push + update)':<30} {results.avg_total_time_s:>12.3f} s")
150+
print("=" * 80 + "\n")
151+
152+
153+
async def run_weight_sync_benchmark(
154+
cfg: DictConfig,
155+
iterations: int,
156+
prefetch_enabled: bool,
157+
n_fetcher_procs: int,
158+
warmup_iterations: int,
159+
) -> BenchmarkResults:
160+
"""Run weight sync benchmark with knobs to enable prefetch.
161+
162+
Args:
163+
cfg: TorchForge config from YAML
164+
iterations: Number of weight sync iterations to benchmark
165+
prefetch_enabled: Whether to enable shared memory prefetching
166+
n_fetcher_procs: Number of fetcher processes (when prefetch_enabled=True)
167+
warmup_iterations: Number of warmup iterations before timing
168+
169+
Returns:
170+
BenchmarkResults with timing metrics
171+
"""
172+
model_name = cfg.generator.engine_args.get("model", "unknown")
173+
174+
generator_cfg = cfg.generator.copy()
175+
if prefetch_enabled:
176+
generator_cfg.prefetch_weights_to_shm = True
177+
generator_cfg.n_fetcher_procs = n_fetcher_procs
178+
else:
179+
generator_cfg.prefetch_weights_to_shm = False
180+
generator_cfg.n_fetcher_procs = 0
181+
182+
if cfg.get("provisioner", None) is not None:
183+
provisioner = await init_provisioner(
184+
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
185+
)
186+
else:
187+
provisioner = await init_provisioner()
188+
189+
services_generator_cfg = cfg.services.generator.copy()
190+
services_generator_cfg.num_replicas = 1
191+
192+
logger.info("Spawning Generator and Trainer...")
193+
generator, trainer = await asyncio.gather(
194+
Generator.options(**services_generator_cfg).as_service(**generator_cfg),
195+
BenchmarkTitanTrainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer),
196+
)
197+
logger.info("Generator and Trainer spawned.")
198+
199+
trainer_num_procs = cfg.actors.trainer["procs"]
200+
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
201+
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
202+
# same as the main grpo app.
203+
await ts.initialize(
204+
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
205+
strategy=ts.LocalRankStrategy(),
206+
)
207+
logger.info("Torchstore initialized with LocalRankStrategy")
208+
209+
if warmup_iterations > 0:
210+
logger.info(f"Running {warmup_iterations} warmup iteration(s)...")
211+
for i in range(warmup_iterations):
212+
v = uuid.uuid4().int
213+
await trainer.push_weights.call(policy_version=v)
214+
await generator.update_weights.fanout(version=v)
215+
await trainer.modify_weights.call(scale=1.001)
216+
logger.info("Warmup complete.")
217+
218+
# Get model size for throughput calculation
219+
# With DTensor/TP, each rank's param.numel() returns global size, not shard size
220+
# So just take one rank's value
221+
model_size_result = await trainer.get_model_size_bytes.call()
222+
_, model_size_bytes = next(iter(model_size_result.items()))
223+
model_size_gb = model_size_bytes / (1024**3)
224+
logger.info(f"Model size: {model_size_gb:.2f} GB")
225+
226+
logger.info(f"Running {iterations} timed iteration(s)...")
227+
metrics: list[WeightSyncMetrics] = []
228+
229+
# Generate a test prompt for in-flight requests
230+
test_prompt = "What is the capital of France? Please explain in detail."
231+
232+
for i in range(iterations):
233+
v = uuid.uuid4().int
234+
235+
# Modify weights to simulate training
236+
await trainer.modify_weights.call(scale=1.001)
237+
238+
# Time push_weights
239+
push_start = time.perf_counter()
240+
await trainer.push_weights.call(policy_version=v)
241+
push_end = time.perf_counter()
242+
push_time_s = push_end - push_start
243+
244+
# Simulate in-flight requests that pause_generation must wait for
245+
num_inflight = 4
246+
generation_tasks = [
247+
asyncio.create_task(generator.generate.route(test_prompt))
248+
for _ in range(num_inflight)
249+
]
250+
# Give generation a moment to start
251+
await asyncio.sleep(0.1)
252+
253+
# Time update_weights (includes pause_generation waiting for in-flight)
254+
update_start = time.perf_counter()
255+
await generator.update_weights.fanout(version=v)
256+
update_end = time.perf_counter()
257+
update_time_s = update_end - update_start
258+
259+
# Wait for generation to complete (after weight update)
260+
await asyncio.gather(*generation_tasks)
261+
262+
total_time_s = push_time_s + update_time_s
263+
264+
metrics.append(
265+
WeightSyncMetrics(
266+
version=v,
267+
total_time_s=total_time_s,
268+
push_time_s=push_time_s,
269+
update_time_s=update_time_s,
270+
prefetch_enabled=prefetch_enabled,
271+
)
272+
)
273+
274+
logger.info(
275+
f"Iteration {i + 1}/{iterations}: push={push_time_s:.3f}s, "
276+
f"update={update_time_s:.3f}s, total={total_time_s:.3f}s"
277+
)
278+
279+
logger.info("Cleaning up...")
280+
await trainer.cleanup.call()
281+
await generator.shutdown()
282+
await BenchmarkTitanTrainer.shutdown(trainer)
283+
await ts.shutdown()
284+
285+
return BenchmarkResults(
286+
model=model_name,
287+
iterations=iterations,
288+
prefetch_enabled=prefetch_enabled,
289+
n_fetcher_procs=n_fetcher_procs if prefetch_enabled else 0,
290+
model_size_bytes=model_size_bytes,
291+
metrics=metrics,
292+
)
293+
294+
295+
@parse
296+
def recipe_main(cfg: DictConfig = None) -> None: # type: ignore[assignment]
297+
"""Main entry point for weight sync benchmark.
298+
299+
Args:
300+
cfg: Config loaded from YAML file via @parse decorator.
301+
Benchmark parameters can be specified via key=value overrides:
302+
benchmark.iterations=5
303+
benchmark.prefetch_enabled=true
304+
benchmark.n_fetcher_procs=4
305+
benchmark.warmup_iterations=1
306+
"""
307+
cfg = resolve_hf_hub_paths(cfg)
308+
309+
benchmark_cfg = cfg.get("benchmark", {})
310+
iterations = benchmark_cfg.get("iterations", 3)
311+
prefetch_enabled = benchmark_cfg.get("prefetch_enabled", False)
312+
n_fetcher_procs = benchmark_cfg.get("n_fetcher_procs", 8)
313+
warmup_iterations = benchmark_cfg.get("warmup_iterations", 1)
314+
315+
results = asyncio.run(
316+
run_weight_sync_benchmark(
317+
cfg=cfg,
318+
iterations=iterations,
319+
prefetch_enabled=prefetch_enabled,
320+
n_fetcher_procs=n_fetcher_procs,
321+
warmup_iterations=warmup_iterations,
322+
)
323+
)
324+
print_results(results)
325+
326+
asyncio.run(shutdown())
327+
328+
329+
if __name__ == "__main__":
330+
recipe_main()

src/forge/actors/vllm/v0/generator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from collections.abc import Mapping
1515
from copy import copy
1616
from dataclasses import dataclass, field
17+
from multiprocessing import resource_tracker
1718
from typing import Optional
1819

1920
import torch
@@ -727,10 +728,15 @@ async def fetch(
727728
sd = {}
728729
for name in param_names:
729730
param_key = get_param_key(version, name)
731+
# Use explicit resource handling instead of context manager because
732+
# ownership is transferred to the Generator (which calls handle.drop()
733+
# to clean up). We must unregister from resource_tracker here, otherwise
734+
# the fetcher process will try to clean up the shared memory on exit.
730735
param = await ts.get(param_key)
731-
# Use context manager to ensure cleanup after getting handle
732-
with SharedTensor(tensor=param) as shared_tensor:
733-
handle = shared_tensor.get_handle()
734-
sd[name] = handle
736+
shared_tensor = SharedTensor(tensor=param)
737+
handle = shared_tensor.get_handle()
738+
resource_tracker.unregister(f"/{handle.shm_name}", "shared_memory")
739+
sd[name] = handle
740+
shared_tensor.close()
735741
del param # Explicitly free the tensor after copying to shared memory
736742
return sd

0 commit comments

Comments
 (0)