diff --git a/examples/qwen3_5/example_config.yaml b/examples/qwen3_5/example_config.yaml new file mode 100644 index 00000000..f9abb002 --- /dev/null +++ b/examples/qwen3_5/example_config.yaml @@ -0,0 +1,166 @@ +trainer_type: fsdp2_trainer +dataset_config: + extra_kwargs: {} + dataset_type: qwen3_vl_iterable + dataset_format: yaml + processor_config: + processor_name: Qwen/Qwen3-VL-8B-Instruct + processor_type: qwen3_vl + dataset_path: data/video/debug.yaml + datasets: null + shuffle: true + eval_dataset_path: null + object_storage: none + bucket_name: null + packing: false + packing_strategy: first_fit + packing_length: 51200 + filter_overlong: true + filter_overlong_workers: 8 + max_length: null + video_sampling_strategy: fps + video_max_pixels: 50176 + video_max_frames: 512 + frame_num: 64 + fps: 1 + video_backend: qwen_vl_utils +trainer_args: + output_dir: ./output/qwen3_5_training + do_train: false + do_eval: false + do_predict: false + eval_strategy: 'no' + prediction_loss_only: false + per_device_train_batch_size: 1 + per_device_eval_batch_size: 8 + gradient_accumulation_steps: 1 + eval_accumulation_steps: null + eval_delay: 0 + torch_empty_cache_steps: null + learning_rate: 0.0002 + weight_decay: 0.0 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1.0e-08 + max_grad_norm: 1.0 + num_train_epochs: 1 + max_steps: 1000 + lr_scheduler_type: cosine + lr_scheduler_kwargs: {} + warmup_ratio: 0.1 + warmup_steps: 0 + log_level: passive + log_level_replica: warning + log_on_each_node: true + logging_dir: ./output/qwen3_5_training/runs + logging_strategy: steps + logging_first_step: false + logging_steps: 1 + logging_nan_inf_filter: true + save_strategy: steps + save_steps: 1000 + save_total_limit: 1 + save_on_each_node: false + save_only_model: false + restore_callback_states_from_checkpoint: false + use_cpu: false + seed: 42 + data_seed: null + bf16: true + fp16: false + bf16_full_eval: false + fp16_full_eval: false + tf32: null + local_rank: 0 + ddp_backend: null + debug: [] + dataloader_drop_last: false + eval_steps: null + dataloader_num_workers: 0 + dataloader_prefetch_factor: null + run_name: qwen3_5_debug + disable_tqdm: false + remove_unused_columns: true + label_names: null + load_best_model_at_end: false + metric_for_best_model: null + greater_is_better: null + ignore_data_skip: false + fsdp: [] + fsdp_config: + transformer_layer_cls_to_wrap: + - Qwen3_5DecoderLayer + reshard_after_forward: false + min_num_params: 0 + xla: false + xla_fsdp_v2: false + xla_fsdp_grad_ckpt: false + accelerator_config: + split_batches: false + dispatch_batches: null + even_batches: true + use_seedable_sampler: true + non_blocking: false + gradient_accumulation_kwargs: null + parallelism_config: null + deepspeed: null + label_smoothing_factor: 0.0 + optim: adamw_torch_fused + optim_args: null + length_column_name: length + report_to: [] + project: huggingface + trackio_space_id: trackio + ddp_find_unused_parameters: null + ddp_bucket_cap_mb: null + ddp_broadcast_buffers: null + dataloader_pin_memory: true + dataloader_persistent_workers: false + skip_memory_metrics: true + push_to_hub: false + resume_from_checkpoint: null + hub_model_id: null + hub_strategy: every_save + hub_token: + hub_private_repo: null + hub_always_push: false + hub_revision: null + gradient_checkpointing: true + gradient_checkpointing_kwargs: null + include_for_metrics: [] + eval_do_concat_batches: true + auto_find_batch_size: false + full_determinism: false + ddp_timeout: 1800 + torch_compile: false + torch_compile_backend: null + torch_compile_mode: null + include_num_input_tokens_seen: 'no' + neftune_noise_alpha: null + optim_target_modules: null + batch_eval_metrics: false + eval_on_start: false + use_liger_kernel: true + liger_kernel_config: null + eval_use_gather_object: false + average_tokens_across_devices: true + use_muon: false + freeze_modules: null + use_rmpad: true + fsdp2: true + sp_ulysses_degree: 1 + reduce_dtype: bfloat16 + output_dtype: bfloat16 + print_batch_input_steps: 5 + enable_profiler: false + profiler_config: + start_step: 1 + end_step: 3 +model_config: + extra_kwargs: {} + load_from_pretrained_path: Qwen/Qwen3.5-VL-8B-Instruct + load_from_config: null + attn_implementation: flash_attention_2 + overwrite_config: null + monkey_patch_kwargs: null +extra_kwargs: null diff --git a/src/lmms_engine/models/__init__.py b/src/lmms_engine/models/__init__.py index 143abd10..e71204a6 100644 --- a/src/lmms_engine/models/__init__.py +++ b/src/lmms_engine/models/__init__.py @@ -17,6 +17,7 @@ from .qwen2_5_vl import apply_liger_kernel_to_qwen2_5_vl from .qwen2_audio import apply_liger_kernel_to_qwen2_audio from .qwen3 import apply_liger_kernel_to_qwen3 +from .qwen3_5 import apply_liger_kernel_to_qwen3_5 from .qwen3_moe import apply_liger_kernel_to_qwen3_moe from .qwen3_omni_moe import ( Qwen3OmniMoeThinkerConfig, @@ -48,6 +49,7 @@ "apply_liger_kernel_to_qwen2_5_omni", "apply_liger_kernel_to_qwen2_5_vl", "apply_liger_kernel_to_qwen2_audio", + "apply_liger_kernel_to_qwen3_5", "apply_liger_kernel_to_qwen3_vl", "apply_liger_kernel_to_qwen3_vl_moe", "apply_liger_kernel_to_qwen3_moe", diff --git a/src/lmms_engine/models/qwen3_5/__init__.py b/src/lmms_engine/models/qwen3_5/__init__.py new file mode 100644 index 00000000..fa0de525 --- /dev/null +++ b/src/lmms_engine/models/qwen3_5/__init__.py @@ -0,0 +1,3 @@ +from .monkey_patch import apply_liger_kernel_to_qwen3_5 + +__all__ = ["apply_liger_kernel_to_qwen3_5"] diff --git a/src/lmms_engine/models/qwen3_5/monkey_patch.py b/src/lmms_engine/models/qwen3_5/monkey_patch.py new file mode 100644 index 00000000..059bd3e9 --- /dev/null +++ b/src/lmms_engine/models/qwen3_5/monkey_patch.py @@ -0,0 +1,106 @@ +from functools import partial, wraps + +try: + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + from liger_kernel.transformers.monkey_patch import ( + _patch_rms_norm_module, + _patch_swiglu_module, + ) + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.rope import liger_rotary_pos_emb + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +except Exception: + print("liger kernel not installed, please install it with `pip install liger-kernel`") + +from loguru import logger +from transformers import PreTrainedModel + +from lmms_engine.models.monkey_patch import MONKEY_PATCHER + + +@MONKEY_PATCHER.register("qwen3_5_text", "liger") +def apply_liger_kernel_to_qwen3_5( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, + use_rmpad: bool = False, +) -> None: + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.qwen3_5 import modeling_qwen3_5 + + if rope: + modeling_qwen3_5.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNorm + + if fused_linear_cross_entropy: + from .qwen3_5_liger import qwen3_5_lce_forward + + if use_rmpad: + + def wrap_forward(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(use_rmpad=use_rmpad, *args, **kwargs) + + return wrapper + + qwen3_5_lce_forward = wrap_forward(qwen3_5_lce_forward) + modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward + + if swiglu: + modeling_qwen3_5.Qwen3_5MLP = LigerSwiGLUMLP + + if use_rmpad: + from .qwen3_5_ops import attn_forward as qwen3_5_ops_attn_forward + from .qwen3_5_ops import ( + decoder_layer_forward as qwen3_5_ops_decoder_layer_forward, + ) + from .qwen3_5_ops import model_forward as qwen3_5_ops_model_forward + + modeling_qwen3_5.Qwen3_5TextModel.forward = qwen3_5_ops_model_forward + modeling_qwen3_5.Qwen3_5DecoderLayer.forward = qwen3_5_ops_decoder_layer_forward + modeling_qwen3_5.Qwen3_5Attention.forward = qwen3_5_ops_attn_forward + + if model is not None: + from transformers.models.qwen3_5.modeling_qwen3_5 import ( + Qwen3_5ForCausalLM, + Qwen3_5TextModel, + ) + + if isinstance(model, Qwen3_5ForCausalLM): + base_model: Qwen3_5TextModel = model.model + elif isinstance(model, Qwen3_5TextModel): + base_model: Qwen3_5TextModel = model + elif hasattr(model, "language_model"): + base_model = getattr( + model.language_model, + model.language_model.base_model_prefix, + model.language_model, + ) + else: + base_model = getattr(model, "model", model) + + _patch_qwen3_5_rms_norm = partial(_patch_rms_norm_module, offset=1.0, casting_mode="llama") + + if rms_norm: + _patch_qwen3_5_rms_norm(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_qwen3_5_rms_norm(decoder_layer.input_layernorm) + _patch_qwen3_5_rms_norm(decoder_layer.post_attention_layernorm) + self_attn = getattr(decoder_layer, "self_attn", None) + if self_attn is not None: + if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: + _patch_qwen3_5_rms_norm(self_attn.q_norm) + if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: + _patch_qwen3_5_rms_norm(self_attn.k_norm) diff --git a/src/lmms_engine/models/qwen3_5/qwen3_5_liger.py b/src/lmms_engine/models/qwen3_5/qwen3_5_liger.py new file mode 100644 index 00000000..a31a0337 --- /dev/null +++ b/src/lmms_engine/models/qwen3_5/qwen3_5_liger.py @@ -0,0 +1,132 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from transformers.modeling_outputs import CausalLMOutputWithPast + +from lmms_engine.parallel.sequence_parallel.ulysses import ( + calculate_seq_len_per_rank, + gather_outputs_and_unpad, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, + pad_to_max_across_ranks, + slice_input_tensor, +) + +from ..sequence_packing_utils import BaseModelOutputWithPastAndRmpad + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) +except Exception: + print("Liger Kernel is not installed, pip install liger-kernel to use this patch") + + +def qwen3_5_lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + use_rmpad: bool = False, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + cu_seq_lens=cu_seq_lens, + indices=indices, + ) + seq_lens = outputs.get("seq_lens", None) + word_idx = outputs.get("word_idx", None) + + hidden_states = outputs[0] + + labels_unpad = labels.view(-1)[word_idx.long()] + if get_ulysses_sequence_parallel_world_size() > 1: + seq_lens = calculate_seq_len_per_rank(seq_lens.tolist()) if seq_lens is not None else None + labels_unpad = slice_input_tensor(labels_unpad, dim=0, padding=True) + labels = labels_unpad + + logits = None + loss = None + + if self.training and (labels is not None): + if use_rmpad: + shift_hidden_states = [] + shift_labels = [] + for i in range(len(seq_lens) - 1): + cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :] + cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous() + cur_labels = labels[seq_lens[i] : seq_lens[i + 1]] + cur_shift_labels = cur_labels[1:].contiguous() + shift_hidden_states.append(cur_shift_hidden_states) + shift_labels.append(cur_shift_labels) + shift_hidden_states = torch.cat(shift_hidden_states, dim=0) + shift_labels = torch.cat(shift_labels, dim=0) + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + if get_ulysses_sequence_parallel_world_size() > 1: + reduction = "none" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if get_ulysses_sequence_parallel_world_size() > 1: + loss, total_padding = pad_to_max_across_ranks(loss, dim=0) + loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding) + num_valid_tokens = (shift_labels != -100).sum().float() + sp_group = get_ulysses_sequence_parallel_group() + if sp_group is not None: + dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group) + loss = torch.sum(loss) / (num_valid_tokens + 1e-8) + + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: + logits = self.lm_head(hidden_states) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/lmms_engine/models/qwen3_5/qwen3_5_ops.py b/src/lmms_engine/models/qwen3_5/qwen3_5_ops.py new file mode 100644 index 00000000..5af8415d --- /dev/null +++ b/src/lmms_engine/models/qwen3_5/qwen3_5_ops.py @@ -0,0 +1,274 @@ +from typing import List, Optional, Tuple, Union + +import torch +from transformers.cache_utils import Cache +from transformers.models.qwen3_5.modeling_qwen3_5 import ( + Qwen3_5Attention, + Qwen3_5DecoderLayer, + Qwen3_5DynamicCache, + Qwen3_5TextModel, + apply_rotary_pos_emb, +) +from transformers.utils import is_flash_attn_2_available, logging + +from lmms_engine.parallel.sequence_parallel.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + repeat_kv, + slice_input_tensor, + ulysses_pad, + ulysses_pad_and_slice_inputs, +) + +from ..sequence_packing_utils import BaseModelOutputWithPastAndRmpad, _unpad_input + +logger = logging.get_logger(__name__) + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, rearrange + + +def model_forward( + self: Qwen3_5TextModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + **kwargs, +) -> Union[Tuple, BaseModelOutputWithPastAndRmpad]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if cu_seq_lens is None and input_ids is not None: + original_inputs = input_ids + input_ids, indices, cu_seq_lens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask) + if get_ulysses_sequence_parallel_world_size() > 1: + input_ids_rmpad = input_ids.unsqueeze(0) + input_ids, _, pad_size = ulysses_pad_and_slice_inputs( + input_ids.unsqueeze(0), + sp_size=get_ulysses_sequence_parallel_world_size(), + ) + input_ids = input_ids.squeeze(0) + elif cu_seq_lens is None and inputs_embeds is not None: + original_inputs = inputs_embeds + inputs_embeds, indices, cu_seq_lens, max_seqlen_in_batch = _unpad_input(inputs_embeds, attention_mask) + if get_ulysses_sequence_parallel_world_size() > 1: + inputs_embeds = slice_input_tensor(inputs_embeds, dim=0, padding=True) + bs, seqlen = original_inputs.shape[:2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3_5DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + seqlen, + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + position_ids = position_ids.repeat_interleave(bs, dim=0) + + position_ids = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose( + 0, 1 + ) + original_position_ids = position_ids + + if get_ulysses_sequence_parallel_world_size() > 1: + _, position_ids, pad_size = ulysses_pad( + input_ids_rmpad, + original_position_ids, + sp_size=get_ulysses_sequence_parallel_world_size(), + ) + + # Qwen3.5 uses 4-component position IDs: text + temporal + height + width + # For text-only: expand 1D position_ids to 4 components + if position_ids.ndim == 2 and position_ids.shape[0] == 1: + position_ids = position_ids.expand(4, -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(4, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + mrope_position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] if position_ids.ndim == 3 else position_ids + mrope_position_ids = position_ids + + hidden_states = inputs_embeds + + position_embeddings = self.rotary_emb(hidden_states, mrope_position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cu_seq_lens=cu_seq_lens, + indices=indices, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPastAndRmpad( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + seq_lens=cu_seq_lens, + word_idx=indices, + ) + + +def decoder_layer_forward( + self: Qwen3_5DecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + # GatedDeltaNet expects 3D (batch, seq_len, hidden) but rmpad + # flattens to 2D (total_tokens, hidden). Add a batch dim of 1. + needs_squeeze = hidden_states.ndim == 2 + if needs_squeeze: + hidden_states = hidden_states.unsqueeze(0) + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=None, + ) + if needs_squeeze: + hidden_states = hidden_states.squeeze(0) + elif self.layer_type == "full_attention": + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cu_seq_lens=cu_seq_lens, + indices=indices, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +def attn_forward( + self: Qwen3_5Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, +): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # Qwen3.5 uses gated attention: q_proj outputs query + gate (2x size) + query_states, gate = torch.chunk(self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)) + value_states = self.v_proj(hidden_states).view(hidden_shape) + cos, sin = position_embeddings + + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + if ulysses_sp_size > 1: + assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism" + + repeats = max(ulysses_sp_size // key_states.size(1), 1) + key_states = repeat_kv(key_states, repeats) + value_states = repeat_kv(value_states, repeats) + + query_states = gather_seq_scatter_heads(query_states, seq_dim=0, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=0, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=0, head_dim=1) + + if cu_seq_lens is not None: + seq_len_tensor = torch.tensor( + query_states.shape[0], + device=cu_seq_lens.device, + dtype=cu_seq_lens.dtype, + ) + needs_append = (cu_seq_lens.max() < seq_len_tensor).item() + if needs_append: + cu_seq_lens = torch.cat([cu_seq_lens, seq_len_tensor.unsqueeze(0)]) + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None and hasattr(past_key_values, "update"): + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": None} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.transpose(1, 2).squeeze(0) + key_states = key_states.transpose(1, 2).squeeze(0) + + max_seqlen = torch.diff(cu_seq_lens).max().item() if cu_seq_lens is not None else None + window_size = (-1, -1) + + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seq_lens, + cu_seqlens_k=cu_seq_lens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=window_size, + softmax_scale=self.head_dim**-0.5, + dropout_p=0.0, + ) + + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=0, head_dim=1) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + # Apply the gated attention mechanism + attn_output = attn_output * torch.sigmoid(gate) + attn_output = self.o_proj(attn_output) + + return attn_output, None diff --git a/src/lmms_engine/models/utils.py b/src/lmms_engine/models/utils.py index e88b420b..8082bfc2 100644 --- a/src/lmms_engine/models/utils.py +++ b/src/lmms_engine/models/utils.py @@ -33,6 +33,7 @@ "qwen3_moe", "qwen3_omni_moe", "qwen3_omni_moe_thinker", + "qwen3_5", "qwen3_vl", "qwen3_vl_moe", "deepseek_v3", @@ -82,6 +83,7 @@ def __init__(self, config: PretrainedConfig): if config.model_type in [ "llava_onevision", "qwen2_5_vl", + "qwen3_5", "qwen3_vl", "qwen3_vl_moe", "qwen2_5_omni",