|
18 | 18 | import torch |
19 | 19 | from transformers.image_utils import load_image |
20 | 20 |
|
| 21 | +from angelslim.utils import rank0_print |
| 22 | + |
21 | 23 | __all__ = [ |
22 | 24 | "process_token_dict_to_mappings", |
23 | 25 | "convert_sharegpt_data", |
|
27 | 29 | "VLMHunyuanDataCollatorWithPadding", |
28 | 30 | "AudioDataCollatorWithPadding", |
29 | 31 | "CosyVoice3DataCollatorWithPadding", |
| 32 | + "build_image_processor_kwargs", |
30 | 33 | ] |
31 | 34 |
|
32 | 35 |
|
| 36 | +def build_image_processor_kwargs(image_processor, max_pixels=None, min_pixels=None): |
| 37 | + """ |
| 38 | + convert max_pixels/min_pixels to the format required by the specific image_processor. |
| 39 | + - Qwen2.5-VL: directly use max_pixels / min_pixels |
| 40 | + - Qwen3-VL: convert to size={"longest_edge": max_pixels, "shortest_edge": min_pixels} |
| 41 | +
|
| 42 | + Args: |
| 43 | + image_processor: model's image_processor instance |
| 44 | + max_pixels: maximum pixels (total area), None means no limit |
| 45 | + min_pixels: minimum pixels (total area), None means no limit |
| 46 | +
|
| 47 | + Returns: |
| 48 | + dict: can be directly passed to image_processor(...) |
| 49 | + """ |
| 50 | + if max_pixels is None and min_pixels is None: |
| 51 | + return {} |
| 52 | + |
| 53 | + processor_class = type(image_processor).__name__ |
| 54 | + # Qwen3-VL uses size={"longest_edge": ..., "shortest_edge": ...} |
| 55 | + if "Qwen3" in processor_class: |
| 56 | + size = {} |
| 57 | + if max_pixels is not None: |
| 58 | + size["longest_edge"] = max_pixels |
| 59 | + if min_pixels is not None: |
| 60 | + size["shortest_edge"] = min_pixels |
| 61 | + return {"size": size} |
| 62 | + else: |
| 63 | + # Qwen2.5-VL's accept max_pixels and min_pixels |
| 64 | + kwargs = {} |
| 65 | + if max_pixels is not None: |
| 66 | + kwargs["max_pixels"] = max_pixels |
| 67 | + if min_pixels is not None: |
| 68 | + kwargs["min_pixels"] = min_pixels |
| 69 | + return kwargs |
| 70 | + |
| 71 | + |
33 | 72 | def convert_sharegpt_data(row, dataset_column="conversations"): |
34 | 73 | converted_messages = [] |
35 | 74 |
|
@@ -78,19 +117,19 @@ def process_token_dict_to_mappings( |
78 | 117 | token_dict[token] = 0 |
79 | 118 | if len(token_dict) >= draft_vocab_size: |
80 | 119 | break |
81 | | - print(f"Added missing tokens to reach draft vocab size: {draft_vocab_size}") |
82 | | - print(f"Total tokens after addition: {len(token_dict)}") |
| 120 | + rank0_print(f"Added missing tokens to reach draft vocab size: {draft_vocab_size}") |
| 121 | + rank0_print(f"Total tokens after addition: {len(token_dict)}") |
83 | 122 | total_frequency = sum(token_dict.values()) |
84 | 123 | top_N = token_dict.most_common(draft_vocab_size) |
85 | 124 | top_N_frequency_sum = sum(freq for key, freq in top_N) |
86 | 125 |
|
87 | 126 | if total_frequency == 0: |
88 | | - print("Warning: Total token frequency is zero. All tokens will have zero ratio.") |
| 127 | + rank0_print("Warning: Total token frequency is zero. All tokens will have zero ratio.") |
89 | 128 | top_N_ratio = 0.0 |
90 | 129 | else: |
91 | 130 | top_N_ratio = top_N_frequency_sum / total_frequency |
92 | 131 |
|
93 | | - print(f"top {draft_vocab_size} token frequency ratio: {top_N_ratio:.2%}") |
| 132 | + rank0_print(f"top {draft_vocab_size} token frequency ratio: {top_N_ratio:.2%}") |
94 | 133 | used_tokens = [key for key, freq in top_N] |
95 | 134 | used_tokens.sort() |
96 | 135 |
|
@@ -199,14 +238,29 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
199 | 238 |
|
200 | 239 | class VLMDataCollatorWithPadding: |
201 | 240 |
|
202 | | - def __init__(self, processor=None): |
| 241 | + def __init__(self, processor=None, image_processor_kwargs=None): |
203 | 242 | """ |
204 | 243 | Args: |
205 | 244 | processor: VLM processor (e.g. AutoProcessor for qwen3_vl). |
206 | 245 | When provided, image_paths in features will be decoded |
207 | 246 | on-the-fly to pixel_values (used in online training). |
| 247 | + image_processor_kwargs: Additional kwargs passed to image_processor, |
| 248 | + e.g. {"max_pixels": 1003520, "min_pixels": 200704}. |
208 | 249 | """ |
209 | 250 | self.processor = processor |
| 251 | + max_pixels = image_processor_kwargs.get("max_pixels", None) |
| 252 | + min_pixels = image_processor_kwargs.get("min_pixels", None) |
| 253 | + if ( |
| 254 | + processor is not None |
| 255 | + and (max_pixels is not None or min_pixels is not None) |
| 256 | + and hasattr(processor, "image_processor") |
| 257 | + ): |
| 258 | + self._resolved_image_processor_kwargs = build_image_processor_kwargs( |
| 259 | + processor.image_processor, max_pixels, min_pixels |
| 260 | + ) |
| 261 | + else: |
| 262 | + self._resolved_image_processor_kwargs = {} |
| 263 | + rank0_print(f"_resolved_image_processor_kwargs: {self._resolved_image_processor_kwargs}") |
210 | 264 |
|
211 | 265 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
212 | 266 | max_length = max(item["input_ids"].shape[1] for item in features) |
@@ -238,7 +292,18 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
238 | 292 | image_paths = json.loads(item["image_paths"]) |
239 | 293 | if image_paths: |
240 | 294 | images = [load_image(p) for p in image_paths] |
241 | | - vision_enc = self.processor.image_processor(images=images, return_tensors="pt") |
| 295 | + if hasattr(self.processor, "image_processor"): |
| 296 | + vision_enc = self.processor.image_processor( |
| 297 | + images=images, |
| 298 | + return_tensors="pt", |
| 299 | + **self._resolved_image_processor_kwargs, |
| 300 | + ) |
| 301 | + else: |
| 302 | + vision_enc = self.processor( |
| 303 | + images=images, |
| 304 | + return_tensors="pt", |
| 305 | + **self._resolved_image_processor_kwargs, |
| 306 | + ) |
242 | 307 | all_pixel_values.append(vision_enc["pixel_values"]) |
243 | 308 | if "image_grid_thw" in vision_enc: |
244 | 309 | all_image_grid_thw.append(vision_enc["image_grid_thw"]) |
@@ -300,14 +365,29 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
300 | 365 |
|
301 | 366 | class VLMHunyuanDataCollatorWithPadding: |
302 | 367 |
|
303 | | - def __init__(self, processor=None): |
| 368 | + def __init__(self, processor=None, image_processor_kwargs=None): |
304 | 369 | """ |
305 | 370 | Args: |
306 | 371 | processor: VLM processor (e.g. AutoProcessor for hunyuan_vl). |
307 | 372 | When provided, image_paths in features will be decoded |
308 | 373 | on-the-fly to pixel_values (used in online training). |
| 374 | + image_processor_kwargs: Additional kwargs passed to image_processor, |
| 375 | + e.g. {"max_pixels": 1003520, "min_pixels": 200704}. |
309 | 376 | """ |
310 | 377 | self.processor = processor |
| 378 | + max_pixels = image_processor_kwargs.get("max_pixels", None) |
| 379 | + min_pixels = image_processor_kwargs.get("min_pixels", None) |
| 380 | + if ( |
| 381 | + processor is not None |
| 382 | + and (max_pixels is not None or min_pixels is not None) |
| 383 | + and hasattr(processor, "image_processor") |
| 384 | + ): |
| 385 | + self._resolved_image_processor_kwargs = build_image_processor_kwargs( |
| 386 | + processor.image_processor, max_pixels, min_pixels |
| 387 | + ) |
| 388 | + else: |
| 389 | + self._resolved_image_processor_kwargs = {} |
| 390 | + rank0_print(f"_resolved_image_processor_kwargs: {self._resolved_image_processor_kwargs}") |
311 | 391 |
|
312 | 392 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
313 | 393 | max_length = max(item["input_ids"].shape[1] for item in features) |
@@ -338,7 +418,18 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
338 | 418 | image_paths = json.loads(item["image_paths"]) |
339 | 419 | if image_paths: |
340 | 420 | images = [load_image(p) for p in image_paths] |
341 | | - vision_enc = self.processor(images=images, return_tensors="pt") |
| 421 | + if hasattr(self.processor, "image_processor"): |
| 422 | + vision_enc = self.processor.image_processor( |
| 423 | + images=images, |
| 424 | + return_tensors="pt", |
| 425 | + **self._resolved_image_processor_kwargs, |
| 426 | + ) |
| 427 | + else: |
| 428 | + vision_enc = self.processor( |
| 429 | + images=images, |
| 430 | + return_tensors="pt", |
| 431 | + **self._resolved_image_processor_kwargs, |
| 432 | + ) |
342 | 433 | all_pixel_values.append(vision_enc["pixel_values"]) |
343 | 434 | if "image_grid_thw" in vision_enc: |
344 | 435 | all_image_grid_thw.append(vision_enc["image_grid_thw"]) |
|
0 commit comments