Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions xllm/core/framework/hf_model_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ bool try_load_compressed_tensors_quant_cfg(const JsonReader& reader,
if (dynamic_it != input_activations_it->end() && !dynamic_it->is_null()) {
quant_args.activation_dynamic() = dynamic_it->get<bool>();
}
if (const auto ignore = reader.value<std::vector<std::string>>(
"quantization_config.ignore");
ignore.has_value()) {
quant_args.ignored_modules() = *ignore;
}
return true;
}

Expand Down
7 changes: 7 additions & 0 deletions xllm/core/framework/hf_model_loader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) {
}
}
},
"ignore": [
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this rule applicable to all models?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field is part of the quantization config schema, not a model-specific rule. The JSON (including ignore) is generated by the quantization tool — at least AngelSlim produces this field. So the applicability depends on which quant tool was used, not the model itself.

"lm_head",
"model.layers.1.mlp.down_proj"
],
"quant_method": "compressed-tensors"
}
}
Expand All @@ -54,6 +58,9 @@ TEST(HFModelLoaderTest, LoadCompressedTensorsFp8StaticConfig) {
EXPECT_EQ(quant_args.bits(), 8);
EXPECT_EQ(quant_args.moe_weight_bits(), 8);
EXPECT_FALSE(quant_args.activation_dynamic());
ASSERT_EQ(quant_args.ignored_modules().size(), 2);
EXPECT_EQ(quant_args.ignored_modules()[0], "lm_head");
EXPECT_EQ(quant_args.ignored_modules()[1], "model.layers.1.mlp.down_proj");
}
}

Expand Down
32 changes: 32 additions & 0 deletions xllm/core/framework/quant_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
#pragma once

#include <ostream>
#include <regex>
#include <string>
#include <vector>

#include "common/macros.h"

Expand Down Expand Up @@ -55,6 +57,35 @@ struct QuantArgs {
// weight block size
PROPERTY(std::vector<int64_t>, weight_block_size) = {};

// exact module names or regexes prefixed with "re:" that should bypass
// quantization for compressed-tensors models.
PROPERTY(std::vector<std::string>, ignored_modules) = {};

bool should_ignore_module(const std::string& module_name) const {
for (const auto& pattern : ignored_modules()) {
if (pattern == module_name) {
return true;
}
if (pattern.size() > 3 && pattern.rfind("re:", 0) == 0) {
try {
if (std::regex_match(module_name, std::regex(pattern.substr(3)))) {
return true;
}
} catch (const std::regex_error&) {
}
}
Comment on lines +69 to +76
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Creating a std::regex object inside a loop for every module check is extremely inefficient as regex compilation is an expensive operation. This can significantly slow down model initialization, especially for models with many layers and multiple ignore patterns. Additionally, the magic number 3 should be replaced with a named constant (e.g., kRegexPrefixLength). Consider pre-compiling the regexes or using a more efficient matching strategy.

}
return false;
}

QuantArgs for_module(const std::string& module_name) const {
QuantArgs local_args = *this;
if (should_ignore_module(module_name)) {
local_args.quant_method().clear();
}
return local_args;
}

// check if weights can be fused
bool can_be_fused() const {
// can't fuse quantized weights if desc_act is true
Expand All @@ -72,6 +103,7 @@ inline std::ostream& operator<<(std::ostream& os, const QuantArgs& args) {
os << ", is_sym: " << args.is_sym();
os << ", activation_dynamic: " << args.activation_dynamic();
os << ", fmt: " << args.fmt();
os << ", ignored_modules: " << args.ignored_modules().size();
os << "]";
return os;
}
Expand Down
9 changes: 7 additions & 2 deletions xllm/core/layers/common/dense_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size,
bool enable_result_reduction,
const QuantArgs& quant_args,
ProcessGroup* process_group,
const torch::TensorOptions& options)
const torch::TensorOptions& options,
const std::string& module_prefix)
: is_gated_(is_gated),
intermediate_size_(intermediate_size),
process_group_(process_group),
Expand Down Expand Up @@ -73,13 +74,17 @@ DenseMLPImpl::DenseMLPImpl(int64_t hidden_size,
act_ = register_module("act", Activation(hidden_act_, is_gated_));

// 2. down
const auto down_proj_quant_args =
module_prefix.empty()
? quant_args
: quant_args.for_module(module_prefix + ".down_proj");
Comment thread
walsonyang marked this conversation as resolved.
down_proj_ = register_module("down_proj",
RowParallelLinear(intermediate_size_,
hidden_size,
/*bias=*/has_bias,
/*input_is_parallelized=*/true,
enable_result_reduction,
quant_args,
down_proj_quant_args,
process_group_,
options,
down_proj_extra_args));
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/layers/common/dense_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class DenseMLPImpl : public torch::nn::Module {
bool enable_result_reduction,
const QuantArgs& quant_args,
ProcessGroup* process_group,
const torch::TensorOptions& options);
const torch::TensorOptions& options,
const std::string& module_prefix = "");

torch::Tensor forward(const torch::Tensor& hidden_states);

Expand Down
48 changes: 48 additions & 0 deletions xllm/core/layers/common/tests/dense_mlp_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,54 @@ TEST_F(DenseMLPTest, SmoothquantLoadStateDictTest) {
LOG(INFO) << "State dict loading test passed - output sum: " << output_sum;
}

