Skip to content

Commit fc0a995

Browse files
committed
feat: speculative decoding support image size control by max_pixels/min_pixels
1 parent 1c786f2 commit fc0a995

3 files changed

Lines changed: 154 additions & 13 deletions

File tree

angelslim/compressor/speculative/train/data/data_utils.py

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919
from transformers.image_utils import load_image
2020

21+
from angelslim.utils import rank0_print
22+
2123
__all__ = [
2224
"process_token_dict_to_mappings",
2325
"convert_sharegpt_data",
@@ -27,9 +29,46 @@
2729
"VLMHunyuanDataCollatorWithPadding",
2830
"AudioDataCollatorWithPadding",
2931
"CosyVoice3DataCollatorWithPadding",
32+
"build_image_processor_kwargs",
3033
]
3134

3235

36+
def build_image_processor_kwargs(image_processor, max_pixels=None, min_pixels=None):
37+
"""
38+
convert max_pixels/min_pixels to the format required by the specific image_processor.
39+
- Qwen2.5-VL: directly use max_pixels / min_pixels
40+
- Qwen3-VL: convert to size={"longest_edge": max_pixels, "shortest_edge": min_pixels}
41+
42+
Args:
43+
image_processor: model's image_processor instance
44+
max_pixels: maximum pixels (total area), None means no limit
45+
min_pixels: minimum pixels (total area), None means no limit
46+
47+
Returns:
48+
dict: can be directly passed to image_processor(...)
49+
"""
50+
if max_pixels is None and min_pixels is None:
51+
return {}
52+
53+
processor_class = type(image_processor).__name__
54+
# Qwen3-VL uses size={"longest_edge": ..., "shortest_edge": ...}
55+
if "Qwen3" in processor_class:
56+
size = {}
57+
if max_pixels is not None:
58+
size["longest_edge"] = max_pixels
59+
if min_pixels is not None:
60+
size["shortest_edge"] = min_pixels
61+
return {"size": size}
62+
else:
63+
# Qwen2.5-VL's accept max_pixels and min_pixels
64+
kwargs = {}
65+
if max_pixels is not None:
66+
kwargs["max_pixels"] = max_pixels
67+
if min_pixels is not None:
68+
kwargs["min_pixels"] = min_pixels
69+
return kwargs
70+
71+
3372
def convert_sharegpt_data(row, dataset_column="conversations"):
3473
converted_messages = []
3574

