Skip to content

Commit a5bf748

Browse files
fix vl mem leak (#1645)
* fix vl mem leak * fix * refine * [Fix] Address review comments for trim_memory - Remove unnecessary logger parameter from trim_memory(); use global logger internally only when needed (on failure). - Rename trim_memory_step to trim_memory_interval to follow trainer naming convention (_step = counter, _interval = frequency). * [Fix] Fix trim_memory counter logic and naming - Move counter increment outside the if-block so it increments on every call, not only when trimming. Previously, with interval > 1 the counter would get stuck after the first trim. - Rename _trim_memory_count to _trim_memory_counter to better convey that it is a step counter, not a count of trims performed. --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
1 parent d215d30 commit a5bf748

5 files changed

Lines changed: 54 additions & 9 deletions

File tree

xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from xtuner.v1.data_proto.messages import ChatMessages
1111
from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate
12-
from xtuner.v1.utils import get_logger
12+
from xtuner.v1.utils import get_logger, trim_memory
1313

1414
from ..data_item import BaseMLLMDataItem, CacheItem
1515
from ..utils import CachableTokenizeFunction, tokenizer_xxhash, with_proxy_attention_flops
@@ -118,7 +118,8 @@ def replace_image_token(
118118

119119

120120
def load_image(image_path: str):
121-
return Image.open(image_path).convert("RGB")
121+
with Image.open(image_path) as img:
122+
return img.convert("RGB")
122123

123124

124125
def get_image_path(image_path: str, media_root: str):
@@ -144,6 +145,7 @@ def __init__(
144145
data_name: str | None = None,
145146
llm_pack_weight: float = 1.0,
146147
visual_pack_weight: float = 0.0,
148+
trim_memory_interval: int = 1,
147149
):
148150
self.max_length = max_length
149151
self._tokenizer_hash = tokenizer_hash
@@ -158,6 +160,9 @@ def __init__(
158160
self._video_wh_list: list[list] = []
159161
self._video_extra_info_list: list[dict] = []
160162

163+
self._trim_memory_interval = trim_memory_interval
164+
self._trim_memory_counter = 0
165+
161166
self._hash_str += f"llm_pack_weight:{llm_pack_weight}_visual_pack_weight:{visual_pack_weight}"
162167
super().__init__(tokenizer, llm_pack_weight=llm_pack_weight, visual_pack_weight=visual_pack_weight)
163168

@@ -213,16 +218,23 @@ def __call__(self, item: dict, media_root: str = "", **kwargs) -> T | CacheItem:
213218
ret = self.calc_num_tokens_multi_modal_get_item(item)
214219
else:
215220
ret = self.multi_modal_get_item(item, media_root)
221+
if self._trim_memory_counter % self._trim_memory_interval == 0:
222+
trim_memory()
223+
self._trim_memory_counter += 1
216224
elif len(self._video_path) > 0:
217225
if self.state == "cache":
218226
ret = self.calc_num_tokens_video_get_item(item)
219227
else:
220228
ret = self.video_get_item(item, media_root)
229+
if self._trim_memory_counter % self._trim_memory_interval == 0:
230+
trim_memory()
231+
self._trim_memory_counter += 1
221232
else:
222233
if self.state == "cache":
223234
ret = self.calc_num_tokens_pure_text_get_item(item)
224235
else:
225236
ret = self.pure_text_get_item(item)
237+
226238
return ret
227239

228240
def hash(self) -> str:
@@ -257,6 +269,7 @@ class BaseMLLMTokenizeFnConfig(BaseModel):
257269
add_bos_token: bool = False # for mllm pretrain
258270
llm_pack_weight: float = 1.0
259271
visual_pack_weight: float = 0.0
272+
trim_memory_interval: int = 1
260273

261274
def build(
262275
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs

xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
hash: str | None = None,
234234
add_eos_token: bool = True, # for mllm pretrain
235235
add_bos_token: bool = False, # for mllm pretrain
236+
trim_memory_interval: int = 1,
236237
):
237238
self.oss_loader = None
238239
self.debug = debug
@@ -335,6 +336,7 @@ def __init__(
335336
data_name=self.data_name,
336337
llm_pack_weight=llm_pack_weight,
337338
visual_pack_weight=visual_pack_weight,
339+
trim_memory_interval=trim_memory_interval,
338340
)
339341

340342
def _truncated_data_item(
@@ -904,6 +906,8 @@ class Qwen3VLTokenizeFnConfig(BaseMLLMTokenizeFnConfig):
904906
# it's helpful to add labels to the images and videos for better reference.
905907
add_vision_id: bool = True
906908

909+
trim_memory_interval: int = 1
910+
907911
def build(
908912
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs
909913
) -> Qwen3VLTokenizeFunction:
@@ -932,4 +936,5 @@ def build(
932936
oss_time_log_thr=self.oss_time_log_thr,
933937
add_eos_token=self.add_eos_token, # for mllm pretrain
934938
add_bos_token=self.add_bos_token, # for mllm pretrain
939+
trim_memory_interval=self.trim_memory_interval,
935940
)

xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222

2323

2424
def pil_loader(img_str):
25-
buff = io.BytesIO(img_str)
26-
img = Image.open(buff)
27-
return img.convert("RGB")
25+
# Ensure both the BytesIO buffer and PIL image file handle are closed promptly.
26+
with io.BytesIO(img_str) as buff:
27+
with Image.open(buff) as img:
28+
return img.convert("RGB")
2829

2930

3031
def extract_frame_number(filename):
@@ -109,12 +110,13 @@ def read_frames_folder(
109110
start_time = time.time()
110111
image_byte = client.get(image_list[frame_index])
111112
oss_read_time += time.time() - start_time
112-
frame = Image.open(io.BytesIO(image_byte))
113-
frame_list.append(np.array(frame))
113+
with io.BytesIO(image_byte) as buff:
114+
with Image.open(buff) as frame:
115+
frame_list.append(np.array(frame))
114116
else:
115117
fp = os.path.join(video_path, image_list[frame_index])
116-
frame = Image.open(fp).convert("RGB")
117-
frame_list.append(np.array(frame))
118+
with Image.open(fp) as frame:
119+
frame_list.append(np.array(frame.convert("RGB")))
118120

119121
frames = numpy_to_tensor(frame_list)
120122
return frames, oss_read_time, len(frames), frames_indices, timestamps

xtuner/v1/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_hf_model_path,
1818
is_local_rank0,
1919
record_git_info,
20+
trim_memory,
2021
)
2122
from .pad import pad_to_max_length, pad_to_multiple_of
2223
from .profile import profile_time, profile_time_and_memory, timer, timer_logger
@@ -61,4 +62,5 @@
6162
"ray_method",
6263
"profile_time",
6364
"clean_param_name",
65+
"trim_memory",
6466
]

xtuner/v1/utils/misc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,26 @@ def clean_param_name(name: str) -> str:
214214
if "_orig_mod." in name:
215215
name = name.replace("_orig_mod.", "")
216216
return name
217+
218+
219+
_TRIM_MEMORY_WARNED = False
220+
221+
222+
def trim_memory() -> bool:
223+
"""Try to return free heap pages to OS.
224+
225+
Best-effort only: on platforms without `malloc_trim` (or when unavailable),
226+
this will fail. We log the failure once per process to avoid spamming.
227+
"""
228+
global _TRIM_MEMORY_WARNED
229+
try:
230+
import ctypes
231+
232+
libc = ctypes.CDLL("libc.so.6")
233+
return libc.malloc_trim(0)
234+
except Exception as e:
235+
if not _TRIM_MEMORY_WARNED:
236+
_logger = get_logger()
237+
_logger.warning(f" >>>>>>>>> [trim_memory] Failed to trim memory: {e} <<<<<<<<")
238+
_TRIM_MEMORY_WARNED = True
239+
return False

0 commit comments

Comments
 (0)