TEST_F(DenseMLPTest, Fp8IgnoredDownProjLoadsAsUnquantized) {
QuantArgs fp8_quant_args;
fp8_quant_args.quant_method() = kQuantMethodFp8;
fp8_quant_args.bits() = 8;
fp8_quant_args.activation_dynamic() = false;
fp8_quant_args.ignored_modules() = {"model.layers.1.mlp.down_proj"};

const int64_t hidden_size = 16;
const int64_t intermediate_size = 32;
auto mlp = DenseMLP(DenseMLPImpl(hidden_size,
intermediate_size,
/*is_gated=*/true,
/*has_bias=*/false,
/*hidden_act=*/"silu",
/*enable_result_reduction=*/true,
fp8_quant_args,
parallel_args_.tp_group_,
options_,
"model.layers.1.mlp"));

std::unordered_map<std::string, torch::Tensor> weight_dict;
auto fp8_weight_options = options_.dtype(torch::kFloat8_e4m3fn);
auto scale_options = options_.dtype(torch::kFloat32);

weight_dict["gate_proj.weight"] =
torch::zeros({intermediate_size, hidden_size}, fp8_weight_options);
weight_dict["gate_proj.weight_scale"] = torch::ones({1}, scale_options);
weight_dict["gate_proj.input_scale"] = torch::ones({1}, scale_options);

weight_dict["up_proj.weight"] =
torch::zeros({intermediate_size, hidden_size}, fp8_weight_options);
weight_dict["up_proj.weight_scale"] = torch::ones({1}, scale_options);
weight_dict["up_proj.input_scale"] = torch::ones({1}, scale_options);

weight_dict["down_proj.weight"] =
torch::zeros({hidden_size, intermediate_size}, options_);

StateDict state_dict(weight_dict);
mlp->load_state_dict(state_dict);

const auto params = mlp->named_parameters(/*recurse=*/true);
EXPECT_TRUE(params.contains("gate_up_proj.weight_scale"));
EXPECT_TRUE(params.contains("gate_up_proj.input_scale"));
EXPECT_TRUE(params.contains("down_proj.weight"));
EXPECT_FALSE(params.contains("down_proj.weight_scale"));
EXPECT_FALSE(params.contains("down_proj.input_scale"));
}

TEST_F(DenseMLPTest, SmoothquantPrecisionVerificationTest) {
// Test precision verification with custom input and expected output
const int64_t batch_size = 16;
Expand Down
8 changes: 6 additions & 2 deletions xllm/core/layers/qwen2_decoder_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ limitations under the License.
namespace xllm {
namespace layer {

Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context)
Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context,
int32_t layer_id)
: parallel_args_(context.get_parallel_args()) {
const auto& model_args = context.get_model_args();
const auto& quant_args = context.get_quant_args();
const auto& options = context.get_tensor_options();
const std::string mlp_module_prefix =
layer_id >= 0 ? "model.layers." + std::to_string(layer_id) + ".mlp" : "";

// Initialize attention layers
attention_ = register_module("self_attn", Qwen2Attention(context));
Expand All @@ -46,7 +49,8 @@ Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context)
/*enable_result_reduction=*/true,
quant_args,
parallel_args_.tp_group_,
options));
options,
mlp_module_prefix));
}

void Qwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) {
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/layers/qwen2_decoder_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace layer {

class Qwen2DecoderLayerImpl : public torch::nn::Module {
public:
explicit Qwen2DecoderLayerImpl(const ModelContext& context);
explicit Qwen2DecoderLayerImpl(const ModelContext& context,
int32_t layer_id = -1);

void load_state_dict(const StateDict& state_dict);

Expand Down
5 changes: 4 additions & 1 deletion xllm/core/layers/qwen3_moe_decoder_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context,
parallel_args_,
options));
} else {
const std::string mlp_module_prefix =
"model.layers." + std::to_string(layer_id) + ".mlp";
mlp_ = register_module("mlp",
DenseMLP(model_args.hidden_size(),
model_args.intermediate_size(),
Expand All @@ -89,7 +91,8 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context,
/*enable_result_reduction=*/true,
quant_args,
parallel_args_.tp_group_,
options));
options,
mlp_module_prefix));
}
}

Expand Down
2 changes: 1 addition & 1 deletion xllm/models/llm/qwen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class QWen2ModelImpl : public LlmModelImplBase<layer::Qwen2DecoderLayer> {
register_module("embed_tokens", layer::WordEmbedding(context));

for (int32_t i = 0; i < model_args.n_layers(); i++) {
auto layer = layer::Qwen2DecoderLayer(context);
auto layer = layer::Qwen2DecoderLayer(context, i);
layers_.push_back(layer);
}
}
Expand Down
2 changes: 1 addition & 1 deletion xllm/models/llm/qwen3.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class QWen3ModelImpl : public LlmModelImplBase<layer::Qwen3DecoderLayer> {
options.device(), options.dtype().toScalarType(), mask_value);
#endif
for (int32_t i = 0; i < model_args.n_layers(); i++) {
auto layer = layer::Qwen3DecoderLayer(context);
auto layer = layer::Qwen3DecoderLayer(context, i);
layers_.push_back(layer);
}
}
Expand Down
Loading