@@ -78,19 +117,19 @@ def process_token_dict_to_mappings(
78117
token_dict[token] = 0
79118
if len(token_dict) >= draft_vocab_size:
80119
break
81-
print(f"Added missing tokens to reach draft vocab size: {draft_vocab_size}")
82-
print(f"Total tokens after addition: {len(token_dict)}")
120+
rank0_print(f"Added missing tokens to reach draft vocab size: {draft_vocab_size}")
121+
rank0_print(f"Total tokens after addition: {len(token_dict)}")
83122
total_frequency = sum(token_dict.values())
84123
top_N = token_dict.most_common(draft_vocab_size)
85124
top_N_frequency_sum = sum(freq for key, freq in top_N)
86125

87126
if total_frequency == 0:
88-
print("Warning: Total token frequency is zero. All tokens will have zero ratio.")
127+
rank0_print("Warning: Total token frequency is zero. All tokens will have zero ratio.")
89128
top_N_ratio = 0.0
90129
else:
91130
top_N_ratio = top_N_frequency_sum / total_frequency
92131

93-
print(f"top {draft_vocab_size} token frequency ratio: {top_N_ratio:.2%}")
132+
rank0_print(f"top {draft_vocab_size} token frequency ratio: {top_N_ratio:.2%}")
94133
used_tokens = [key for key, freq in top_N]
95134
used_tokens.sort()
96135

@@ -199,14 +238,29 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
199238

200239
class VLMDataCollatorWithPadding:
201240

202-
def __init__(self, processor=None):
241+
def __init__(self, processor=None, image_processor_kwargs=None):
203242
"""
204243
Args:
205244
processor: VLM processor (e.g. AutoProcessor for qwen3_vl).
206245
When provided, image_paths in features will be decoded
207246
on-the-fly to pixel_values (used in online training).
247+
image_processor_kwargs: Additional kwargs passed to image_processor,
248+
e.g. {"max_pixels": 1003520, "min_pixels": 200704}.
208249
"""
209250
self.processor = processor
251+
max_pixels = image_processor_kwargs.get("max_pixels", None)
252+
min_pixels = image_processor_kwargs.get("min_pixels", None)
253+
if (
254+
processor is not None
255+
and (max_pixels is not None or min_pixels is not None)
256+
and hasattr(processor, "image_processor")
257+
):
258+
self._resolved_image_processor_kwargs = build_image_processor_kwargs(
259+
processor.image_processor, max_pixels, min_pixels
260+
)
261+
else:
262+
self._resolved_image_processor_kwargs = {}
263+
rank0_print(f"_resolved_image_processor_kwargs: {self._resolved_image_processor_kwargs}")
210264

211265
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
212266
max_length = max(item["input_ids"].shape[1] for item in features)
@@ -238,7 +292,18 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
238292
image_paths = json.loads(item["image_paths"])
239293
if image_paths:
240294
images = [load_image(p) for p in image_paths]
241-
vision_enc = self.processor.image_processor(images=images, return_tensors="pt")
295+
if hasattr(self.processor, "image_processor"):
296+
vision_enc = self.processor.image_processor(
297+
images=images,
298+
return_tensors="pt",
299+
**self._resolved_image_processor_kwargs,
300+
)
301+
else:
302+
vision_enc = self.processor(
303+
images=images,
304+
return_tensors="pt",
305+
**self._resolved_image_processor_kwargs,
306+
)
242307
all_pixel_values.append(vision_enc["pixel_values"])
243308
if "image_grid_thw" in vision_enc:
244309
all_image_grid_thw.append(vision_enc["image_grid_thw"])
@@ -300,14 +365,29 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
300365

301366
class VLMHunyuanDataCollatorWithPadding:
302367

303-
def __init__(self, processor=None):
368+
def __init__(self, processor=None, image_processor_kwargs=None):
304369
"""
305370
Args:
306371
processor: VLM processor (e.g. AutoProcessor for hunyuan_vl).
307372
When provided, image_paths in features will be decoded
308373
on-the-fly to pixel_values (used in online training).
374+
image_processor_kwargs: Additional kwargs passed to image_processor,
375+
e.g. {"max_pixels": 1003520, "min_pixels": 200704}.
309376
"""
310377
self.processor = processor
378+
max_pixels = image_processor_kwargs.get("max_pixels", None)
379+
min_pixels = image_processor_kwargs.get("min_pixels", None)
380+
if (
381+
processor is not None
382+
and (max_pixels is not None or min_pixels is not None)
383+
and hasattr(processor, "image_processor")
384+
):
385+
self._resolved_image_processor_kwargs = build_image_processor_kwargs(
386+
processor.image_processor, max_pixels, min_pixels
387+
)
388+
else:
389+
self._resolved_image_processor_kwargs = {}
390+
rank0_print(f"_resolved_image_processor_kwargs: {self._resolved_image_processor_kwargs}")
311391

312392
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
313393
max_length = max(item["input_ids"].shape[1] for item in features)
@@ -338,7 +418,18 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
338418
image_paths = json.loads(item["image_paths"])
339419
if image_paths:
340420
images = [load_image(p) for p in image_paths]
341-
vision_enc = self.processor(images=images, return_tensors="pt")
421+
if hasattr(self.processor, "image_processor"):
422+
vision_enc = self.processor.image_processor(
423+
images=images,
424+
return_tensors="pt",
425+
**self._resolved_image_processor_kwargs,
426+
)
427+
else:
428+
vision_enc = self.processor(
429+
images=images,
430+
return_tensors="pt",
431+
**self._resolved_image_processor_kwargs,
432+
)
342433
all_pixel_values.append(vision_enc["pixel_values"])
343434
if "image_grid_thw" in vision_enc:
344435
all_image_grid_thw.append(vision_enc["image_grid_thw"])

angelslim/compressor/speculative/train/data/dataset_builder/online_dataset_builder.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DataCollatorWithPadding,
4242
VLMDataCollatorWithPadding,
4343
VLMHunyuanDataCollatorWithPadding,
44+
build_image_processor_kwargs,
4445
)
4546
from .base_dataset_builder import OnlineDatasetBuilder
4647
from .dataset_builder_factory import DatasetBuilderFactory
@@ -87,6 +88,11 @@ def __init__(
8788
chat_template_type,
8889
display,
8990
)
91+
_max_pixels = os.environ.get("MAX_PIXELS")
92+
_min_pixels = os.environ.get("MIN_PIXELS")
93+
self.max_pixels = int(_max_pixels) if _max_pixels is not None else None
94+
self.min_pixels = int(_min_pixels) if _min_pixels is not None else None
95+
rank0_print(f"max_pixels: {self.max_pixels}, min_pixels: {self.min_pixels}")
9096

9197
def build_dataset(
9298
self,
@@ -168,7 +174,15 @@ def build_dataset(
168174

169175
def get_data_collator(self) -> Any:
170176
# for online vlm training: dynamically compute pixel_values during collate stage
171-
return VLMDataCollatorWithPadding(processor=self.tokenizer)
177+
image_processor_kwargs = {}
178+
if self.max_pixels is not None:
179+
image_processor_kwargs["max_pixels"] = self.max_pixels
180+
if self.min_pixels is not None:
181+
image_processor_kwargs["min_pixels"] = self.min_pixels
182+
return VLMDataCollatorWithPadding(
183+
processor=self.tokenizer,
184+
image_processor_kwargs=image_processor_kwargs or None,
185+
)
172186

173187
def _preprocess_function(self, examples: Dict[str, List]) -> Dict[str, List]:
174188
new_examples = {
@@ -258,6 +272,12 @@ def _process_single_conversation(self, conversation_data: List[Dict]) -> Optiona
258272
del message["content"]
259273
message["content"] = new_content
260274

275+
image_kwargs = {}
276+
if image_paths and hasattr(self.tokenizer, "image_processor"):
277+
image_kwargs = build_image_processor_kwargs(
278+
self.tokenizer.image_processor, self.max_pixels, self.min_pixels
279+
)
280+
261281
encoding = self.tokenizer.apply_chat_template(
262282
messages,
263283
tokenize=True,
@@ -268,6 +288,7 @@ def _process_single_conversation(self, conversation_data: List[Dict]) -> Optiona
268288
max_length=self.max_length,
269289
truncation=True,
270290
padding=False,
291+
**image_kwargs,
271292
)
272293

273294
input_ids = encoding["input_ids"]
@@ -326,6 +347,11 @@ def __init__(
326347
chat_template_type,
327348
display,
328349
)
350+
_max_pixels = os.environ.get("MAX_PIXELS")
351+
_min_pixels = os.environ.get("MIN_PIXELS")
352+
self.max_pixels = int(_max_pixels) if _max_pixels is not None else None
353+
self.min_pixels = int(_min_pixels) if _min_pixels is not None else None
354+
rank0_print(f"max_pixels: {self.max_pixels}, min_pixels: {self.min_pixels}")
329355

330356
def build_dataset(
331357
self,
@@ -404,7 +430,15 @@ def build_dataset(
404430

405431
def get_data_collator(self) -> Any:
406432
# for online training, we need to use VLMHunyuanDataCollatorWithPadding
407-
return VLMHunyuanDataCollatorWithPadding(processor=self.tokenizer)
433+
image_processor_kwargs = {}
434+
if self.max_pixels is not None:
435+
image_processor_kwargs["max_pixels"] = self.max_pixels
436+
if self.min_pixels is not None:
437+
image_processor_kwargs["min_pixels"] = self.min_pixels
438+
return VLMHunyuanDataCollatorWithPadding(
439+
processor=self.tokenizer,
440+
image_processor_kwargs=image_processor_kwargs or None,
441+
)
408442

409443
def _preprocess_function(self, examples: Dict[str, List]) -> Dict[str, List]:
410444
new_examples = {
@@ -482,6 +516,11 @@ def _process_single_conversation(self, conversation_data: List[Dict]) -> Optiona
482516
)
483517
image_inputs, _ = self._extract_vision_info(messages)
484518

519+
image_kwargs = {}
520+
if image_inputs and hasattr(self.tokenizer, "image_processor"):
521+
image_kwargs = build_image_processor_kwargs(
522+
self.tokenizer.image_processor, self.max_pixels, self.min_pixels
523+
)
485524
encoding = self.tokenizer(
486525
text=[text],
487526
images=image_inputs,
@@ -490,6 +529,7 @@ def _process_single_conversation(self, conversation_data: List[Dict]) -> Optiona
490529
max_length=self.max_length,
491530
truncation=True,
492531
padding=False,
532+
**image_kwargs,
493533
)
494534
input_ids = encoding["input_ids"]
495535
offsets = encoding["offset_mapping"]

tools/generate_hidden_for_draft_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
infer_model_params,
3434
)
3535
from angelslim.compressor.speculative.train.data.data_utils import (
36+
build_image_processor_kwargs,
3637
process_token_dict_to_mappings,
3738
)
3839
from angelslim.utils import decide_device_for_distributed
@@ -102,6 +103,10 @@ def __init__(
102103
self.rank = rank
103104
self.draft_vocab_size = draft_vocab_size
104105
self.target_vocab_size = target_vocab_size
106+
_max_pixels = os.environ.get("MAX_PIXELS")
107+
_min_pixels = os.environ.get("MIN_PIXELS")
108+
self.max_pixels = int(_max_pixels) if _max_pixels is not None else None
109+
self.min_pixels = int(_min_pixels) if _min_pixels is not None else None
105110
self.output_dir.mkdir(parents=True, exist_ok=True)
106111
self.token_dict = Counter()
107112

@@ -151,11 +156,17 @@ def _process_single_sample(self, idx: int, row: Dict[str, Any]) -> bool:
151156
images = [load_image(p) for p in image_paths]
152157
processor = self.target_model.tokenizer
153158
if hasattr(processor, "image_processor"):
159+
kwargs = build_image_processor_kwargs(
160+
processor.image_processor, self.max_pixels, self.min_pixels
161+
)
154162
vision_encoding = processor.image_processor(
155-
images=images, return_tensors="pt"
163+
images=images, return_tensors="pt", **kwargs
156164
)
157165
else:
158-
vision_encoding = processor(images=images, return_tensors="pt")
166+
kwargs = build_image_processor_kwargs(
167+
processor, self.max_pixels, self.min_pixels
168+
)
169+
vision_encoding = processor(images=images, return_tensors="pt", **kwargs)
159170
row["pixel_values"] = vision_encoding["pixel_values"].to(device)
160171
if "video_pixel_values" in vision_encoding:
161172
row["video_pixel_values"] = vision_encoding["video_pixel_values"].to(
@@ -406,7 +417,6 @@ def parse_arguments() -> argparse.Namespace:
406417
help="Path to draft model config file, used to read draft_vocab_size and vocab_size "
407418
"for computing vocab mapping",
408419
)
409-
410420
return parser.parse_args()
411421

412422

0 commit comments

Comments
 (0)