Skip to content

Commit 5efbbf4

Browse files
Guyue Huangguyueh1
authored andcommitted
Discard weight when finish generation in the main loop
Signed-off-by: Guyue Huang <guyueh@login-lyris01.lyris.clusters.nvidia.com>
1 parent 89da2d4 commit 5efbbf4

4 files changed

Lines changed: 11 additions & 5 deletions

File tree

nemo_rl/algorithms/grpo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,11 @@ def grpo_train(
15861586
max_rollout_turns=master_config.grpo["max_rollout_turns"],
15871587
greedy=False,
15881588
)
1589-
policy_generation.finish_generation()
1589+
policy_generation.finish_generation(
1590+
discard_weights=colocated_inference
1591+
)
1592+
if colocated_inference:
1593+
POLICY_GENERATION_STALE = True
15901594
# Collect generation logger metrics for performance reporting after each generation step
15911595
# inflight batch sizes and num pending samples are collected from each worker
15921596
if policy_generation is not None:

nemo_rl/models/generation/vllm/vllm_generation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,10 +732,12 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool:
732732
if self.cfg["vllm_cfg"]["async_engine"]
733733
else "reset_prefix_cache"
734734
)
735+
kwargs = {}
735736
# Use run_all_workers_single_data for methods that don't need data
736737
futures = self.worker_group.run_all_workers_single_data(
737738
method_name,
738739
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
740+
**kwargs,
739741
)
740742
# Wait for all futures to complete
741743
results = ray.get(futures)

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ def reset_prefix_cache(self):
986986
gc.collect()
987987
torch.cuda.empty_cache()
988988

989-
def sleep(self):
989+
def sleep(self, discard_weights: bool = False):
990990
"""Put the vLLM engine to sleep."""
991991
assert self.llm is not None, (
992992
"Attempting to sleep with either an uninitialized vLLM or non-model-owner"
@@ -1009,7 +1009,7 @@ def sleep(self):
10091009
self.llm.renderer, "clear_mm_cache"
10101010
):
10111011
self.llm.renderer.clear_mm_cache()
1012-
self.llm.sleep(level=1)
1012+
self.llm.sleep(level=2 if discard_weights else 1)
10131013

10141014
gc.collect()
10151015
torch.cuda.empty_cache()

nemo_rl/models/generation/vllm/vllm_worker_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ async def reset_prefix_cache_async(self):
11291129
gc.collect()
11301130
torch.cuda.empty_cache()
11311131

1132-
async def sleep_async(self):
1132+
async def sleep_async(self, discard_weights: bool = False):
11331133
"""Async version of sleep."""
11341134
assert self.llm is not None, (
11351135
"Attempting to sleep with either an uninitialized vLLM or non-model-owner"
@@ -1148,7 +1148,7 @@ async def sleep_async(self):
11481148
# the receiver and sends data=None, causing an assertion error.
11491149
if hasattr(self.llm, "reset_mm_cache"):
11501150
await self.llm.reset_mm_cache()
1151-
await self.llm.sleep(level=1)
1151+
await self.llm.sleep(level=2 if discard_weights else 1)
11521152

11531153
gc.collect()
11541154
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)