Skip to content

Commit aaac32e

Browse files
authored
[perf] feat: clear megatron global buffer memory (verl-project#5173)
### What does this PR do? - [x] clear megatron global buffer memory - [x] synchronize device before empty_cache in vllm worker - [ ] clear vllm encoder cache memory vllm-project/vllm#33452
1 parent 5cef3ca commit aaac32e

4 files changed

Lines changed: 10 additions & 3 deletions

File tree

verl/utils/megatron_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from megatron.core.distributed import DistributedDataParallelConfig
3232
from megatron.core.enums import ModelType
3333
from megatron.core.optimizer import ChainedOptimizer
34+
from megatron.core.parallel_state import get_global_memory_buffer
3435
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
3536
from megatron.core.transformer.module import Float16Module
3637
from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper
@@ -598,7 +599,7 @@ def _iter_opts(opt):
598599
pass
599600

600601
# Free Megatron-LM's global memory buffer
601-
# get_global_memory_buffer().buffer.clear()
602+
get_global_memory_buffer().buffer.clear()
602603

603604
gc.collect()
604605
get_torch_device().empty_cache()

verl/workers/engine_workers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,8 @@ async def update_weights(self):
596596
return
597597

598598
set_expandable_segments(False)
599+
log_gpu_memory_usage("Before resume weights", logger=logger)
600+
599601
# 1. resume weights and update weights
600602
if self.config.rollout.free_cache_engine:
601603
await self.rollout.resume(tags=["weights"])

verl/workers/rollout/vllm_rollout/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
209209
# receive bucket and update weights
210210
while True:
211211
metadata = socket.recv_pyobj()
212-
weights = []
212+
weights, tensor = [], None
213213
for name, meta in metadata["bucket_meta"].items():
214214
shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"]
215215
size = dtype.itemsize * shape.numel()
@@ -225,7 +225,7 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
225225
get_torch_device().synchronize()
226226
socket.send(b"")
227227
self._update_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done)
228-
del weights
228+
del weights, tensor
229229
if metadata["is_last"]:
230230
break
231231

@@ -235,6 +235,7 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False
235235
if shm is not None:
236236
shm.close()
237237
del shm
238+
get_torch_device().synchronize()
238239
gc.collect()
239240
get_torch_device().ipc_collect()
240241
get_torch_device().empty_cache()

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ async def sleep(self):
580580
if self.rollout_mode == RolloutMode.HYBRID:
581581
# Don't use engine.sleep(level=2) here
582582
await self.engine.collective_rpc("sleep", kwargs={"level": 2})
583+
584+
# clear encoder cache: https://github.com/vllm-project/vllm/pull/33452
585+
# await self.engine.reset_encoder_cache()
583586
elif self.rollout_mode == RolloutMode.COLOCATED:
584587
await self.engine.sleep(level=1)
585588
elif self.rollout_mode == RolloutMode.STANDALONE:

0 commit comments

Comments
 (0)