diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index 2e2432cea..0aea9d997 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -113,6 +113,11 @@ bool try_load_compressed_tensors_quant_cfg(const JsonReader& reader, if (dynamic_it != input_activations_it->end() && !dynamic_it->is_null()) { quant_args.activation_dynamic() = dynamic_it->get(); } + if (const auto ignore = reader.value>( + "quantization_config.ignore"); + ignore.has_value()) { + quant_args.ignored_modules() = *ignore; + } return true; } diff --git a/xllm/core/framework/hf_model_loader_test.cpp b/xllm/core/framework/hf_model_loader_test.cpp index 9777ff02c..908464eb8 100644 --- a/xllm/core/framework/hf_model_loader_test.cpp +++ b/xllm/core/framework/hf_model_loader_test.cpp @@ -42,6 +42,10 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) { } } }, + "ignore": [ + "lm_head", + "model.layers.1.mlp.down_proj" + ], "quant_method": "compressed-tensors" } } @@ -54,6 +58,9 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) { EXPECT_EQ(quant_args.bits(), 8); EXPECT_EQ(quant_args.moe_weight_bits(), 8); EXPECT_FALSE(quant_args.activation_dynamic()); + ASSERT_EQ(quant_args.ignored_modules().size(), 2); + EXPECT_EQ(quant_args.ignored_modules()[0], "lm_head"); + EXPECT_EQ(quant_args.ignored_modules()[1], "model.layers.1.mlp.down_proj"); } } diff --git a/xllm/core/framework/quant_args.h b/xllm/core/framework/quant_args.h index aa56cb417..6f2b66156 100644 --- a/xllm/core/framework/quant_args.h +++ b/xllm/core/framework/quant_args.h @@ -17,7 +17,9 @@ limitations under the License. #pragma once #include +#include #include +#include #include "common/macros.h" @@ -55,6 +57,35 @@ struct QuantArgs { // weight block size PROPERTY(std::vector, weight_block_size) = {}; + // exact module names or regexes prefixed with "re:" that should bypass + // quantization for compressed-tensors models. + PROPERTY(std::vector, ignored_modules) = {}; + + bool should_ignore_module(const std::string& module_name) const { + for (const auto& pattern : ignored_modules()) { + if (pattern == module_name) { + return true; + } + if (pattern.size() > 3 && pattern.rfind("re:", 0) == 0) { + try { + if (std::regex_match(module_name, std::regex(pattern.substr(3)))) { + return true; + } + } catch (const std::regex_error&) { + } + } + } + return false; + } + + QuantArgs for_module(const std::string& module_name) const { + QuantArgs local_args = *this; + if (should_ignore_module(module_name)) { + local_args.quant_method().clear(); + } + return local_args; + } + // check if weights can be fused bool can_be_fused() const { // can't fuse quantized weights if desc_act is true @@ -72,6 +103,7 @@ inline std::ostream& operator<<(std::ostream& os, const QuantArgs& args) { os << ", is_sym: " << args.is_sym(); os << ", activation_dynamic: " << args.activation_dynamic(); os << ", fmt: " << args.fmt(); + os << ", ignored_modules: " << args.ignored_modules().size(); os << "]"; return os; } diff --git a/xllm/core/layers/common/dense_mlp.cpp b/xllm/core/layers/common/dense_mlp.cpp index 79ce81346..bb95dd0ff 100644 --- a/xllm/core/layers/common/dense_mlp.cpp +++ b/xllm/core/layers/common/dense_mlp.cpp @@ -31,7 +31,8 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size, bool enable_result_reduction, const QuantArgs& quant_args, ProcessGroup* process_group, - const torch::TensorOptions& options) + const torch::TensorOptions& options, + const std::string& module_prefix) : is_gated_(is_gated), intermediate_size_(intermediate_size), process_group_(process_group), @@ -73,13 +74,17 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size, act_ = register_module("act", Activation(hidden_act_, is_gated_)); // 2. down + const auto down_proj_quant_args = + module_prefix.empty() + ? quant_args + : quant_args.for_module(module_prefix + ".down_proj"); down_proj_ = register_module("down_proj", RowParallelLinear(intermediate_size_, hidden_size, /*bias=*/has_bias, /*input_is_parallelized=*/true, enable_result_reduction, - quant_args, + down_proj_quant_args, process_group_, options, down_proj_extra_args)); diff --git a/xllm/core/layers/common/dense_mlp.h b/xllm/core/layers/common/dense_mlp.h index 545799558..8b4b2248d 100644 --- a/xllm/core/layers/common/dense_mlp.h +++ b/xllm/core/layers/common/dense_mlp.h @@ -38,7 +38,8 @@ class DenseMLPImpl : public torch::nn::Module { bool enable_result_reduction, const QuantArgs& quant_args, ProcessGroup* process_group, - const torch::TensorOptions& options); + const torch::TensorOptions& options, + const std::string& module_prefix = ""); torch::Tensor forward(const torch::Tensor& hidden_states); diff --git a/xllm/core/layers/common/tests/dense_mlp_tests.cpp b/xllm/core/layers/common/tests/dense_mlp_tests.cpp index 36cf5ae60..dd0b5c046 100644 --- a/xllm/core/layers/common/tests/dense_mlp_tests.cpp +++ b/xllm/core/layers/common/tests/dense_mlp_tests.cpp @@ -318,6 +318,54 @@ TEST_F(DenseMLPTest, SmoothquantLoadStateDictTest) { LOG(INFO) << "State dict loading test passed - output sum: " << output_sum; } +TEST_F(DenseMLPTest, Fp8IgnoredDownProjLoadsAsUnquantized) { + QuantArgs fp8_quant_args; + fp8_quant_args.quant_method() = kQuantMethodFp8; + fp8_quant_args.bits() = 8; + fp8_quant_args.activation_dynamic() = false; + fp8_quant_args.ignored_modules() = {"model.layers.1.mlp.down_proj"}; + + const int64_t hidden_size = 16; + const int64_t intermediate_size = 32; + auto mlp = DenseMLP(DenseMLPImpl(hidden_size, + intermediate_size, + /*is_gated=*/true, + /*has_bias=*/false, + /*hidden_act=*/"silu", + /*enable_result_reduction=*/true, + fp8_quant_args, + parallel_args_.tp_group_, + options_, + "model.layers.1.mlp")); + + std::unordered_map weight_dict; + auto fp8_weight_options = options_.dtype(torch::kFloat8_e4m3fn); + auto scale_options = options_.dtype(torch::kFloat32); + + weight_dict["gate_proj.weight"] = + torch::zeros({intermediate_size, hidden_size}, fp8_weight_options); + weight_dict["gate_proj.weight_scale"] = torch::ones({1}, scale_options); + weight_dict["gate_proj.input_scale"] = torch::ones({1}, scale_options); + + weight_dict["up_proj.weight"] = + torch::zeros({intermediate_size, hidden_size}, fp8_weight_options); + weight_dict["up_proj.weight_scale"] = torch::ones({1}, scale_options); + weight_dict["up_proj.input_scale"] = torch::ones({1}, scale_options); + + weight_dict["down_proj.weight"] = + torch::zeros({hidden_size, intermediate_size}, options_); + + StateDict state_dict(weight_dict); + mlp->load_state_dict(state_dict); + + const auto params = mlp->named_parameters(/*recurse=*/true); + EXPECT_TRUE(params.contains("gate_up_proj.weight_scale")); + EXPECT_TRUE(params.contains("gate_up_proj.input_scale")); + EXPECT_TRUE(params.contains("down_proj.weight")); + EXPECT_FALSE(params.contains("down_proj.weight_scale")); + EXPECT_FALSE(params.contains("down_proj.input_scale")); +} + TEST_F(DenseMLPTest, SmoothquantPrecisionVerificationTest) { // Test precision verification with custom input and expected output const int64_t batch_size = 16; diff --git a/xllm/core/layers/qwen2_decoder_layer.cpp b/xllm/core/layers/qwen2_decoder_layer.cpp index 5260a9bf2..db5f8395a 100644 --- a/xllm/core/layers/qwen2_decoder_layer.cpp +++ b/xllm/core/layers/qwen2_decoder_layer.cpp @@ -18,11 +18,14 @@ limitations under the License. namespace xllm { namespace layer { -Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context) +Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context, + int32_t layer_id) : parallel_args_(context.get_parallel_args()) { const auto& model_args = context.get_model_args(); const auto& quant_args = context.get_quant_args(); const auto& options = context.get_tensor_options(); + const std::string mlp_module_prefix = + layer_id >= 0 ? "model.layers." + std::to_string(layer_id) + ".mlp" : ""; // Initialize attention layers attention_ = register_module("self_attn", Qwen2Attention(context)); @@ -46,7 +49,8 @@ Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context) /*enable_result_reduction=*/true, quant_args, parallel_args_.tp_group_, - options)); + options, + mlp_module_prefix)); } void Qwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { diff --git a/xllm/core/layers/qwen2_decoder_layer.h b/xllm/core/layers/qwen2_decoder_layer.h index 86892e945..19ed4b601 100644 --- a/xllm/core/layers/qwen2_decoder_layer.h +++ b/xllm/core/layers/qwen2_decoder_layer.h @@ -35,7 +35,8 @@ namespace layer { class Qwen2DecoderLayerImpl : public torch::nn::Module { public: - explicit Qwen2DecoderLayerImpl(const ModelContext& context); + explicit Qwen2DecoderLayerImpl(const ModelContext& context, + int32_t layer_id = -1); void load_state_dict(const StateDict& state_dict); diff --git a/xllm/core/layers/qwen3_moe_decoder_layer.cpp b/xllm/core/layers/qwen3_moe_decoder_layer.cpp index 3c370828d..6fa3f1aef 100644 --- a/xllm/core/layers/qwen3_moe_decoder_layer.cpp +++ b/xllm/core/layers/qwen3_moe_decoder_layer.cpp @@ -80,6 +80,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, parallel_args_, options)); } else { + const std::string mlp_module_prefix = + "model.layers." + std::to_string(layer_id) + ".mlp"; mlp_ = register_module("mlp", DenseMLP(model_args.hidden_size(), model_args.intermediate_size(), @@ -89,7 +91,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, /*enable_result_reduction=*/true, quant_args, parallel_args_.tp_group_, - options)); + options, + mlp_module_prefix)); } } diff --git a/xllm/models/llm/qwen2.h b/xllm/models/llm/qwen2.h index 09267b9f0..a2cd8acc1 100644 --- a/xllm/models/llm/qwen2.h +++ b/xllm/models/llm/qwen2.h @@ -43,7 +43,7 @@ class QWen2ModelImpl : public LlmModelImplBase { register_module("embed_tokens", layer::WordEmbedding(context)); for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto layer = layer::Qwen2DecoderLayer(context); + auto layer = layer::Qwen2DecoderLayer(context, i); layers_.push_back(layer); } } diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index d716297c7..f1bd972c2 100644 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -55,7 +55,7 @@ class QWen3ModelImpl : public LlmModelImplBase { options.device(), options.dtype().toScalarType(), mask_value); #endif for (int32_t i = 0; i < model_args.n_layers(); i++) { - auto layer = layer::Qwen3DecoderLayer(context); + auto layer = layer::Qwen3DecoderLayer(context, i); layers_.push_back(layer); } }