Skip to content

Commit c8cf479

Browse files
Mandy3311hukongyi
authored andcommitted
address reviewer comments.
Co-authored-by: hukongyi <hukongyi@cmbchina.com>
1 parent d1ce873 commit c8cf479

6 files changed

Lines changed: 22 additions & 32 deletions

File tree

specforge/data/preprocessing.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,6 @@ def preprocess_vlm_conversations(
202202
- pixel_values: List of pixel values for images in the examples.
203203
- image_grid_thw: List of image grid tensors.
204204
"""
205-
system_prompt = chat_template.system_prompt
206-
207205
# prepare result
208206
results = {
209207
"input_ids": [],
@@ -213,16 +211,15 @@ def preprocess_vlm_conversations(
213211
"image_grid_thw": [],
214212
}
215213

216-
# Note: currently, we assume that each example has only one image
217-
for i, image in enumerate(examples["image"]):
214+
for i, images in enumerate(examples["images"]):
218215
source = examples["conversations"][i]
219216
messages = []
220217
# messages = [{"role": "system", "content": system_prompt}]
221218
if not source:
222219
# if the source is None, skip it
223220
continue
224221

225-
if not image:
222+
if not images:
226223
text_messages = []
227224
convroles = ["user", "assistant"]
228225
for j, sentence in enumerate(source):
@@ -267,26 +264,17 @@ def preprocess_vlm_conversations(
267264
source = source[1:]
268265

269266
convroles = ["user", "assistant"]
270-
has_added_image = False
267+
has_added_images = False
271268
for j, sentence in enumerate(source):
272269
role = sentence["role"]
273270
assert role == convroles[j % 2], f"unexpected role {role}"
274271
if role == "user":
275-
# if the message is from user and has image, process the image
276-
if not has_added_image:
277-
messages.append(
278-
{
279-
"role": role,
280-
"content": [
281-
{
282-
"type": "image",
283-
"image": image,
284-
},
285-
{"type": "text", "text": sentence["content"]},
286-
],
287-
}
288-
)
289-
has_added_image = True
272+
# Insert all images into the first user message
273+
if not has_added_images:
274+
content = [{"type": "image", "image": img} for img in images]
275+
content.append({"type": "text", "text": sentence["content"]})
276+
messages.append({"role": role, "content": content})
277+
has_added_images = True
290278
else:
291279
messages.append({"role": role, "content": sentence["content"]})
292280
else:
@@ -319,7 +307,7 @@ def preprocess_vlm_conversations(
319307
input_ids = encoding.input_ids[0]
320308
offsets = encoding.offset_mapping[0]
321309
pixel_values = encoding.pixel_values
322-
image_grid_thw = encoding.image_grid_thw[0]
310+
image_grid_thw = encoding.image_grid_thw # shape: (num_images, 3)
323311

324312
# get conversation with image info for loss mask generation
325313
decoded_conversation = processor.tokenizer.decode(
@@ -335,7 +323,7 @@ def preprocess_vlm_conversations(
335323
results["loss_mask"].append(loss_mask[None, :])
336324
results["attention_mask"].append(torch.ones_like(loss_mask)[None, :])
337325
results["pixel_values"].append(pixel_values)
338-
results["image_grid_thw"].append(image_grid_thw[None, :])
326+
results["image_grid_thw"].append(image_grid_thw)
339327
return results
340328

341329

specforge/data/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
205205
- attention_mask: torch.Tensor of shape (B, N)
206206
- loss_mask: torch.Tensor of shape (B, N)
207207
"""
208+
assert len(features) == 1, (
209+
f"VlmDataCollatorWithPadding requires batch_size=1, got {len(features)}. "
210+
"Set per_device_train_batch_size=1 in your training config."
211+
)
208212
max_length = max(item["input_ids"].shape[1] for item in features)
209213
batch_input_ids = torch.cat(
210214
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]

specforge/modeling/target/dflash_target_model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,8 @@ def _init_vlm_attributes(self):
110110
self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2)
111111
self.vlm_model_type = getattr(vision_config, "model_type", "")
112112

113-
text_config = getattr(hf_config, "text_config", hf_config)
114113
self.tokens_per_second = None
115114

116-
rope_params = getattr(text_config, "rope_parameters", {}) or {}
117-
self.mrope_interleaved = rope_params.get("mrope_interleaved", False)
118-
119115
@classmethod
120116
def from_pretrained(
121117
cls,
@@ -437,9 +433,9 @@ def generate_dflash_data(
437433
output_hidden_states=True,
438434
use_cache=False,
439435
)
440-
if pixel_values:
436+
if pixel_values is not None:
441437
model_kwargs["pixel_values"] = pixel_values
442-
if image_grid_thw:
438+
if image_grid_thw is not None:
443439
model_kwargs["image_grid_thw"] = image_grid_thw
444440
outputs = self.model(**model_kwargs)
445441

specforge/modeling/target/sglang_backend/patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def initialize_model_parallel(
9898
4 tensor model-parallel groups:
9999
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
100100
2 pipeline model-parallel groups:
101-
[g0, g2, g4, g6], [b1, g3, g5, g7]
101+
[g0, g2, g4, g6], [g1, g3, g5, g7]
102102
Note that for efficiency, the caller should make sure adjacent ranks
103103
are on the same DGX box. For example if we are using 2 DGX-1 boxes
104104
with a total of 16 GPUs, rank 0 to 7 belong to the first box and

specforge/modeling/target/target_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def from_pretrained(
6363
instance = cls(config)
6464

6565
if embed_key is None:
66-
embed_key = "model.embed_tokens.weight"
66+
if hasattr(config, "text_config") and config.text_config is not None:
67+
embed_key = "model.language_model.embed_tokens.weight"
68+
else:
69+
embed_key = "model.embed_tokens.weight"
6770
if lm_head_key is None:
6871
lm_head_key = "lm_head.weight"
6972

tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test_dense(rank, world_size, port, tp_size):
3232
device="cuda",
3333
attention_backend="fa3",
3434
mem_fraction_static=0.4,
35-
# enable_torch_compile=True,
3635
enable_nccl_nvls=True,
3736
# enable_symm_mem=True,
3837
enable_symm_mem=False,

0 commit comments

Comments
 (0)