Skip to content

[WIP] feat: Transformers 5.0 compatibility#142

Merged
kcz358 merged 22 commits intomainfrom
transformer_5.0
Mar 12, 2026
Merged

[WIP] feat: Transformers 5.0 compatibility#142
kcz358 merged 22 commits intomainfrom
transformer_5.0

Conversation

@kcz358
Copy link
Copy Markdown
Collaborator

@kcz358 kcz358 commented Mar 8, 2026

Summary

This PR adds compatibility support for transformers >= 5.0 while maintaining backward compatibility with transformers 4.x.

Changes

  • Conditional model imports: Models incompatible with transformers >= 5.0 are now conditionally imported based on version

    • dream_dllm, qwen3_dllm, llada_dllm require transformers < 5.0
    • llava_onevision1_5 requires transformers < 5.0
    • These models are dynamically excluded from __all__ when using transformers >= 5.0
  • Training config: Added group_by_length parameter for backward compatibility

  • Dependencies: Updated transformers requirement from exact version to >= 4.57.1

Status

⚠️ Work in Progress - This PR is not complete. More fixes and compatibility improvements will be added.

Testing

  • ✅ Compatible models import successfully with transformers 5.3.0
  • ✅ Incompatible models raise clear ImportError with transformers >= 5.0
  • __all__ list correctly excludes incompatible models
  • ⏳ Full testing with transformers 4.x pending

Related Issues

Addresses compatibility issues with transformers 5.0+ upgrade

kcz358 added 4 commits March 8, 2026 10:03
Conditionally import models incompatible with transformers >= 5.0:
- dream_dllm, qwen3_dllm, llada_dllm require transformers < 5.0
- llava_onevision1_5 requires transformers < 5.0
- Dynamically update __all__ based on transformers version
- Prevents ImportError when using transformers 5.0+
Add group_by_length parameter to TrainingArguments to maintain
compatibility with existing training configurations.
Update transformers dependency from exact version to minimum version
to support transformers 5.0+ while maintaining backward compatibility.
@kcz358 kcz358 force-pushed the transformer_5.0 branch from 8b0a65b to 4ab190d Compare March 8, 2026 10:04
kcz358 and others added 6 commits March 8, 2026 11:20
…al_tokens

Use all_special_tokens for transformers >= 5.0 compatibility while
maintaining backward compatibility with transformers < 5.0.

Changes:
- Add special_tokens property to all processor classes
- Use all_special_tokens if available (transformers >= 5.0)
- Fall back to additional_special_tokens (transformers < 5.0)
- Add <|im_start|> and <||im_end|> tokens to special_tokens list
- Cache special_tokens as instance attribute for performance

Affected processors:
- AeroDataProcessor (base class)
- BaseQwen2_5_DataProcessor (inherits from AeroDataProcessor)
- Qwen2VLDataProcessor
- Qwen2DataProcessor
- LLaVADataProcessor
- LLaVAVideoDataProcessor (inherits from LLaVADataProcessor)
- NanovlmDataProcessor
- Qwen3_VLDataProcessor (inherits from BaseQwen2_5_DataProcessor)
Use processor.apply_chat_template with tokenize=True consistently
across all processors instead of mixing with processor.tokenizer calls.

Changes:
- aero_processor: use processor.apply_chat_template(tokenize=True)[0]
- base_qwen2_5_processor: use processor.apply_chat_template(tokenize=True)[0]
- qwen2_vl_processor: use processor.apply_chat_template(tokenize=True)
- qwen3_vl_processor: use processor.apply_chat_template(tokenize=True)[0]

This ensures all processors return token IDs directly during data
preparation, improving consistency and reducing confusion.
Extract rope index calculation functions into common_ops/rope.py to
ensure consistent behavior across transformers versions.

Changes:
- Add common_ops/rope.py with qwen2_5_vl_rope_index and qwen3_vl_get_rope_index
- Update qwen2_5_vl_ops.py to use qwen2_5_vl_rope_index
- Update qwen3_vl_ops.py to use qwen3_vl_get_rope_index
- Update qwen3_vl_moe_ops.py to use qwen3_vl_get_rope_index

This ensures rope index calculations remain stable even when transformers
internal implementations change.
Add NVIDIA B200/B300 GPU FLOPS (2.25e15) to get_device_flops()
to fix MFU calculation returning 0 on B200 GPUs.

Previously, unknown GPU types returned inf FLOPS, causing MFU
to always be 0.
@kcz358
Copy link
Copy Markdown
Collaborator Author

kcz358 commented Mar 8, 2026

