Skip to content

Commit fb1287d

Browse files
authored
feat: support Qwen down_proj fallback for compressed-tensors ignored modules. (#1254)
1 parent ef1543a commit fb1287d

11 files changed

Lines changed: 115 additions & 9 deletions

File tree

xllm/core/framework/hf_model_loader.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ bool try_load_compressed_tensors_quant_cfg(const JsonReader& reader,
113113
if (dynamic_it != input_activations_it->end() && !dynamic_it->is_null()) {
114114
quant_args.activation_dynamic() = dynamic_it->get<bool>();
115115
}
116+
if (const auto ignore = reader.value<std::vector<std::string>>(
117+
"quantization_config.ignore");
118+
ignore.has_value()) {
119+
quant_args.ignored_modules() = *ignore;
120+
}
116121
return true;
117122
}
118123

xllm/core/framework/hf_model_loader_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) {
4242
}
4343
}
4444
},
45+
"ignore": [
46+
"lm_head",
47+
"model.layers.1.mlp.down_proj"
48+
],
4549
"quant_method": "compressed-tensors"
4650
}
4751
}
@@ -54,6 +58,9 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) {
5458
EXPECT_EQ(quant_args.bits(), 8);
5559
EXPECT_EQ(quant_args.moe_weight_bits(), 8);
5660
EXPECT_FALSE(quant_args.activation_dynamic());
61+
ASSERT_EQ(quant_args.ignored_modules().size(), 2);
62+
EXPECT_EQ(quant_args.ignored_modules()[0], "lm_head");
63+
EXPECT_EQ(quant_args.ignored_modules()[1], "model.layers.1.mlp.down_proj");
5764
}
5865
}
5966

xllm/core/framework/quant_args.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ limitations under the License.
1717
#pragma once
1818

1919
#include <ostream>
20+
#include <regex>
2021
#include <string>
22+
#include <vector>
2123

2224
#include "common/macros.h"
2325

@@ -55,6 +57,35 @@ struct QuantArgs {
5557
// weight block size
5658
PROPERTY(std::vector<int64_t>, weight_block_size) = {};
5759

60+
// exact module names or regexes prefixed with "re:" that should bypass
61+
// quantization for compressed-tensors models.
62+
PROPERTY(std::vector<std::string>, ignored_modules) = {};
63+
64+
bool should_ignore_module(const std::string& module_name) const {
65+
for (const auto& pattern : ignored_modules()) {
66+
if (pattern == module_name) {
67+
return true;
68+
}
69+
if (pattern.size() > 3 && pattern.rfind("re:", 0) == 0) {
70+
try {
71+
if (std::regex_match(module_name, std::regex(pattern.substr(3)))) {
72+
return true;
73+
}
74+
} catch (const std::regex_error&) {
75+
}
76+
}
77+
}
78+
return false;
79+
}
80+
81+
QuantArgs for_module(const std::string& module_name) const {
82+
QuantArgs local_args = *this;
83+
if (should_ignore_module(module_name)) {
84+
local_args.quant_method().clear();
85+
}
86+
return local_args;
87+
}
88+
5889
// check if weights can be fused
5990
bool can_be_fused() const {
6091
// can't fuse quantized weights if desc_act is true
@@ -72,6 +103,7 @@ inline std::ostream& operator<<(std::ostream& os, const QuantArgs& args) {
72103
os << ", is_sym: " << args.is_sym();
73104
os << ", activation_dynamic: " << args.activation_dynamic();
74105
os << ", fmt: " << args.fmt();
106+
os << ", ignored_modules: " << args.ignored_modules().size();
75107
os << "]";
76108
return os;
77109
}

xllm/core/layers/common/dense_mlp.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size,
3131
bool enable_result_reduction,
3232
const QuantArgs& quant_args,
3333
ProcessGroup* process_group,
34-
const torch::TensorOptions& options)
34+
const torch::TensorOptions& options,
35+
const std::string& module_prefix)
3536
: is_gated_(is_gated),
3637
intermediate_size_(intermediate_size),
3738
process_group_(process_group),
@@ -73,13 +74,17 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size,
7374
act_ = register_module("act", Activation(hidden_act_, is_gated_));
7475

7576
// 2. down
77+
const auto down_proj_quant_args =
78+
module_prefix.empty()
79+
? quant_args
80+
: quant_args.for_module(module_prefix + ".down_proj");
7681
down_proj_ = register_module("down_proj",
7782
RowParallelLinear(intermediate_size_,
7883
hidden_size,
7984
/*bias=*/has_bias,
8085
/*input_is_parallelized=*/true,
8186
enable_result_reduction,
82-
quant_args,
87+
down_proj_quant_args,
8388
process_group_,
8489
options,
8590
down_proj_extra_args));

xllm/core/layers/common/dense_mlp.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class DenseMLPImpl : public torch::nn::Module {
3838
bool enable_result_reduction,
3939
const QuantArgs& quant_args,
4040
ProcessGroup* process_group,
41-
const torch::TensorOptions& options);
41+
const torch::TensorOptions& options,
42+
const std::string& module_prefix = "");
4243

4344
torch::Tensor forward(const torch::Tensor& hidden_states);
4445

xllm/core/layers/common/tests/dense_mlp_tests.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,54 @@ TEST_F(DenseMLPTest, SmoothquantLoadStateDictTest) {
318318
LOG(INFO) << "State dict loading test passed - output sum: " << output_sum;
319319
}
320320

