Skip to content

Commit 6032f20

Browse files
committed
apply trim memory
1 parent 73d9fbd commit 6032f20

5 files changed

Lines changed: 49 additions & 9 deletions

File tree

xtuner/v1/datasets/mllm_tokenize_fn/base_mllm_tokenize_fn.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from xtuner.v1.data_proto.messages import ChatMessages
1111
from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate
12-
from xtuner.v1.utils import get_logger
12+
from xtuner.v1.utils import get_logger, trim_memory
1313

1414
from ..data_item import BaseMLLMDataItem, CacheItem
1515
from ..utils import CachableTokenizeFunction, tokenizer_xxhash, with_proxy_attention_flops
@@ -118,7 +118,8 @@ def replace_image_token(
118118

119119

120120
def load_image(image_path: str):
121-
return Image.open(image_path).convert("RGB")
121+
with Image.open(image_path) as img:
122+
return img.convert("RGB")
122123

123124

124125
def get_image_path(image_path: str, media_root: str):
@@ -144,6 +145,7 @@ def __init__(
144145
data_name: str | None = None,
145146
llm_pack_weight: float = 1.0,
146147
visual_pack_weight: float = 0.0,
148+
trim_memory_step: int = 1,
147149
):
148150
self.max_length = max_length
149151
self._tokenizer_hash = tokenizer_hash
@@ -157,10 +159,17 @@ def __init__(
157159
self._image_wh_list: list[list] = []
158160
self._video_wh_list: list[list] = []
159161
self._video_extra_info_list: list[dict] = []
162+
self._trim_memory_step = max(1, trim_memory_step)
163+
self._trim_memory_count = 0
160164

161165
self._hash_str += f"llm_pack_weight:{llm_pack_weight}_visual_pack_weight:{visual_pack_weight}"
162166
super().__init__(tokenizer, llm_pack_weight=llm_pack_weight, visual_pack_weight=visual_pack_weight)
163167

168+
def _maybe_trim_memory(self):
169+
self._trim_memory_count += 1
170+
if self._trim_memory_count % self._trim_memory_step == 0:
171+
trim_memory(logger)
172+
164173
def calc_num_tokens_multi_modal_get_item(self, data_item: dict) -> CacheItem:
165174
raise NotImplementedError
166175

@@ -213,11 +222,13 @@ def __call__(self, item: dict, media_root: str = "", **kwargs) -> T | CacheItem:
213222
ret = self.calc_num_tokens_multi_modal_get_item(item)
214223
else:
215224
ret = self.multi_modal_get_item(item, media_root)
225+
self._maybe_trim_memory()
216226
elif len(self._video_path) > 0:
217227
if self.state == "cache":
218228
ret = self.calc_num_tokens_video_get_item(item)
219229
else:
220230
ret = self.video_get_item(item, media_root)
231+
self._maybe_trim_memory()
221232
else:
222233
if self.state == "cache":
223234
ret = self.calc_num_tokens_pure_text_get_item(item)
@@ -257,6 +268,7 @@ class BaseMLLMTokenizeFnConfig(BaseModel):
257268
add_bos_token: bool = False # for mllm pretrain
258269
llm_pack_weight: float = 1.0
259270
visual_pack_weight: float = 0.0
271+
trim_memory_step: int = 1
260272

261273
def build(
262274
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs

xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_tokenize_fn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
hash: str | None = None,
234234
add_eos_token: bool = True, # for mllm pretrain
235235
add_bos_token: bool = False, # for mllm pretrain
236+
trim_memory_step: int = 1,
236237
):
237238
self.oss_loader = None
238239
self.debug = debug
@@ -335,6 +336,7 @@ def __init__(
335336
data_name=self.data_name,
336337
llm_pack_weight=llm_pack_weight,
337338
visual_pack_weight=visual_pack_weight,
339+
trim_memory_step=trim_memory_step,
338340
)
339341

340342
def _truncated_data_item(
@@ -903,6 +905,7 @@ class Qwen3VLTokenizeFnConfig(BaseMLLMTokenizeFnConfig):
903905
# When handling multiple images or multiple videos,
904906
# it's helpful to add labels to the images and videos for better reference.
905907
add_vision_id: bool = True
908+
trim_memory_step: int = 1
906909

907910
def build(
908911
self, tokenizer, tokenizer_hash: str | None = None, anno_name: str = "", **kwargs
@@ -932,4 +935,5 @@ def build(
932935
oss_time_log_thr=self.oss_time_log_thr,
933936
add_eos_token=self.add_eos_token, # for mllm pretrain
934937
add_bos_token=self.add_bos_token, # for mllm pretrain
938+
trim_memory_step=self.trim_memory_step,
935939
)

xtuner/v1/datasets/mllm_tokenize_fn/qwen3_vl_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222

2323

2424
def pil_loader(img_str):
25-
buff = io.BytesIO(img_str)
26-
img = Image.open(buff)
27-
return img.convert("RGB")
25+
with io.BytesIO(img_str) as buff:
26+
with Image.open(buff) as img:
27+
return img.convert("RGB")
2828

2929

3030
def extract_frame_number(filename):
@@ -109,12 +109,13 @@ def read_frames_folder(
109109
start_time = time.time()
110110
image_byte = client.get(image_list[frame_index])
111111
oss_read_time += time.time() - start_time
112-
frame = Image.open(io.BytesIO(image_byte))
113-
frame_list.append(np.array(frame))
112+
with io.BytesIO(image_byte) as buff:
113+
with Image.open(buff) as frame:
114+
frame_list.append(np.array(frame))
114115
else:
115116
fp = os.path.join(video_path, image_list[frame_index])
116-
frame = Image.open(fp).convert("RGB")
117-
frame_list.append(np.array(frame))
117+
with Image.open(fp) as frame:
118+
frame_list.append(np.array(frame.convert("RGB")))
118119

119120
frames = numpy_to_tensor(frame_list)
120121
return frames, oss_read_time, len(frames), frames_indices, timestamps

xtuner/v1/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
is_hf_model_path,
1818
is_local_rank0,
1919
record_git_info,
20+
trim_memory,
2021
)
2122
from .pad import pad_to_max_length, pad_to_multiple_of
2223
from .profile import profile_time, profile_time_and_memory, timer, timer_logger
@@ -61,4 +62,5 @@
6162
"ray_method",
6263
"profile_time",
6364
"clean_param_name",
65+
"trim_memory",
6466
]

xtuner/v1/utils/misc.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23
import sys
34
import threading
@@ -214,3 +215,23 @@ def clean_param_name(name: str) -> str:
214215
if "_orig_mod." in name:
215216
name = name.replace("_orig_mod.", "")
216217
return name
218+
219+
220+
_TRIM_MEMORY_WARNED = False
221+
222+
223+
def trim_memory(logger: logging.Logger | None = None):
224+
"""Try to return free heap pages to OS."""
225+
global _TRIM_MEMORY_WARNED
226+
if logger is None:
227+
logger = get_logger()
228+
try:
229+
import ctypes
230+
231+
libc = ctypes.CDLL("libc.so.6")
232+
return libc.malloc_trim(0)
233+
except Exception as e:
234+
if not _TRIM_MEMORY_WARNED:
235+
logger.warning(f" >>>>>>>>> [trim_memory] Failed to trim memory: {e} <<<<<<<<")
236+
_TRIM_MEMORY_WARNED = True
237+
return False

0 commit comments

Comments
 (0)