Now passes the cicd test for qwen3 vl model on transformers 5.0

kcz358 and others added 5 commits March 9, 2026 01:37
- Fix vision_model variable reference in liger kernel patch
- Support nested text_config in lce_forward
- Handle rope_scaling/rope_parameters for transformers 5.0+
- Add qwen2_5_vl to FlopsCounter model type mapping
…ormers 5.0 compatibility

- Add apply_chat_template utility method to DataUtilities
- Handles dict-like return values (BatchEncoding) with use_key param
- Handles nested list wrapping from some processors
- Update all processors to use unified method
…lity

Filter unsupported TrainingArguments parameters by inspecting
transformers.TrainingArguments.__init__ signature, avoiding errors
from deprecated or removed parameters in newer versions.
Visual model methods (get_image_features, get_video_features, visual())
may return tuples OR dataclass objects (BaseModelOutputWithPooling,
BaseModelOutputWithDeepstackFeatures) in transformers 5.0+.

Add parse_visual_output() to transparently handle both return types.
* [feat] Support Qwen3_5 Training

* style: auto-fix lint (black + isort)

* [feat] Support Qwen3.5 Training

* optimize qwen3.5 dataset process logic

* optimize qwen3.5 dataset process logic

* flop function leave empty

---------

Co-authored-by: charlesswu <charlesswu@tencent.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
@KemingWu KemingWu marked this pull request as ready for review March 10, 2026 08:11
@KemingWu KemingWu self-requested a review March 10, 2026 08:12
Copy link
Copy Markdown
Collaborator

@KemingWu KemingWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@mwxely mwxely self-requested a review March 10, 2026 09:07
Comment on lines +32 to +39
@property
def special_tokens(self):
if not hasattr(self, "_special_tokens"):
if hasattr(self.processor.tokenizer, "all_special_tokens"):
self._special_tokens = list(self.processor.tokenizer.all_special_tokens)
else:
self._special_tokens = list(self.processor.tokenizer.additional_special_tokens)
return self._special_tokens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate special_tokens property. Looks like a copy-paste error. Can remove one of them.

Comment on lines 179 to 180
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.to(inputs_embeds.device, inputs_embeds.dtype) is called twice. Should remove one of them.

Comment on lines 196 to 197
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.to(inputs_embeds.device, inputs_embeds.dtype) is called twice. Should remove one of them.


if get_ulysses_sequence_parallel_world_size() > 1:
_, position_ids, pad_size = ulysses_pad(
input_ids_rmpad,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_ids_rmpad is only defined inside the if cu_seq_lens is None and input_ids is not None branch (line 54), but referenced unconditionally when ulysses_sp_size > 1 here. If inputs_embeds is provided instead of input_ids (e.g., when a VL model calls this text model after pre-computing embeddings), this will raise NameError. The elif branch on line 60 handles inputs_embeds but never defines input_ids_rmpad.

Comment on lines +44 to +49
def parse_visual_output(output):
if isinstance(output, tuple):
return output
if hasattr(output, "pooler_output"):
return output.pooler_output
return output
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_visual_output is defined independently in 4 files (qwen2_5_vl_ops.py, qwen2_5_omni_liger.py, qwen3_vl_ops.py, qwen3_vl_moe_ops.py). The qwen2.5 versions return a single tensor while the qwen3 versions return a (embeds, deepstack_features) tuple. Since this PR already introduced common_ops/rope.py for shared rope functions, can consider doing the same for parse_visual_output.

Comment on lines +23 to +34
@property
def special_tokens(self):
if not hasattr(self, "_special_tokens"):
if hasattr(self.processor.tokenizer, "all_special_tokens"):
self._special_tokens = list(self.processor.tokenizer.all_special_tokens)
else:
self._special_tokens = list(self.processor.tokenizer.additional_special_tokens)
if "<|im_start|>" not in self._special_tokens:
self._special_tokens.append("<|im_start|>")
if "<|im_end|>" not in self._special_tokens:
self._special_tokens.append("<|im_end|>")
return self._special_tokens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar special_tokens property logic is added across multiple processor files (aero, nanovlm, qwen2, qwen2_vl, llava). qwen2_processor accesses self.processor directly rather than self.processor.tokenizer. But the core pattern (check all_special_tokens → fallback to additional_special_tokens → cache in _special_tokens) is the same. Worth considering a mixin or base class helper to reduce duplication.

@kcz358 kcz358 merged commit 749eeef into main Mar 12, 2026
3 checks passed
@kcz358 kcz358 deleted the transformer_5.0 branch March 12, 2026 05:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants