Skip to content

Commit 7d9feb5

Browse files
Mandy3311hukongyi
authored andcommitted
address reviewer comments.
Co-authored-by: hukongyi <hukongyi@cmbchina.com>
1 parent 10f35f3 commit 7d9feb5

7 files changed

Lines changed: 25 additions & 33 deletions

File tree

scripts/train_dflash.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def parse_args():
8383
"Suggested: 7 for block_size=16, 5 for 10, 4 for 8. None disables.",
8484
)
8585
model_group.add_argument(
86-
"--embed-key", type=str, default="model.language_model.embed_tokens.weight"
86+
"--embed-key", type=str, default=None,
87+
help="Key for embedding weights in the target model checkpoint. "
88+
"Defaults to auto-detection based on model architecture.",
8789
)
8890
model_group.add_argument("--lm-head-key", type=str, default="lm_head.weight")
8991
model_group.add_argument("--is-vlm", action="store_true")

specforge/data/preprocessing.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ def preprocess_vlm_conversations(
201201
- pixel_values: List of pixel values for images in the examples.
202202
- image_grid_thw: List of image grid tensors.
203203
"""
204-
system_prompt = chat_template.system_prompt
205-
206204
# prepare result
207205
results = {
208206
"input_ids": [],
@@ -212,16 +210,15 @@ def preprocess_vlm_conversations(
212210
"image_grid_thw": [],
213211
}
214212

215-
# Note: currently, we assume that each example has only one image
216-
for i, image in enumerate(examples["image"]):
213+
for i, images in enumerate(examples["images"]):
217214
source = examples["conversations"][i]
218215
messages = []
219216
# messages = [{"role": "system", "content": system_prompt}]
220217
if not source:
221218
# if the source is None, skip it
222219
continue
223220

224-
if not image:
221+
if not images:
225222
text_messages = []
226223
convroles = ["user", "assistant"]
227224
for j, sentence in enumerate(source):
@@ -266,26 +263,17 @@ def preprocess_vlm_conversations(
266263
source = source[1:]
267264

268265
convroles = ["user", "assistant"]
269-
has_added_image = False
266+
has_added_images = False
270267
for j, sentence in enumerate(source):
271268
role = sentence["role"]
272269
assert role == convroles[j % 2], f"unexpected role {role}"
273270
if role == "user":
274-
# if the message is from user and has image, process the image
275-
if not has_added_image:
276-
messages.append(
277-
{
278-
"role": role,
279-
"content": [
280-
{
281-
"type": "image",
282-
"image": image,
283-
},
284-
{"type": "text", "text": sentence["content"]},
285-
],
286-
}
287-
)
288-
has_added_image = True
271+
# Insert all images into the first user message
272+
if not has_added_images:
273+
content = [{"type": "image", "image": img} for img in images]
274+
content.append({"type": "text", "text": sentence["content"]})
275+
messages.append({"role": role, "content": content})
276+
has_added_images = True
289277
else:
290278
messages.append({"role": role, "content": sentence["content"]})
291279
else:
@@ -318,7 +306,7 @@ def preprocess_vlm_conversations(
318306
input_ids = encoding.input_ids[0]
319307
offsets = encoding.offset_mapping[0]
320308
pixel_values = encoding.pixel_values
321-
image_grid_thw = encoding.image_grid_thw[0]
309+
image_grid_thw = encoding.image_grid_thw # shape: (num_images, 3)
322310

323311
# get conversation with image info for loss mask generation
324312
decoded_conversation = processor.tokenizer.decode(
@@ -334,7 +322,7 @@ def preprocess_vlm_conversations(
334322
results["loss_mask"].append(loss_mask[None, :])
335323
results["attention_mask"].append(torch.ones_like(loss_mask)[None, :])
336324
results["pixel_values"].append(pixel_values)
337-
results["image_grid_thw"].append(image_grid_thw[None, :])
325+
results["image_grid_thw"].append(image_grid_thw)
338326
return results
339327

340328

specforge/data/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
205205
- attention_mask: torch.Tensor of shape (B, N)
206206
- loss_mask: torch.Tensor of shape (B, N)
207207
"""
208+
assert len(features) == 1, (
209+
f"VlmDataCollatorWithPadding requires batch_size=1, got {len(features)}. "
210+
"Set per_device_train_batch_size=1 in your training config."
211+
)
208212
max_length = max(item["input_ids"].shape[1] for item in features)
209213
batch_input_ids = torch.cat(
210214
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]

specforge/modeling/target/dflash_target_model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,8 @@ def _init_vlm_attributes(self):
110110
self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2)
111111
self.vlm_model_type = getattr(vision_config, "model_type", "")
112112

113-
text_config = getattr(hf_config, "text_config", hf_config)
114113
self.tokens_per_second = None
115114

116-
rope_params = getattr(text_config, "rope_parameters", {}) or {}
117-
self.mrope_interleaved = rope_params.get("mrope_interleaved", False)
118-
119115
@classmethod
120116
def from_pretrained(
121117
cls,
@@ -434,9 +430,9 @@ def generate_dflash_data(
434430
output_hidden_states=True,
435431
use_cache=False,
436432
)
437-
if pixel_values:
433+
if pixel_values is not None:
438434
model_kwargs["pixel_values"] = pixel_values
439-
if image_grid_thw:
435+
if image_grid_thw is not None:
440436
model_kwargs["image_grid_thw"] = image_grid_thw
441437
outputs = self.model(**model_kwargs)
442438

specforge/modeling/target/sglang_backend/patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def initialize_model_parallel(
9898
4 tensor model-parallel groups:
9999
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
100100
2 pipeline model-parallel groups:
101-
[g0, g2, g4, g6], [b1, g3, g5, g7]
101+
[g0, g2, g4, g6], [g1, g3, g5, g7]
102102
Note that for efficiency, the caller should make sure adjacent ranks
103103
are on the same DGX box. For example if we are using 2 DGX-1 boxes
104104
with a total of 16 GPUs, rank 0 to 7 belong to the first box and

specforge/modeling/target/target_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def from_pretrained(
5252
instance = cls(config)
5353

5454
if embed_key is None:
55-
embed_key = "model.embed_tokens.weight"
55+
if hasattr(config, "text_config") and config.text_config is not None:
56+
embed_key = "model.language_model.embed_tokens.weight"
57+
else:
58+
embed_key = "model.embed_tokens.weight"
5659
if lm_head_key is None:
5760
lm_head_key = "lm_head.weight"
5861

tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test_dense(rank, world_size, port, tp_size):
3232
device="cuda",
3333
attention_backend="fa3",
3434
mem_fraction_static=0.4,
35-
# enable_torch_compile=True,
3635
enable_nccl_nvls=True,
3736
enable_symm_mem=False,
3837
enable_torch_compile=True,

0 commit comments

Comments
 (0)