diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index cf5a50f10..b736d0da5 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -106,7 +106,8 @@ void resolve_npu_kernel_backend_for_options(Options* options) { return; } - const std::string model_type = get_model_type(options->model_path()); + const std::string model_type = + util::get_model_type(options->model_path(), options->backend()); std::string effective_backend; std::string resolved_name; std::string error_message; diff --git a/xllm/core/distributed_runtime/vlm_engine.cpp b/xllm/core/distributed_runtime/vlm_engine.cpp index b6680cfa3..299f43cff 100644 --- a/xllm/core/distributed_runtime/vlm_engine.cpp +++ b/xllm/core/distributed_runtime/vlm_engine.cpp @@ -137,6 +137,14 @@ bool VLMEngine::init_model() { n_local_kv_heads_ = std::max(1, n_kv_heads / world_size); head_dim_ = args_.head_dim(); dtype_ = util::parse_dtype(args_.dtype(), options_.devices()[0]); + if (has_linear_attention_layers(args_)) { + const int64_t linear_n_k_heads = args_.linear_num_key_heads(); + const int64_t linear_n_v_heads = args_.linear_num_value_heads(); + n_local_linear_k_heads_ = + std::max(1, linear_n_k_heads / world_size); + n_local_linear_v_heads_ = + std::max(1, linear_n_v_heads / world_size); + } // key + value for all layers LOG(INFO) << "Block info, block_size: " << options_.block_size() @@ -247,13 +255,38 @@ Engine::KVCacheCapacity VLMEngine::estimate_kv_cache_capacity() { slot_size = 2 * dtype_size * head_dim_ * n_local_kv_heads_; } kv_cache_cap.slot_size = slot_size; + if (has_linear_attention_layers(args_)) { + const int64_t head_k_dim = args_.linear_key_head_dim(); + const int64_t head_v_dim = args_.linear_value_head_dim(); + const int64_t linear_ssm_slot_size = + dtype_size * n_local_linear_v_heads_ * head_k_dim * head_v_dim; + const int64_t linear_conv_slot_size = + dtype_size * + (head_k_dim * n_local_linear_k_heads_ * 2 + + head_v_dim * n_local_linear_v_heads_) * + (args_.linear_conv_kernel_dim() - 1); + kv_cache_cap.linear_slot_size = + linear_ssm_slot_size + linear_conv_slot_size; + } kv_cache_cap.n_layers = args_.n_layers(); // compute kv cache n_blocks + int64_t full_attention_interval = (args_.full_attention_interval() < 1) + ? 1 + : args_.full_attention_interval(); + int64_t num_full_attention_layers = + kv_cache_cap.n_layers / full_attention_interval; + int64_t num_linear_attention_layers = + kv_cache_cap.n_layers - num_full_attention_layers; const int32_t block_size = options_.block_size(); - const int64_t block_size_in_bytes = block_size * slot_size; - kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes / - (args_.n_layers() * block_size_in_bytes); + const int64_t full_cache_block_size_in_bytes = block_size * slot_size; + const int64_t total_cache_block_size_in_bytes = + num_full_attention_layers * full_cache_block_size_in_bytes + + num_linear_attention_layers * kv_cache_cap.linear_slot_size; + CHECK_GT(total_cache_block_size_in_bytes, 0) + << "invalid cache block size estimate"; + kv_cache_cap.n_blocks = + kv_cache_cap.cache_size_in_bytes / total_cache_block_size_in_bytes; CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; return kv_cache_cap; @@ -266,14 +299,27 @@ bool VLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { << ", slot_size: " << kv_cache_cap.slot_size; const int32_t block_size = options_.block_size(); + const bool enable_linear_attention = has_linear_attention_layers(args_); // init kv cache for each worker std::vector> kv_cache_shape; - kv_cache_shape.reserve(2); + kv_cache_shape.reserve(enable_linear_attention ? 4 : 2); kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); + if (enable_linear_attention) { + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, + args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 + + args_.linear_key_head_dim() * n_local_linear_v_heads_, + args_.linear_conv_kernel_dim() - 1}); + kv_cache_shape.emplace_back( + std::vector{kv_cache_cap.n_blocks, + n_local_linear_v_heads_, + args_.linear_key_head_dim(), + args_.linear_value_head_dim()}); + } #if defined(USE_MLU) // transpose kv_cache layout for mlu // default layout: [n_blocks, block_size, n_head, head_dim] diff --git a/xllm/core/distributed_runtime/vlm_engine.h b/xllm/core/distributed_runtime/vlm_engine.h index 2c7668830..75c84acd6 100644 --- a/xllm/core/distributed_runtime/vlm_engine.h +++ b/xllm/core/distributed_runtime/vlm_engine.h @@ -91,6 +91,8 @@ class VLMEngine : public Engine { // config for kv cache int64_t n_local_kv_heads_ = 0; int64_t head_dim_ = 0; + int64_t n_local_linear_v_heads_ = 0; + int64_t n_local_linear_k_heads_ = 0; }; } // namespace xllm diff --git a/xllm/core/framework/batch/mposition.cpp b/xllm/core/framework/batch/mposition.cpp index 369fd5468..7a3041418 100644 --- a/xllm/core/framework/batch/mposition.cpp +++ b/xllm/core/framework/batch/mposition.cpp @@ -42,6 +42,11 @@ std::vector> groupByTokenType( current_key, start, static_cast(token_types.size())); return groups; } + +bool is_qwen3_vl_position_model(const std::string& model_type) { + return absl::StartsWith(model_type, "qwen3_vl") || + absl::StartsWith(model_type, "qwen3_5_vl"); +} } // namespace torch::Tensor MPositionHelper::get_positions() { @@ -63,7 +68,7 @@ torch::Tensor MPositionHelper::get_positions() { std::tuple res; if (absl::StartsWith(args_.model_type(), "glm4v")) { res = get_positions_glm(image_grid_thw, video_grid_thw); - } else if (absl::StartsWith(args_.model_type(), "qwen3_vl")) { + } else if (is_qwen3_vl_position_model(args_.model_type())) { res = get_positions_qwen3(image_grid_thw, video_grid_thw); } else { res = get_positions_p(image_grid_thw, video_grid_thw, second_per_grid_ts); diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index 0480c9e00..87bde6363 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -47,6 +47,7 @@ limitations under the License. #include "core/util/blocking_counter.h" #include "core/util/json_reader.h" #include "core/util/rec_model_utils.h" +#include "core/util/model_config_utils.h" #include "core/util/scope_guard.h" #include "core/util/tensor_helper.h" #include "models/model_registry.h" @@ -724,13 +725,8 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) { return false; } - std::string model_type; - if (auto data = reader.value("model_type")) { - model_type = data.value(); - } else { - LOG(ERROR) << "Failed to find model_type in " << args_file_path; - return false; - } + const std::string model_type = util::get_model_type( + reader, std::filesystem::path(model_weights_path), FLAGS_backend); std::string resolved_model_type; std::string error_message; diff --git a/xllm/core/framework/hf_model_loader_test.cpp b/xllm/core/framework/hf_model_loader_test.cpp index 9777ff02c..9dff2c07a 100644 --- a/xllm/core/framework/hf_model_loader_test.cpp +++ b/xllm/core/framework/hf_model_loader_test.cpp @@ -17,13 +17,38 @@ limitations under the License. #include +#include + #include "core/platform/device.h" +#include "core/util/model_config_utils.h" #if defined(USE_NPU) #include "models/model_registry.h" #endif namespace xllm { +TEST(HFModelLoaderTest, Qwen35BackendAwareModelTypeSelection) { + JsonReader reader; + ASSERT_TRUE(reader.parse_text(R"json( + { + "architectures": ["Qwen3_5ForConditionalGeneration"], + "model_type": "qwen3_5", + "text_config": { + "model_type": "qwen3_5_text" + }, + "vision_config": { + "model_type": "qwen3_5" + } + } + )json")); + + const auto fake_model_path = std::filesystem::path("/tmp/Qwen3.5-27B"); + EXPECT_EQ(util::get_model_type(reader, fake_model_path), "qwen3_5_text"); + EXPECT_EQ(util::get_model_type(reader, fake_model_path, "vlm"), "qwen3_5_vl"); + EXPECT_EQ(util::get_model_type(reader, fake_model_path, "llm"), + "qwen3_5_text"); +} + TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) { JsonReader reader; ASSERT_TRUE(reader.parse_text(R"json( @@ -123,6 +148,38 @@ TEST(HFModelLoaderTest, Qwen35MtpModelArgsFromMoeConfig) { EXPECT_EQ(args.layer_types()[0], "full_attention"); EXPECT_EQ(args.layer_types()[1], "full_attention"); } + +TEST(HFModelLoaderTest, Qwen35TextModelArgsKeepTextTypeAndMropeConfig) { + auto loader = ModelRegistry::get_model_args_loader("qwen3_5_text"); + ASSERT_TRUE(loader != nullptr); + + JsonReader reader; + ASSERT_TRUE(reader.parse_text(R"json( + { + "architectures": ["Qwen3_5ForConditionalGeneration"], + "model_type": "qwen3_5", + "text_config": { + "model_type": "qwen3_5_text", + "rope_parameters": { + "mrope_interleaved": true, + "mrope_section": [11, 11, 10], + "partial_rotary_factor": 0.25, + "rope_theta": 10000000 + } + }, + "vision_config": { + "model_type": "qwen3_5" + } + } + )json")); + + ModelArgs args; + ASSERT_TRUE(loader(reader, &args)); + EXPECT_EQ(args.model_type(), "qwen3_5_text"); + EXPECT_EQ(args.rope_scaling_mrope_section(), + (std::vector{11, 11, 10})); + EXPECT_TRUE(args.rope_scaling_mrope_interleaved()); +} #endif } // namespace xllm diff --git a/xllm/core/framework/model/causal_vlm.h b/xllm/core/framework/model/causal_vlm.h index 10e794ae2..4e277d4bc 100644 --- a/xllm/core/framework/model/causal_vlm.h +++ b/xllm/core/framework/model/causal_vlm.h @@ -79,19 +79,35 @@ class CausalVLMImpl : public CausalVLM { #if defined(USE_NPU) layer::NpuLmHead get_npu_lm_head() override { - return model_->get_npu_lm_head(); + if constexpr (detail::has_get_npu_lm_head::value) { + return model_->get_npu_lm_head(); + } else { + return CausalLM::get_npu_lm_head(); + } } void set_npu_lm_head(layer::NpuLmHead& head) override { - model_->set_npu_lm_head(head); + if constexpr (detail::has_set_npu_lm_head::value) { + model_->set_npu_lm_head(head); + } else { + CausalLM::set_npu_lm_head(head); + } } layer::NpuWordEmbedding get_npu_word_embedding() override { - return model_->get_npu_word_embedding(); + if constexpr (detail::has_get_npu_word_embedding::value) { + return model_->get_npu_word_embedding(); + } else { + return CausalLM::get_npu_word_embedding(); + } } void set_npu_word_embedding(layer::NpuWordEmbedding& embedding) override { - model_->set_npu_word_embedding(embedding); + if constexpr (detail::has_set_npu_word_embedding::value) { + model_->set_npu_word_embedding(embedding); + } else { + CausalLM::set_npu_word_embedding(embedding); + } } #endif layer::LmHead get_lm_head() override { diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h index 080977002..dc275e647 100644 --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -83,6 +83,7 @@ struct ModelArgs { PROPERTY(float, rope_scaling_mscale) = 0.0f; PROPERTY(float, rope_scaling_mscale_all_dim) = 0.0f; PROPERTY(std::vector, rope_scaling_mrope_section); + PROPERTY(bool, rope_scaling_mrope_interleaved) = false; // the maximum sequence length to use for rotary position embeddings. PROPERTY(int64_t, max_position_embeddings) = 0; diff --git a/xllm/core/layers/common/rotary_embedding_util.cpp b/xllm/core/layers/common/rotary_embedding_util.cpp index 9211a8734..5b3213356 100644 --- a/xllm/core/layers/common/rotary_embedding_util.cpp +++ b/xllm/core/layers/common/rotary_embedding_util.cpp @@ -86,6 +86,17 @@ class CosSinCacheManager { using torch::indexing::None; using ISlice = torch::indexing::Slice; +inline torch::Tensor rotate_every_two(const torch::Tensor& x) { + auto x1 = x.index({ISlice(), ISlice(), ISlice(0, None, 2)}); + auto x2 = x.index({ISlice(), ISlice(), ISlice(1, None, 2)}); + return torch::stack({-x2, x1}, /*dim=*/-1).flatten(/*start_dim=*/-2); +} + +inline torch::Tensor rotate_half(const torch::Tensor& x) { + auto chunks = x.chunk(2, /*dim=*/-1); + return torch::cat({-chunks[1], chunks[0]}, /*dim=*/-1); +} + // Inverse dim formula to find dim based on number of rotations inline double yarn_find_correction_dim(int num_rotations, int dim, @@ -420,6 +431,23 @@ torch::Tensor get_deepseek_rotary_embedding( options); return cos_sin; } + +std::tuple apply_rotary_pos_emb( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& cos, + const torch::Tensor& sin, + bool interleaved) { + if (interleaved) { + auto q_embed = (q * cos) + (rotate_every_two(q) * sin); + auto k_embed = (k * cos) + (rotate_every_two(k) * sin); + return std::make_tuple(q_embed, k_embed); + } + + auto q_embed = (q * cos) + (rotate_half(q) * sin); + auto k_embed = (k * cos) + (rotate_half(k) * sin); + return std::make_tuple(q_embed, k_embed); +} } // namespace rotary } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/rotary_embedding_util.h b/xllm/core/layers/common/rotary_embedding_util.h index 43b85f410..d4e5fcd7c 100644 --- a/xllm/core/layers/common/rotary_embedding_util.h +++ b/xllm/core/layers/common/rotary_embedding_util.h @@ -116,6 +116,13 @@ torch::Tensor get_deepseek_rotary_embedding( float mscale_all_dim, const torch::TensorOptions& options); +std::tuple apply_rotary_pos_emb( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& cos, + const torch::Tensor& sin, + bool interleaved); + #if defined(USE_MUSA) torch::Tensor get_interleave_rotary_embedding( int64_t dim, diff --git a/xllm/core/layers/npu_torch/qwen3_next_attention.cpp b/xllm/core/layers/npu_torch/qwen3_next_attention.cpp index 73385f1dd..fd28d8d16 100644 --- a/xllm/core/layers/npu_torch/qwen3_next_attention.cpp +++ b/xllm/core/layers/npu_torch/qwen3_next_attention.cpp @@ -18,9 +18,19 @@ limitations under the License. #include #include + +#include "layers/common/rotary_embedding_util.h" + namespace xllm { namespace layer { +namespace { + +using torch::indexing::None; +using ISlice = torch::indexing::Slice; + +} // namespace + Qwen3NextAttentionImpl::Qwen3NextAttentionImpl( const ModelArgs& args, const QuantArgs& quant_args, @@ -83,16 +93,18 @@ Qwen3NextAttentionImpl::Qwen3NextAttentionImpl( "k_norm", Qwen3NextRMSNorm(head_dim_, args.rms_norm_eps(), options)); // 5. Rotary embedding - const int rotary_dim = - static_cast(head_dim_ * args.partial_rotary_factor()); - rotary_emb_ = - register_module("rotary_emb", - PartialRotaryEmbedding(rotary_dim, + rotary_dim_ = static_cast(head_dim_ * args.partial_rotary_factor()); + CHECK_GT(rotary_dim_, 0) << "rotary_dim must be positive"; + rotary_interleaved_ = !args.rope_scaling_mrope_section().empty() && + args.rope_scaling_mrope_interleaved(); + partial_rotary_emb_ = + register_module("partial_rotary_emb", + PartialRotaryEmbedding(rotary_dim_, args.max_position_embeddings(), args.rope_theta(), head_dim_, - true, - false, + /*is_neox_style=*/true, + /*interleaved=*/false, options)); // 6. Attention @@ -158,14 +170,32 @@ torch::Tensor Qwen3NextAttentionImpl::forward( const int64_t T = q.size(0); auto q_reshaped = q.reshape({T, num_heads_, head_dim_}); - auto q_normed = q_norm_->forward(q_reshaped); auto k_reshaped = k.reshape({T, num_kv_heads_, head_dim_}); + auto q_normed = q_norm_->forward(q_reshaped); auto k_normed = k_norm_->forward(k_reshaped); - - q = q_normed.view({T, q_size_}); - k = k_normed.view({T, kv_size_}); - - rotary_emb_->forward(positions, q, k); + if (positions.dim() == 2 && attn_metadata.mrope_cos.defined() && + attn_metadata.mrope_sin.defined()) { + auto q_rotary = q_normed.index({"...", ISlice(0, rotary_dim_)}); + auto k_rotary = k_normed.index({"...", ISlice(0, rotary_dim_)}); + std::tie(q_rotary, k_rotary) = layer::rotary::apply_rotary_pos_emb( + q_rotary, + k_rotary, + attn_metadata.mrope_cos.unsqueeze(1), + attn_metadata.mrope_sin.unsqueeze(1), + rotary_interleaved_); + q_normed = torch::cat( + {q_rotary, q_normed.index({"...", ISlice(rotary_dim_, None)})}, -1); + k_normed = torch::cat( + {k_rotary, k_normed.index({"...", ISlice(rotary_dim_, None)})}, -1); + q = q_normed.reshape({T, q_size_}); + k = k_normed.reshape({T, kv_size_}); + } else { + q = q_normed.reshape({T, q_size_}); + k = k_normed.reshape({T, kv_size_}); + const torch::Tensor rotary_positions = + positions.dim() == 2 ? positions[0] : positions; + partial_rotary_emb_->forward(rotary_positions, q, k); + } auto out = std::get<0>(attn_->forward(attn_metadata, q, k, v, kv_cache)); if (attn_output_gate_) { diff --git a/xllm/core/layers/npu_torch/qwen3_next_attention.h b/xllm/core/layers/npu_torch/qwen3_next_attention.h index b7265a5ea..0c1689a98 100644 --- a/xllm/core/layers/npu_torch/qwen3_next_attention.h +++ b/xllm/core/layers/npu_torch/qwen3_next_attention.h @@ -28,6 +28,7 @@ limitations under the License. #include "layers/common/qwen3_next_rms_norm.h" namespace xllm { + namespace layer { class Qwen3NextAttentionImpl : public torch::nn::Module { @@ -63,9 +64,11 @@ class Qwen3NextAttentionImpl : public torch::nn::Module { Qwen3NextRMSNorm q_norm_{nullptr}; Qwen3NextRMSNorm k_norm_{nullptr}; + PartialRotaryEmbedding partial_rotary_emb_{nullptr}; Attention attn_{nullptr}; - PartialRotaryEmbedding rotary_emb_{nullptr}; + int64_t rotary_dim_ = 0; + bool rotary_interleaved_ = false; }; TORCH_MODULE(Qwen3NextAttention); diff --git a/xllm/core/util/model_config_utils.cpp b/xllm/core/util/model_config_utils.cpp index eb2a8d1d9..1ec408269 100644 --- a/xllm/core/util/model_config_utils.cpp +++ b/xllm/core/util/model_config_utils.cpp @@ -18,39 +18,83 @@ limitations under the License. #include +#include #include -#include "core/util/json_reader.h" +namespace xllm::util { -namespace xllm { +namespace { -std::string get_model_type(const std::filesystem::path& model_path) { +bool is_qwen35_multimodal_checkpoint(const JsonReader& reader, + const std::string& model_type) { + if (model_type != "qwen3_5" || !reader.contains("vision_config")) { + return false; + } + const auto architectures = + reader.value>("architectures"); + if (!architectures.has_value()) { + return true; + } + + return std::find(architectures->begin(), + architectures->end(), + "Qwen3_5ForConditionalGeneration") != architectures->end(); +} + +std::string resolve_model_type(const JsonReader& reader, + const std::string& model_type, + const std::optional& backend) { + if (!is_qwen35_multimodal_checkpoint(reader, model_type)) { + return model_type; + } + + if (backend.has_value() && backend.value() == "vlm") { + return "qwen3_5_vl"; + } + + const auto text_model_type = + reader.value("text_config.model_type"); + if (text_model_type.has_value()) { + return text_model_type.value(); + } + + return "qwen3_5_vl"; +} + +} // namespace + +std::string get_model_type(const JsonReader& reader, + const std::filesystem::path& model_path, + std::optional backend) { + auto model_type = reader.value("model_type"); + if (!model_type.has_value()) { + model_type = reader.value("model_name"); + } + if (!model_type.has_value()) { + LOG(FATAL) << "Please check config.json file in model path: " << model_path + << ", it should contain model_type or model_name key."; + } + + return resolve_model_type(reader, model_type.value(), backend); +} + +std::string get_model_type(const std::filesystem::path& model_path, + std::optional backend) { JsonReader reader; // for llm, vlm and rec models, the config.json file is in the model path - std::filesystem::path config_json_path = model_path / "config.json"; - - if (std::filesystem::exists(config_json_path)) { - reader.parse(config_json_path); - // Prefer model_type (e.g. LLM/VLM); fall back to model_name for configs - // that only have model_name (e.g. LongCat-Image: {"model_name": - // "LongCat-Image"}). - auto model_type = reader.value("model_type"); - if (!model_type.has_value()) { - model_type = reader.value("model_name"); - } - if (!model_type.has_value()) { - LOG(FATAL) << "Please check config.json file in model path: " - << model_path - << ", it should contain model_type or model_name key."; - } - return model_type.value(); - } else { + const std::filesystem::path config_json_path = model_path / "config.json"; + + if (!std::filesystem::exists(config_json_path)) { LOG(FATAL) << "Please check config.json or model_index.json file, one of " "them should exist in the model path: " << model_path; } + if (!reader.parse(config_json_path.string())) { + LOG(FATAL) << "Failed to parse config.json file in model path: " + << model_path; + } - return ""; + return get_model_type(reader, model_path, std::move(backend)); } -} // namespace xllm +} // namespace xllm::util diff --git a/xllm/core/util/model_config_utils.h b/xllm/core/util/model_config_utils.h index aa2763971..24de37c2d 100644 --- a/xllm/core/util/model_config_utils.h +++ b/xllm/core/util/model_config_utils.h @@ -17,10 +17,18 @@ limitations under the License. #pragma once #include +#include #include -namespace xllm { +#include "core/util/json_reader.h" -std::string get_model_type(const std::filesystem::path& model_path); +namespace xllm::util { -} // namespace xllm +std::string get_model_type(const JsonReader& reader, + const std::filesystem::path& model_path, + std::optional backend = std::nullopt); + +std::string get_model_type(const std::filesystem::path& model_path, + std::optional backend = std::nullopt); + +} // namespace xllm::util diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 9f57375e7..e59a79bdc 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include #include #include @@ -31,12 +30,12 @@ limitations under the License. // ------------------- #include -#include #include #include #include #include "core/util/json_reader.h" +#include "core/util/model_config_utils.h" #include "models/model_registry.h" namespace xllm { @@ -144,28 +143,6 @@ inline std::string get_model_name( return model_name; } -inline std::string get_model_type(const std::filesystem::path& model_path) { - JsonReader reader; - std::filesystem::path config_json_path = model_path / "config.json"; - - if (!std::filesystem::exists(config_json_path)) { - LOG(FATAL) << "Please check config.json or model_index.json file, one of " - "them should exist in the model path: " - << model_path; - } - - reader.parse(config_json_path); - auto model_type = reader.value("model_type"); - if (!model_type.has_value()) { - model_type = reader.value("model_name"); - } - if (!model_type.has_value()) { - LOG(FATAL) << "Please check config.json file in model path: " << model_path - << ", it should contain model_type or model_name key."; - } - return model_type.value(); -} - inline std::string get_model_backend(const std::filesystem::path& model_path) { JsonReader reader; std::filesystem::path model_index_json_path = model_path / "model_index.json"; @@ -190,7 +167,7 @@ inline bool should_enable_mla( if (resolved_backend == "dit") { return false; } - return is_mla_model_type(get_model_type(model_path)); + return is_mla_model_type(get_model_type(model_path, backend)); } } // namespace util diff --git a/xllm/models/llm/qwen3_5.h b/xllm/models/llm/qwen3_5.h index 4ecc9693b..464611740 100644 --- a/xllm/models/llm/qwen3_5.h +++ b/xllm/models/llm/qwen3_5.h @@ -114,6 +114,18 @@ TORCH_MODULE(Qwen3_5ForCausalLM); LOAD_ARG_TEXT_OR_ROOT(linear_num_value_heads, "linear_num_value_heads", 32); \ LOAD_ARG_TEXT_OR_ROOT(linear_value_head_dim, "linear_value_head_dim", 128); \ LOAD_QWEN3_5_ROPE_ARG(partial_rotary_factor, 0.25f); \ + LOAD_ARG_OR(rope_scaling_mrope_interleaved, \ + "text_config.rope_scaling.mrope_interleaved", \ + false); \ + LOAD_ARG_OR(rope_scaling_mrope_interleaved, \ + "rope_scaling.mrope_interleaved", \ + args->rope_scaling_mrope_interleaved()); \ + LOAD_ARG_OR(rope_scaling_mrope_interleaved, \ + "text_config.rope_parameters.mrope_interleaved", \ + args->rope_scaling_mrope_interleaved()); \ + LOAD_ARG_OR(rope_scaling_mrope_interleaved, \ + "rope_parameters.mrope_interleaved", \ + args->rope_scaling_mrope_interleaved()); \ LOAD_ARG_TEXT_OR_ROOT(shared_expert_intermediate_size, \ "shared_expert_intermediate_size", \ default_shared_expert_intermediate_size); \ @@ -153,6 +165,16 @@ TORCH_MODULE(Qwen3_5ForCausalLM); LOAD_ARG_OR(dtype, "text_config.torch_dtype", args->dtype()); \ LOAD_ARG_OR(dtype, "torch_dtype", args->dtype()) +#define LOAD_QWEN3_5_NESTED_ROPE_ARG(arg_name, json_key, default_value) \ + LOAD_ARG_OR(arg_name, "text_config." #arg_name, default_value); \ + LOAD_ARG_OR(arg_name, #arg_name, args->arg_name()); \ + LOAD_ARG_OR( \ + arg_name, "text_config.rope_scaling." json_key, args->arg_name()); \ + LOAD_ARG_OR(arg_name, "rope_scaling." json_key, args->arg_name()); \ + LOAD_ARG_OR( \ + arg_name, "text_config.rope_parameters." json_key, args->arg_name()); \ + LOAD_ARG_OR(arg_name, "rope_parameters." json_key, args->arg_name()) + REGISTER_CAUSAL_MODEL(qwen3_5, Qwen3_5ForCausalLM); REGISTER_MODEL_ARGS(qwen3_5, [&] { LOAD_QWEN3_5_TYPE_AND_DTYPE("qwen3_5"); @@ -160,15 +182,23 @@ REGISTER_MODEL_ARGS(qwen3_5, [&] { /*num_experts=*/0, /*num_experts_per_tok=*/0, /*shared_expert_intermediate_size=*/0); + LOAD_QWEN3_5_NESTED_ROPE_ARG( + rope_scaling_mrope_section, "mrope_section", std::vector()); }); REGISTER_CAUSAL_MODEL(qwen3_5_text, Qwen3_5ForCausalLM); REGISTER_MODEL_ARGS(qwen3_5_text, [&] { - LOAD_QWEN3_5_TYPE_AND_DTYPE("qwen3_5_text"); + SET_ARG(model_type, "qwen3_5_text"); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + LOAD_ARG_OR(dtype, "dtype", args->dtype()); + LOAD_ARG_OR(dtype, "text_config.torch_dtype", args->dtype()); + LOAD_ARG_OR(dtype, "torch_dtype", args->dtype()); LOAD_QWEN3_5_NEXT_COMPAT_ARGS(/*moe_intermediate_size=*/0, /*num_experts=*/0, /*num_experts_per_tok=*/0, /*shared_expert_intermediate_size=*/0); + LOAD_QWEN3_5_NESTED_ROPE_ARG( + rope_scaling_mrope_section, "mrope_section", std::vector()); }); REGISTER_CAUSAL_MODEL(qwen3_5_moe, Qwen3_5ForCausalLM); @@ -178,17 +208,26 @@ REGISTER_MODEL_ARGS(qwen3_5_moe, [&] { /*num_experts=*/512, /*num_experts_per_tok=*/10, /*shared_expert_intermediate_size=*/512); + LOAD_QWEN3_5_NESTED_ROPE_ARG( + rope_scaling_mrope_section, "mrope_section", std::vector()); }); REGISTER_CAUSAL_MODEL(qwen3_5_moe_text, Qwen3_5ForCausalLM); REGISTER_MODEL_ARGS(qwen3_5_moe_text, [&] { - LOAD_QWEN3_5_TYPE_AND_DTYPE("qwen3_5_moe_text"); + SET_ARG(model_type, "qwen3_5_moe_text"); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + LOAD_ARG_OR(dtype, "dtype", args->dtype()); + LOAD_ARG_OR(dtype, "text_config.torch_dtype", args->dtype()); + LOAD_ARG_OR(dtype, "torch_dtype", args->dtype()); LOAD_QWEN3_5_NEXT_COMPAT_ARGS(/*moe_intermediate_size=*/512, /*num_experts=*/512, /*num_experts_per_tok=*/10, /*shared_expert_intermediate_size=*/512); + LOAD_QWEN3_5_NESTED_ROPE_ARG( + rope_scaling_mrope_section, "mrope_section", std::vector()); }); +#undef LOAD_QWEN3_5_NESTED_ROPE_ARG #undef LOAD_QWEN3_5_TYPE_AND_DTYPE #undef LOAD_QWEN3_5_NEXT_COMPAT_ARGS #undef LOAD_QWEN3_5_ROPE_ARG diff --git a/xllm/models/llm/qwen3_next_hybrid_base.h b/xllm/models/llm/qwen3_next_hybrid_base.h index c765f2a29..54847caf1 100644 --- a/xllm/models/llm/qwen3_next_hybrid_base.h +++ b/xllm/models/llm/qwen3_next_hybrid_base.h @@ -15,6 +15,7 @@ limitations under the License. #pragma once +#include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "core/layers/common/attention_metadata_builder.h" #include "core/layers/common/lm_head.h" #include "core/layers/common/qwen3_next_rms_norm.h" +#include "core/layers/common/rotary_embedding_util.h" #include "core/layers/common/word_embedding.h" #include "core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.h" @@ -74,6 +76,22 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { options.dtype().toScalarType(), /*mask_value=*/mask_value); dp_size_ = parallel_args.dp_size(); + mrope_section_ = model_args_.rope_scaling_mrope_section(); + rotary_interleaved_ = + !mrope_section_.empty() && model_args_.rope_scaling_mrope_interleaved(); + const int64_t rotary_dim = static_cast( + model_args_.head_dim() * model_args_.partial_rotary_factor()); + CHECK_GT(rotary_dim, 0) << "rotary_dim must be positive"; + auto inv_freq = layer::rotary::compute_inv_freq( + rotary_dim, model_args_.rope_theta(), options); + rotary_cos_sin_cache_ = + register_buffer("rotary_cos_sin_cache", + layer::rotary::compute_cos_sin_cache( + rotary_dim, + model_args_.max_position_embeddings(), + rotary_interleaved_, + inv_freq, + options)); } // tokens: [num_tokens] @@ -94,11 +112,25 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { layer::AttentionMetadata attn_metadata = layer::AttentionMetadataBuilder::build( input_params, model_args_, build_attention_mask(input_params)); - torch::Tensor h = embed_tokens_(tokens); + const bool only_prefill = + attn_metadata.is_prefill || attn_metadata.is_chunked_prefill; + if (positions.dim() == 2 && only_prefill && !mrope_section_.empty()) { + std::tie(attn_metadata.mrope_cos, attn_metadata.mrope_sin) = + build_rotary_cos_sin(positions); + } + const bool use_deepstack = !input_params.deep_stacks.empty(); + auto deep_stacks = input_params.deep_stacks; + torch::Tensor h = input_params.input_embedding; + if (!h.defined()) { + h = embed_tokens_(tokens); + } for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; h = layer->forward( h, positions, attn_metadata, kv_caches[i], input_params); + if (use_deepstack && i < deep_stacks.size()) { + h = deepstack_process(h, input_params.visual_pos_masks, deep_stacks[i]); + } } h = norm_(h); return ModelOutput(h); @@ -138,6 +170,56 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { } protected: + torch::Tensor deepstack_process(torch::Tensor hidden_states, + torch::Tensor visual_pos_masks, + const torch::Tensor& visual_embeds) { + visual_pos_masks = visual_pos_masks.to(hidden_states.device()); + auto selected = hidden_states.index({visual_pos_masks}); + hidden_states.index_put_({visual_pos_masks}, selected + visual_embeds); + return hidden_states; + } + + std::pair build_rotary_cos_sin( + const torch::Tensor& positions) { + namespace F = torch::nn::functional; + auto cos_sin = F::embedding(positions, rotary_cos_sin_cache_); + auto chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = chunks[0]; + auto sin_pos = chunks[1]; + if (positions.dim() != 2 || mrope_section_.empty()) { + return std::make_pair(cos_pos, sin_pos); + } + + auto reorder = [this](const torch::Tensor& x) { + std::vector sections = mrope_section_; + if (rotary_interleaved_) { + for (int64_t& section : sections) { + section *= 2; + } + } else { + sections.insert( + sections.end(), mrope_section_.begin(), mrope_section_.end()); + } + + auto split_tensors = x.split(sections, /*dim=*/-1); + std::vector selected_tensors; + selected_tensors.reserve(split_tensors.size()); + for (int64_t i = 0; i < static_cast(split_tensors.size()); ++i) { + const int64_t pos_axis = + rotary_interleaved_ + ? i + : i % static_cast(mrope_section_.size()); + CHECK_LT(pos_axis, split_tensors[i].size(0)) + << "mRoPE position axis out of range for section " << i; + selected_tensors.push_back( + split_tensors[i].select(/*dim=*/0, /*index=*/pos_axis)); + } + return torch::cat(selected_tensors, /*dim=*/-1); + }; + + return std::make_pair(reorder(cos_pos), reorder(sin_pos)); + } + torch::Tensor build_attention_mask(const ModelInputParams& input_params) { max_seq_len_ = std::max(input_params.kv_max_seq_len, max_seq_len_); if (!FLAGS_enable_chunked_prefill) { @@ -172,6 +254,9 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { layer::Qwen3NextRMSNorm norm_{nullptr}; layer::AttentionMask attn_mask_; layer::WordEmbedding embed_tokens_{nullptr}; + torch::Tensor rotary_cos_sin_cache_; + std::vector mrope_section_; + bool rotary_interleaved_ = false; }; class Qwen3HybridForCausalLMImplBase : public torch::nn::Module { diff --git a/xllm/models/model_registry.cpp b/xllm/models/model_registry.cpp index 9d633adb0..945e4a877 100644 --- a/xllm/models/model_registry.cpp +++ b/xllm/models/model_registry.cpp @@ -68,6 +68,7 @@ constexpr char kTorchBackend[] = "TORCH"; bool is_torch_only_model_type(const std::string& model_type) { static const std::unordered_set kTorchOnlyModelTypes = { "qwen3_5", + "qwen3_5_vl", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text", diff --git a/xllm/models/models.h b/xllm/models/models.h index d7ece4ee8..67b817347 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -57,7 +57,7 @@ limitations under the License. #include "vlm/npu/qwen3_vl.h" // IWYU pragma: keep #include "vlm/npu/qwen3_vl_mm_embedding.h" // IWYU pragma: keep #include "vlm/npu/qwen3_vl_moe.h" // IWYU pragma: keep - +#include "vlm/qwen3_5_vl.h" // IWYU pragma: keep #elif defined(USE_MLU) #include "dit/pipeline_flux.h" // IWYU pragma: keep #include "dit/pipeline_flux_control.h" // IWYU pragma: keep diff --git a/xllm/models/vlm/qwen3_5_vl.h b/xllm/models/vlm/qwen3_5_vl.h new file mode 100644 index 000000000..3a16aa3d7 --- /dev/null +++ b/xllm/models/vlm/qwen3_5_vl.h @@ -0,0 +1,415 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#define XLLM_DISABLE_GENERIC_VLM_REGISTRATION +#include "models/vlm/qwen3_vl.h" +#undef XLLM_DISABLE_GENERIC_VLM_REGISTRATION + +#if defined(USE_NPU) +#include "models/vlm/npu/qwen3_vl.h" +#endif + +#include "models/llm/qwen3_5.h" + +namespace xllm { + +using Qwen3_5_VLInputProcessor = Qwen3_VLInputProcessor; +using Qwen3_5_VLImageProcessor = Qwen3VLImageProcessor; + +class Qwen3_5_VLForConditionalGenerationImpl : public torch::nn::Module { + public: + explicit Qwen3_5_VLForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { +#if defined(USE_NPU) + visual_ = + register_module("visual", npu::model::Qwen3_VisionTransformer(context)); +#else + visual_ = register_module("visual", Qwen3_VisionTransformer(context)); +#endif + language_model_ = + register_module("language_model", Qwen3_5ForCausalLM(context)); + } + + void prepare_encoder_input(const ModelInputParams& input_params, + std::optional& image_inputs, + std::optional& video_inputs) { + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) { + pixel_values = res.value(); + } + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) { + image_grid_thw = res.value(); + } + + torch::Tensor pixel_values_videos; + if (const auto& res = mm_data.get("pixel_values_videos")) { + pixel_values_videos = res.value(); + } + + torch::Tensor video_grid_thw; + if (const auto& res = mm_data.get("video_grid_thw")) { + video_grid_thw = res.value(); + } + + if (pixel_values.defined() && image_grid_thw.defined()) { + image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; + } + + if (pixel_values_videos.defined() && video_grid_thw.defined()) { + video_inputs = Qwen3_VLVideoInputs{pixel_values_videos, video_grid_thw}; + } + } + + MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + std::optional image_input; + std::optional video_input; + prepare_encoder_input(input_params, image_input, video_input); + + MMDict multimodal_embeds; + const int32_t merge_size = model_args_.mm_image_merge_size(); + if (image_input) { + torch::Tensor image_embeds; + std::vector deep_stacks; + std::tie(image_embeds, deep_stacks) = + visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw.to(options_.device()), + input_params); + + auto image_tokens = + (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) + .cpu() + .contiguous() + .to(torch::kLong); + + std::vector image_tokens_vec( + image_tokens.data_ptr(), + image_tokens.data_ptr() + image_tokens.numel()); + multimodal_embeds["image|embedding"] = + image_embeds.split(image_tokens_vec, 0); + + for (size_t i = 0; i < deep_stacks.size(); ++i) { + multimodal_embeds[std::string("image|embedding|deepstack_") + + std::to_string(i)] = + deep_stacks[i].split(image_tokens_vec, 0); + } + } + + if (video_input) { + torch::Tensor video_embeds; + std::vector deep_stacks; + std::tie(video_embeds, deep_stacks) = + visual_(video_input->pixel_values_videos.to(options_), + video_input->video_grid_thw.to(options_.device()), + input_params); + + auto video_tokens = + (video_input->video_grid_thw.prod(-1) / merge_size / merge_size) + .cpu() + .contiguous() + .to(torch::kLong); + + std::vector video_tokens_vec( + video_tokens.data_ptr(), + video_tokens.data_ptr() + video_tokens.numel()); + multimodal_embeds["video|embedding"] = + video_embeds.split(video_tokens_vec, 0); + + for (size_t i = 0; i < deep_stacks.size(); ++i) { + multimodal_embeds[std::string("video|embedding|deepstack_") + + std::to_string(i)] = + deep_stacks[i].split(video_tokens_vec, 0); + } + } + return multimodal_embeds; + } + + torch::Tensor generate_multimodal_mask(torch::Tensor input_ids) { + torch::Tensor special_token_ids = torch::tensor( + {model_args_.image_token_id(), model_args_.video_token_id()}, + input_ids.options().dtype(torch::kInt64)); + return torch::isin(input_ids, special_token_ids); + } + + std::vector get_deep_stacks( + const ModelInputParams& input_params) { + const auto& mm_data = input_params.mm_data; + if (!mm_data.has("embedding|deepstack_0")) { + return {}; + } + + std::vector deepstacks = { + mm_data.get("embedding|deepstack_0").value(), + mm_data.get("embedding|deepstack_1").value(), + mm_data.get("embedding|deepstack_2").value()}; + return deepstacks; + } + + torch::Tensor merge_multimodal_embeddings( + torch::Tensor inputs_embeds, + const torch::Tensor& multimodal_embeds, + const torch::Tensor& is_multimodal) { + inputs_embeds.index_put_({is_multimodal}, multimodal_embeds); + return inputs_embeds; + } + + torch::Tensor get_input_embeddings(const torch::Tensor input_ids, + const ModelInputParams& input_params) { + const auto& mm_data = input_params.mm_data; + torch::Tensor multimodal_embeds; + if (const auto& emb = mm_data.get("embedding")) { + multimodal_embeds = emb.value(); + } + + torch::Tensor inputs_embeds = + language_model_->get_word_embedding()(input_ids); + if (!multimodal_embeds.defined()) { + return inputs_embeds; + } + + torch::Tensor is_multimodal = generate_multimodal_mask(input_ids); + input_params.visual_pos_masks = is_multimodal; + return merge_multimodal_embeddings( + inputs_embeds, multimodal_embeds, is_multimodal); + } + + ModelOutput forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + input_params.deep_stacks = std::move(get_deep_stacks(input_params)); + return language_model_(tokens, positions, kv_caches, input_params); + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict( + state_dict->get_dict_with_prefix("model.visual.")); + } +#if defined(USE_NPU) + visual_->verify_loaded_weights("model.visual."); + visual_->merge_loaded_weights(); +#endif + + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader)); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; +#if defined(USE_NPU) + npu::model::Qwen3_VisionTransformer visual_{nullptr}; +#else + Qwen3_VisionTransformer visual_{nullptr}; +#endif + Qwen3_5ForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Qwen3_5_VLForConditionalGeneration); + +#define LOAD_QWEN3_5_VL_TEXT_OR_ROOT(arg_name, json_key, default_value) \ + LOAD_ARG_OR(arg_name, "text_config." json_key, default_value); \ + LOAD_ARG_OR(arg_name, json_key, args->arg_name()) + +#define LOAD_QWEN3_5_VL_ROPE_ARG(arg_name, default_value) \ + LOAD_ARG_OR(arg_name, "text_config." #arg_name, default_value); \ + LOAD_ARG_OR(arg_name, #arg_name, args->arg_name()); \ + LOAD_ARG_OR( \ + arg_name, "text_config.rope_scaling." #arg_name, args->arg_name()); \ + LOAD_ARG_OR(arg_name, "rope_scaling." #arg_name, args->arg_name()); \ + LOAD_ARG_OR( \ + arg_name, "text_config.rope_parameters." #arg_name, args->arg_name()); \ + LOAD_ARG_OR(arg_name, "rope_parameters." #arg_name, args->arg_name()) + +#define LOAD_QWEN3_5_VL_NESTED_ROPE_ARG(arg_name, json_key, default_value) \ + LOAD_ARG_OR(arg_name, "text_config." #arg_name, default_value); \ + LOAD_ARG_OR(arg_name, #arg_name, args->arg_name()); \ + LOAD_ARG_OR( \ + arg_name, "text_config.rope_scaling." json_key, args->arg_name()); \ + LOAD_ARG_OR(arg_name, "rope_scaling." json_key, args->arg_name()); \ + LOAD_ARG_OR( \ + arg_name, "text_config.rope_parameters." json_key, args->arg_name()); \ + LOAD_ARG_OR(arg_name, "rope_parameters." json_key, args->arg_name()) + +#define LOAD_QWEN3_5_VL_TOKEN_ARG(arg_name, default_value) \ + LOAD_ARG_OR(arg_name, "text_config." #arg_name, default_value); \ + LOAD_ARG_OR(arg_name, #arg_name, args->arg_name()) + +REGISTER_INPUT_PROCESSOR(qwen3_5_vl, Qwen3_5_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(qwen3_5_vl, Qwen3_5_VLForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(qwen3_5_vl, Qwen3_5_VLImageProcessor); + +REGISTER_MODEL_ARGS(qwen3_5_vl, [&] { + SET_ARG(model_type, "qwen3_5_vl"); + + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + LOAD_ARG_OR(dtype, "dtype", args->dtype()); + LOAD_ARG_OR(dtype, "text_config.torch_dtype", args->dtype()); + LOAD_ARG_OR(dtype, "torch_dtype", args->dtype()); + + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(attention_bias, "attention_bias", false); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(attention_dropout, "attention_dropout", 0.0f); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(bos_token_id, "bos_token_id", 151643); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(decoder_sparse_step, "decoder_sparse_step", 1); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(eos_token_id, "eos_token_id", 151645); + LOAD_QWEN3_5_VL_TOKEN_ARG(vision_start_token_id, 248053); + LOAD_QWEN3_5_VL_TOKEN_ARG(vision_end_token_id, 248054); + LOAD_QWEN3_5_VL_TOKEN_ARG(vision_token_id, 248055); + LOAD_QWEN3_5_VL_TOKEN_ARG(image_token_id, 248056); + LOAD_QWEN3_5_VL_TOKEN_ARG(video_token_id, 248057); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(head_dim, "head_dim", 256); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(hidden_act, "hidden_act", "silu"); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(hidden_size, "hidden_size", 5120); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(initializer_range, "initializer_range", 0.02f); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(intermediate_size, "intermediate_size", 17408); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT( + max_position_embeddings, "max_position_embeddings", 262144); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(max_window_layers, "max_window_layers", 64); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(n_heads, "num_attention_heads", 24); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(n_layers, "num_hidden_layers", 64); + LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 4); + LOAD_ARG_OR( + n_kv_heads, "num_key_value_heads", args->n_kv_heads().value_or(4)); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_QWEN3_5_VL_ROPE_ARG(rope_theta, 10000000.0f); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(use_sliding_window, "use_sliding_window", false); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(sliding_window, "sliding_window", 4096); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT( + tie_word_embeddings, "tie_word_embeddings", false); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(vocab_size, "vocab_size", 248320); + LOAD_ARG_OR( + mlp_only_layers, "text_config.mlp_only_layers", std::vector()); + LOAD_ARG_OR(mlp_only_layers, "mlp_only_layers", args->mlp_only_layers()); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(attn_output_gate, "attn_output_gate", true); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT( + full_attention_interval, "full_attention_interval", 4); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT( + linear_conv_kernel_dim, "linear_conv_kernel_dim", 4); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT(linear_key_head_dim, "linear_key_head_dim", 128); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT( + linear_num_key_heads, "linear_num_key_heads", 16); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT( + linear_num_value_heads, "linear_num_value_heads", 48); + LOAD_QWEN3_5_VL_TEXT_OR_ROOT( + linear_value_head_dim, "linear_value_head_dim", 128); + LOAD_QWEN3_5_VL_ROPE_ARG(partial_rotary_factor, 0.25f); + LOAD_ARG_OR(rope_scaling_mrope_interleaved, + "text_config.rope_scaling.mrope_interleaved", + false); + LOAD_ARG_OR(rope_scaling_mrope_interleaved, + "rope_scaling.mrope_interleaved", + args->rope_scaling_mrope_interleaved()); + LOAD_ARG_OR(rope_scaling_mrope_interleaved, + "text_config.rope_parameters.mrope_interleaved", + args->rope_scaling_mrope_interleaved()); + LOAD_ARG_OR(rope_scaling_mrope_interleaved, + "rope_parameters.mrope_interleaved", + args->rope_scaling_mrope_interleaved()); + LOAD_ARG_OR(num_nextn_predict_layers, "text_config.mtp_num_hidden_layers", 0); + LOAD_ARG_OR(num_nextn_predict_layers, + "mtp_num_hidden_layers", + args->num_nextn_predict_layers()); + LOAD_ARG_OR(num_nextn_predict_layers, + "text_config.num_nextn_predict_layers", + args->num_nextn_predict_layers()); + LOAD_ARG_OR(num_nextn_predict_layers, + "num_nextn_predict_layers", + args->num_nextn_predict_layers()); + LOAD_ARG_OR( + layer_types, "text_config.layer_types", std::vector()); + LOAD_ARG_OR(layer_types, "layer_types", args->layer_types()); + LOAD_ARG_OR( + layer_types, "text_config.layers_block_type", args->layer_types()); + LOAD_ARG_OR(layer_types, "layers_block_type", args->layer_types()); + + SET_ARG(moe_intermediate_size, 0); + SET_ARG(norm_topk_prob, true); + SET_ARG(num_experts, 0); + SET_ARG(num_experts_per_tok, 0); + SET_ARG(output_router_logits, false); + SET_ARG(router_aux_loss_coef, 0.001f); + SET_ARG(shared_expert_intermediate_size, 0); + SET_ARG(n_routed_experts, 0); + SET_ARG(n_shared_experts, 0); + SET_ARG(scoring_func, "softmax"); + SET_ARG(topk_method, ""); + SET_ARG(n_group, -1); + SET_ARG(topk_group, 0); + SET_ARG(routed_scaling_factor, 1.0f); + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); + + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 27); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "gelu_pytorch_tanh"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1152); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 4304); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 16); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 5120); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 16); + LOAD_ARG_OR(mm_num_position_embeddings, + "vision_config.num_position_embeddings", + 2304); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG(mm_deepstack_visual_indexes, + "vision_config.deepstack_visual_indexes"); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + LOAD_QWEN3_5_VL_NESTED_ROPE_ARG(rope_scaling_rope_type, "type", "default"); + LOAD_QWEN3_5_VL_NESTED_ROPE_ARG(rope_scaling_mrope_section, + "mrope_section", + std::vector({11, 11, 10})); + if (args->rope_scaling_rope_type() == "default") { + args->rope_scaling_rope_type() = "mrope"; + } +}); + +#undef LOAD_QWEN3_5_VL_TOKEN_ARG +#undef LOAD_QWEN3_5_VL_NESTED_ROPE_ARG +#undef LOAD_QWEN3_5_VL_ROPE_ARG +#undef LOAD_QWEN3_5_VL_TEXT_OR_ROOT + +} // namespace xllm diff --git a/xllm/models/vlm/qwen3_vl.h b/xllm/models/vlm/qwen3_vl.h index 930499639..a12aaeb1d 100644 --- a/xllm/models/vlm/qwen3_vl.h +++ b/xllm/models/vlm/qwen3_vl.h @@ -740,6 +740,7 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { }; TORCH_MODULE(Qwen3_VLForConditionalGeneration); +#ifndef XLLM_DISABLE_GENERIC_VLM_REGISTRATION REGISTER_INPUT_PROCESSOR(qwen3_vl, Qwen3_VLInputProcessor); REGISTER_CAUSAL_VLM_MODEL(qwen3_vl, Qwen3_VLForConditionalGeneration); REGISTER_IMAGE_PROCESSOR(qwen3_vl, Qwen3VLImageProcessor); @@ -803,4 +804,5 @@ REGISTER_MODEL_ARGS(qwen3_vl, [&] { LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 151936); }); +#endif // XLLM_DISABLE_GENERIC_VLM_REGISTRATION } // namespace xllm diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 32aab0bb3..f1c3aa53a 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -50,12 +50,27 @@ static const std::unordered_set prefill_sp_supported_model_set = { "deepseek_v32", "glm_moe_dsa"}; +static const std::unordered_set + prefix_cache_unsupported_model_set = {"qwen3_5_text", "qwen3_5_moe_text"}; + +static const std::unordered_set + chunked_prefill_unsupported_model_set = {"qwen3_5_text", + "qwen3_5_moe_text"}; + void shutdown_handler(int signal) { // TODO: gracefully shutdown the server LOG(WARNING) << "Received signal " << signal << ", stopping server..."; exit(1); } +bool should_disable_chunked_prefill(const std::string& model_type) { + return chunked_prefill_unsupported_model_set.contains(model_type); +} + +bool should_disable_prefix_cache(const std::string& model_type) { + return prefix_cache_unsupported_model_set.contains(model_type); +} + void validate_flags(const std::string& model_type) { if (FLAGS_backend.empty()) { LOG(FATAL) << "Model is not supported currently, model type: " @@ -167,13 +182,25 @@ int run() { #endif std::string model_type = ""; if (FLAGS_backend != "dit") { - model_type = xllm::util::get_model_type(model_path); + model_type = xllm::util::get_model_type(model_path, FLAGS_backend); FLAGS_tool_call_parser = function_call::FunctionCallParser::get_parser_auto( FLAGS_tool_call_parser, model_type); FLAGS_reasoning_parser = ReasoningParser::get_parser_auto(FLAGS_reasoning_parser, model_type); } + if (FLAGS_enable_chunked_prefill && + should_disable_chunked_prefill(model_type)) { + LOG(WARNING) << "Disabling chunked prefill for model_type=" << model_type + << " due to known output corruption on the llm path."; + FLAGS_enable_chunked_prefill = false; + } + if (FLAGS_enable_prefix_cache && should_disable_prefix_cache(model_type)) { + LOG(WARNING) << "Disabling prefix cache for model_type=" << model_type + << " because it is not supported on the llm path."; + FLAGS_enable_prefix_cache = false; + } + // validate flags before creating master validate_flags(model_type);