Skip to content

Commit 1eab18a

Browse files
authored
feat: add support for the Qwen3.5 family of models (#1624)
* feat: add support for the Qwen3.5 family of models * fix: FP8 weight loading for qwen3.5
1 parent a63d017 commit 1eab18a

12 files changed

Lines changed: 1312 additions & 31 deletions

File tree

aphrodite/config/speculative.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"deepseek_mtp",
3535
"ernie_mtp",
3636
"qwen3_next_mtp",
37+
"qwen3_5_mtp",
3738
"mimo_mtp",
3839
"longcat_flash_mtp",
3940
"mtp",
@@ -46,6 +47,7 @@
4647
"glm4_moe_lite_mtp",
4748
"ernie_mtp",
4849
"qwen3_next_mtp",
50+
"qwen3_5_mtp",
4951
"longcat_flash_mtp",
5052
)
5153

@@ -218,6 +220,11 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
218220
if hf_config.model_type == "qwen3_next_mtp":
219221
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
220222
hf_config.update({"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]})
223+
if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
224+
hf_config.model_type = "qwen3_5_mtp"
225+
if hf_config.model_type == "qwen3_5_mtp":
226+
n_predict = getattr(hf_config.text_config, "mtp_num_hidden_layers", 1)
227+
hf_config.update({"n_predict": n_predict, "architectures": ["Qwen3_5MTP"]})
221228
if hf_config.model_type == "longcat_flash":
222229
hf_config.model_type = "longcat_flash_mtp"
223230
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)

aphrodite/modeling/layers/linear.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
7272
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
7373

7474

75+
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
76+
assert weight_block_size is not None
77+
block_n = weight_block_size[0]
78+
shard_offset = (shard_offset + block_n - 1) // block_n
79+
shard_size = (shard_size + block_n - 1) // block_n
80+
return shard_size, shard_offset
81+
82+
7583
def adjust_bitsandbytes_4bit_shard(
7684
param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
7785
) -> tuple[int, int]:
@@ -744,7 +752,12 @@ def weight_loader(
744752
assert param_data.shape == loaded_weight.shape
745753
param_data.copy_(loaded_weight)
746754

747-
def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter, loaded_weight: torch.Tensor):
755+
def _load_fused_module_from_checkpoint(
756+
self,
757+
param: BaseAphroditeParameter,
758+
loaded_weight: torch.Tensor,
759+
output_sizes: list[int] | None = None,
760+
):
748761
"""
749762
Handle special case for models where MLP layers are already
750763
fused on disk. In this case, we have no shard id. This function
@@ -757,7 +770,8 @@ def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter, load
757770

758771
current_shard_offset = 0
759772
shard_offsets: list[tuple[int, int, int]] = []
760-
for i, output_size in enumerate(self.output_sizes):
773+
output_sizes = output_sizes or self.output_sizes
774+
for i, output_size in enumerate(output_sizes):
761775
shard_offsets.append((i, current_shard_offset, output_size))
762776
current_shard_offset += output_size
763777

@@ -776,37 +790,76 @@ def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter, load
776790
loaded_weight_shard = loaded_weight.narrow(param.output_dim, shard_offset, shard_size)
777791
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
778792

793+
def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
794+
if loaded_shard_id is None:
795+
return
796+
if isinstance(loaded_shard_id, tuple):
797+
for idx in loaded_shard_id:
798+
if not (0 <= idx < len(self.output_sizes)):
799+
raise ValueError(
800+
f"Shard id index {idx} should be between 0 and "
801+
f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
802+
)
803+
if len(loaded_shard_id) > 1 and any(
804+
b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
805+
):
806+
raise ValueError(
807+
"Shard id with multiple indices should be consecutive. "
808+
f"Got shard id {loaded_shard_id}."
809+
)
810+
return
811+
if isinstance(loaded_shard_id, int):
812+
if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
813+
raise ValueError(
814+
f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
815+
f"Got shard id {loaded_shard_id}."
816+
)
817+
return
818+
raise ValueError("This line should not be reached")
819+
779820
def weight_loader_v2(
780821
self,
781822
param: BaseAphroditeParameter,
782823
loaded_weight: torch.Tensor,
783-
loaded_shard_id: int | None = None,
824+
loaded_shard_id: tuple[int, ...] | int | None = None,
784825
):
785-
if loaded_shard_id is None:
826+
self.validate_shard_id(loaded_shard_id)
827+
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
786828
if isinstance(param, PerTensorScaleParameter):
787829
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
788830
return
789831
elif type(param) in (RowAphroditeParameter, BaseAphroditeParameter):
790832
param.load_merged_column_weight(loaded_weight=loaded_weight)
791833
return
792-
# TODO: @dsikka - move to parameter.py
793-
self._load_fused_module_from_checkpoint(param, loaded_weight)
834+
output_sizes = (
835+
[self.output_sizes[idx] for idx in loaded_shard_id]
836+
if loaded_shard_id
837+
else None
838+
)
839+
if isinstance(param, BlockQuantScaleParameter):
840+
weight_block_size = getattr(self, "weight_block_size", None)
841+
output_sizes = [
842+
adjust_block_scale_shard(weight_block_size, size, 0)[0]
843+
for size in (output_sizes or self.output_sizes)
844+
]
845+
self._load_fused_module_from_checkpoint(
846+
param, loaded_weight, output_sizes=output_sizes
847+
)
794848
return
795849

796850
assert loaded_shard_id < len(self.output_sizes)
797851

852+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
853+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
854+
798855
if isinstance(param, BlockQuantScaleParameter):
799856
assert self.quant_method is not None
800-
# Assume the weight block size has been set by quant method
801857
assert hasattr(self, "weight_block_size")
802858
weight_block_size = self.weight_block_size
803859
assert weight_block_size is not None
804-
block_n, _ = weight_block_size[0], weight_block_size[1]
805-
shard_offset = ((sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n) // self.tp_size
806-
shard_size = (self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // self.tp_size
807-
else:
808-
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
809-
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
860+
shard_size, shard_offset = adjust_block_scale_shard(
861+
weight_block_size, shard_size, shard_offset
862+
)
810863

811864
param.load_merged_column_weight(
812865
loaded_weight=loaded_weight,

aphrodite/modeling/layers/mamba/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_state_dtype(self) -> tuple[torch.dtype, ...]:
4848
def get_kv_cache_spec(self, aphrodite_config: AphroditeConfig) -> KVCacheSpec | None:
4949
if (
5050
aphrodite_config.speculative_config is not None
51-
and aphrodite_config.model_config.hf_config.model_type not in ["qwen3_next"]
51+
and aphrodite_config.model_config.hf_config.model_type not in ["qwen3_next", "qwen3_5", "qwen3_5_moe"]
5252
):
5353
raise NotImplementedError("Mamba with speculative decoding is not supported yet.")
5454
mamba_block_size = aphrodite_config.cache_config.mamba_block_size

aphrodite/modeling/layers/mamba/mamba_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,14 @@ def gated_delta_net_state_dtype(
6666
cls,
6767
model_dtype: ModelDType | torch.dtype,
6868
mamba_cache_dtype: MambaDType,
69+
mamba_ssm_cache_dtype: MambaDType = "auto",
6970
) -> tuple[torch.dtype, torch.dtype]:
70-
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
71-
return (state_dtype, state_dtype)
71+
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
72+
if mamba_ssm_cache_dtype == "auto":
73+
temporal_state_dtype = conv_state_dtype
74+
else:
75+
temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]
76+
return (conv_state_dtype, temporal_state_dtype)
7277

7378
@classmethod
7479
def kda_state_dtype(

aphrodite/modeling/models/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,17 @@ def verify_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
457457
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
458458

459459

460+
class Qwen3_5ForConditionalGenerationConfig(HybridAttentionMambaModelConfig):
461+
@classmethod
462+
def verify_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
463+
super().verify_and_update_config(aphrodite_config)
464+
465+
cache_config = aphrodite_config.cache_config
466+
if cache_config.mamba_ssm_cache_dtype == "auto":
467+
text_config = aphrodite_config.model_config.hf_text_config
468+
cache_config.mamba_ssm_cache_dtype = getattr(text_config, "mamba_ssm_dtype", "float32")
469+
470+
460471
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
461472
"GteModel": SnowflakeGteNewModelConfig,
462473
"GteNewModel": GteNewModelConfig,
@@ -474,4 +485,6 @@ def verify_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
474485
"Mamba2ForCausalLM": MambaModelConfig,
475486
"FalconMambaForCausalLM": MambaModelConfig,
476487
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
488+
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
489+
"Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
477490
}

0 commit comments

Comments
 (0)