99
1010from xtuner .v1 .data_proto .messages import ChatMessages
1111from xtuner .v1 .data_proto .templates import ChatTemplate , HybridChatTemplate
12- from xtuner .v1 .utils import get_logger
12+ from xtuner .v1 .utils import get_logger , trim_memory
1313
1414from ..data_item import BaseMLLMDataItem , CacheItem
1515from ..utils import CachableTokenizeFunction , tokenizer_xxhash , with_proxy_attention_flops
@@ -118,7 +118,8 @@ def replace_image_token(
118118
119119
120120def load_image (image_path : str ):
121- return Image .open (image_path ).convert ("RGB" )
121+ with Image .open (image_path ) as img :
122+ return img .convert ("RGB" )
122123
123124
124125def get_image_path (image_path : str , media_root : str ):
@@ -144,6 +145,7 @@ def __init__(
144145 data_name : str | None = None ,
145146 llm_pack_weight : float = 1.0 ,
146147 visual_pack_weight : float = 0.0 ,
148+ trim_memory_step : int = 1 ,
147149 ):
148150 self .max_length = max_length
149151 self ._tokenizer_hash = tokenizer_hash
@@ -157,10 +159,17 @@ def __init__(
157159 self ._image_wh_list : list [list ] = []
158160 self ._video_wh_list : list [list ] = []
159161 self ._video_extra_info_list : list [dict ] = []
162+ self ._trim_memory_step = max (1 , trim_memory_step )
163+ self ._trim_memory_count = 0
160164
161165 self ._hash_str += f"llm_pack_weight:{ llm_pack_weight } _visual_pack_weight:{ visual_pack_weight } "
162166 super ().__init__ (tokenizer , llm_pack_weight = llm_pack_weight , visual_pack_weight = visual_pack_weight )
163167
168+ def _maybe_trim_memory (self ):
169+ self ._trim_memory_count += 1
170+ if self ._trim_memory_count % self ._trim_memory_step == 0 :
171+ trim_memory (logger )
172+
164173 def calc_num_tokens_multi_modal_get_item (self , data_item : dict ) -> CacheItem :
165174 raise NotImplementedError
166175
@@ -213,11 +222,13 @@ def __call__(self, item: dict, media_root: str = "", **kwargs) -> T | CacheItem:
213222 ret = self .calc_num_tokens_multi_modal_get_item (item )
214223 else :
215224 ret = self .multi_modal_get_item (item , media_root )
225+ self ._maybe_trim_memory ()
216226 elif len (self ._video_path ) > 0 :
217227 if self .state == "cache" :
218228 ret = self .calc_num_tokens_video_get_item (item )
219229 else :
220230 ret = self .video_get_item (item , media_root )
231+ self ._maybe_trim_memory ()
221232 else :
222233 if self .state == "cache" :
223234 ret = self .calc_num_tokens_pure_text_get_item (item )
@@ -257,6 +268,7 @@ class BaseMLLMTokenizeFnConfig(BaseModel):
257268 add_bos_token : bool = False # for mllm pretrain
258269 llm_pack_weight : float = 1.0
259270 visual_pack_weight : float = 0.0
271+ trim_memory_step : int = 1
260272
261273 def build (
262274 self , tokenizer , tokenizer_hash : str | None = None , anno_name : str = "" , ** kwargs
0 commit comments