-
Notifications
You must be signed in to change notification settings - Fork 196
feat: support Qwen down_proj fallback for compressed-tensors ignored modules. #1254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,9 @@ limitations under the License. | |
| #pragma once | ||
|
|
||
| #include <ostream> | ||
| #include <regex> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "common/macros.h" | ||
|
|
||
|
|
@@ -55,6 +57,35 @@ struct QuantArgs { | |
| // weight block size | ||
| PROPERTY(std::vector<int64_t>, weight_block_size) = {}; | ||
|
|
||
| // exact module names or regexes prefixed with "re:" that should bypass | ||
| // quantization for compressed-tensors models. | ||
| PROPERTY(std::vector<std::string>, ignored_modules) = {}; | ||
|
|
||
| bool should_ignore_module(const std::string& module_name) const { | ||
| for (const auto& pattern : ignored_modules()) { | ||
| if (pattern == module_name) { | ||
| return true; | ||
| } | ||
| if (pattern.size() > 3 && pattern.rfind("re:", 0) == 0) { | ||
| try { | ||
| if (std::regex_match(module_name, std::regex(pattern.substr(3)))) { | ||
| return true; | ||
| } | ||
| } catch (const std::regex_error&) { | ||
| } | ||
| } | ||
|
Comment on lines
+69
to
+76
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating a |
||
| } | ||
| return false; | ||
| } | ||
|
|
||
| QuantArgs for_module(const std::string& module_name) const { | ||
| QuantArgs local_args = *this; | ||
| if (should_ignore_module(module_name)) { | ||
| local_args.quant_method().clear(); | ||
| } | ||
| return local_args; | ||
| } | ||
|
|
||
| // check if weights can be fused | ||
| bool can_be_fused() const { | ||
| // can't fuse quantized weights if desc_act is true | ||
|
|
@@ -72,6 +103,7 @@ inline std::ostream& operator<<(std::ostream& os, const QuantArgs& args) { | |
| os << ", is_sym: " << args.is_sym(); | ||
| os << ", activation_dynamic: " << args.activation_dynamic(); | ||
| os << ", fmt: " << args.fmt(); | ||
| os << ", ignored_modules: " << args.ignored_modules().size(); | ||
| os << "]"; | ||
| return os; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this rule applicable to all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This field is part of the quantization config schema, not a model-specific rule. The JSON (including
ignore) is generated by the quantization tool — at least AngelSlim produces this field. So the applicability depends on which quant tool was used, not the model itself.