321+
TEST_F(DenseMLPTest, Fp8IgnoredDownProjLoadsAsUnquantized) {
322+
QuantArgs fp8_quant_args;
323+
fp8_quant_args.quant_method() = kQuantMethodFp8;
324+
fp8_quant_args.bits() = 8;
325+
fp8_quant_args.activation_dynamic() = false;
326+
fp8_quant_args.ignored_modules() = {"model.layers.1.mlp.down_proj"};
327+
328+
const int64_t hidden_size = 16;
329+
const int64_t intermediate_size = 32;
330+
auto mlp = DenseMLP(DenseMLPImpl(hidden_size,
331+
intermediate_size,
332+
/*is_gated=*/true,
333+
/*has_bias=*/false,
334+
/*hidden_act=*/"silu",
335+
/*enable_result_reduction=*/true,
336+
fp8_quant_args,
337+
parallel_args_.tp_group_,
338+
options_,
339+
"model.layers.1.mlp"));
340+
341+
std::unordered_map<std::string, torch::Tensor> weight_dict;
342+
auto fp8_weight_options = options_.dtype(torch::kFloat8_e4m3fn);
343+
auto scale_options = options_.dtype(torch::kFloat32);
344+
345+
weight_dict["gate_proj.weight"] =
346+
torch::zeros({intermediate_size, hidden_size}, fp8_weight_options);
347+
weight_dict["gate_proj.weight_scale"] = torch::ones({1}, scale_options);
348+
weight_dict["gate_proj.input_scale"] = torch::ones({1}, scale_options);
349+
350+
weight_dict["up_proj.weight"] =
351+
torch::zeros({intermediate_size, hidden_size}, fp8_weight_options);
352+
weight_dict["up_proj.weight_scale"] = torch::ones({1}, scale_options);
353+
weight_dict["up_proj.input_scale"] = torch::ones({1}, scale_options);
354+
355+
weight_dict["down_proj.weight"] =
356+
torch::zeros({hidden_size, intermediate_size}, options_);
357+
358+
StateDict state_dict(weight_dict);
359+
mlp->load_state_dict(state_dict);
360+
361+
const auto params = mlp->named_parameters(/*recurse=*/true);
362+
EXPECT_TRUE(params.contains("gate_up_proj.weight_scale"));
363+
EXPECT_TRUE(params.contains("gate_up_proj.input_scale"));
364+
EXPECT_TRUE(params.contains("down_proj.weight"));
365+
EXPECT_FALSE(params.contains("down_proj.weight_scale"));
366+
EXPECT_FALSE(params.contains("down_proj.input_scale"));
367+
}
368+
321369
TEST_F(DenseMLPTest, SmoothquantPrecisionVerificationTest) {
322370
// Test precision verification with custom input and expected output
323371
const int64_t batch_size = 16;

xllm/core/layers/qwen2_decoder_layer.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ limitations under the License.
1818
namespace xllm {
1919
namespace layer {
2020

21-
Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context)
21+
Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context,
22+
int32_t layer_id)
2223
: parallel_args_(context.get_parallel_args()) {
2324
const auto& model_args = context.get_model_args();
2425
const auto& quant_args = context.get_quant_args();
2526
const auto& options = context.get_tensor_options();
27+
const std::string mlp_module_prefix =
28+
layer_id >= 0 ? "model.layers." + std::to_string(layer_id) + ".mlp" : "";
2629

2730
// Initialize attention layers
2831
attention_ = register_module("self_attn", Qwen2Attention(context));
@@ -46,7 +49,8 @@ Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context)
4649
/*enable_result_reduction=*/true,
4750
quant_args,
4851
parallel_args_.tp_group_,
49-
options));
52+
options,
53+
mlp_module_prefix));
5054
}
5155

5256
void Qwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) {

xllm/core/layers/qwen2_decoder_layer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ namespace layer {
3535

3636
class Qwen2DecoderLayerImpl : public torch::nn::Module {
3737
public:
38-
explicit Qwen2DecoderLayerImpl(const ModelContext& context);
38+
explicit Qwen2DecoderLayerImpl(const ModelContext& context,
39+
int32_t layer_id = -1);
3940

4041
void load_state_dict(const StateDict& state_dict);
4142

xllm/core/layers/qwen3_moe_decoder_layer.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context,
8080
parallel_args_,
8181
options));
8282
} else {
83+
const std::string mlp_module_prefix =
84+
"model.layers." + std::to_string(layer_id) + ".mlp";
8385
mlp_ = register_module("mlp",
8486
DenseMLP(model_args.hidden_size(),
8587
model_args.intermediate_size(),
@@ -89,7 +91,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context,
8991
/*enable_result_reduction=*/true,
9092
quant_args,
9193
parallel_args_.tp_group_,
92-
options));
94+
options,
95+
mlp_module_prefix));
9396
}
9497
}
9598

xllm/models/llm/qwen2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class QWen2ModelImpl : public LlmModelImplBase<layer::Qwen2DecoderLayer> {
4343
register_module("embed_tokens", layer::WordEmbedding(context));
4444

4545
for (int32_t i = 0; i < model_args.n_layers(); i++) {
46-
auto layer = layer::Qwen2DecoderLayer(context);
46+
auto layer = layer::Qwen2DecoderLayer(context, i);
4747
layers_.push_back(layer);
4848
}
4949
}

0 commit comments

Comments
 (0)