Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions angelslim/compressor/quant/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from .packing_utils import dequantize_gemm, pack_weight_to_int8 # noqa: F401
from .quant_func import * # noqa: F401 F403
from .sample_func import EMASampler, MultiStepSampler # noqa: F401
from .save import DeepseekV3HfPTQSave # noqa: F401
from .save import DeepseekV3PTQSaveTRTLLM # noqa: F401
from .save import DeepSeekV3PTQSaveMulti # noqa: F401
from .save import DeepSeekV3PTQSaveSingle # noqa: F401
from .save import PTQDiffusionSave # noqa: F401
from .save import PTQOnlyScaleSave # noqa: F401
from .save import PTQPTMSave # noqa: F401
Expand Down
197 changes: 106 additions & 91 deletions angelslim/compressor/quant/core/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def save(self, save_path):
print_info("save weight scales done.")


class DeepseekV3HfPTQSave(PTQSaveBase):
class DeepSeekV3PTQSaveMulti(PTQSaveBase):
def __init__(self, quant_model, check_scales=False):
super().__init__(quant_model=quant_model)
self.moe_act_scales_dict = {}
Expand Down Expand Up @@ -371,7 +371,6 @@ def __init__(self, quant_model, check_scales=False):
".mlp.down_proj",
".mlp.shared_experts.down_proj",
]
self.exclude_key = ["*head*", "*kv_b*"]

def save(self, save_path):
save_path = os.path.join(save_path, "scales")
Expand Down Expand Up @@ -411,7 +410,7 @@ def save(self, save_path):
_save_path = os.path.join(
save_path, "{}.weight_scale.{}.{}.pt".format(k, "int4", _index)
)
scale_int4 = max_value_group_wise / 7
scale_int4 = max_value_group_wise / 8

# save weigth-int4-pergroup scale
if "experts" in k and "shared_experts" not in k:
Expand Down Expand Up @@ -464,7 +463,7 @@ def save(self, save_path):
)
os.makedirs(save_model_path, exist_ok=True)

self.convert_scales_to_safetensors(save_path, save_model_path)
self.convert_scales_to_safetensors(save_path, tmp_path)
print_info("convert scales to safetensors done.")

file_name = self.merge_model(
Expand All @@ -474,6 +473,17 @@ def save(self, save_path):

self.add_mtp_weight(save_path=save_model_path, file_name=file_name)

if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
parent_dir = os.path.dirname(
self.quant_model.model.ori_model_path.rstrip("/")
)
tp_model_path = os.path.join(
parent_dir, f"ds_ckpt_tp{self.quant_model.model.world_size}"
)
if os.path.exists(tp_model_path):
shutil.rmtree(tp_model_path)

def _save_ckpt(self, scale, save_path, all_reduce=True):
if all_reduce:
if self.rank == 0:
Expand Down Expand Up @@ -604,7 +614,7 @@ def merge_model(self, input_path, save_model_path, mp=16):
model_save_ind = 0
localind = 0

scale_path = os.path.join(save_model_path, "model-scales.safetensors")
scale_path = os.path.join(input_path, "model-scales.safetensors")
scales_dict = load_file(scale_path)

for mpind in range(mp):
Expand Down Expand Up @@ -684,7 +694,6 @@ def merge_model(self, input_path, save_model_path, mp=16):
param_list.append(param)
newparam = torch.cat(param_list, dim=0)
new_save_dict[k] = newparam
print_info(f"shape of {k}: {new_save_dict[k].shape}")
index_dict["weight_map"][k] = str(filename)
safe_save(new_save_dict, os.path.join(save_model_path, filename))
# process others
Expand Down Expand Up @@ -715,13 +724,8 @@ def merge_model(self, input_path, save_model_path, mp=16):
index_dict,
filename,
)
print_info(f"shape of {k}: {new_save_dict[k].shape}")
safe_save(new_save_dict, os.path.join(save_model_path, filename))

# update scales map
for k, _ in scales_dict.items():
index_dict["weight_map"][k] = "model-scales.safetensors"

path = self.quant_model.model.ori_model_path
for file_path in glob(os.path.join(path, "*token*")):
new_file_path = os.path.join(save_model_path, os.path.basename(file_path))
Expand Down Expand Up @@ -754,21 +758,25 @@ def merge_model(self, input_path, save_model_path, mp=16):
quant_dict = {
"quantization_config": {
"quant_method": "w4a8_awq",
"kv_cache_quant_method": "fp8",
"weight_group_size": 128,
"activation_scheme": (
"dynamic" if "dynamic" in a_quant_algo else "static"
),
"ignored_modules": [
"kv_cache_quant_method": "fp8",
"ignored_layers": [
"*self_attn*",
"*gate_up_proj",
"*down_proj",
"*layers.61*",
],
"ignored_quantization_config": {
"quant_method": "fp8_block_scales",
"quant_method": "fp8",
"activation_scheme": "dynamic",
"fmt": "e4m3",
"kv_cache_quant_method": "fp8",
"weight_block_size": [128, 128],
},
},
}
}
else:
raise NotImplementedError(
Expand Down Expand Up @@ -813,40 +821,44 @@ def _transform_keys(
filename,
):
if "fp8" in self.quant_model.quant_config.quant_algo:
if "w4a8" in self.quant_model.quant_config.quant_algo:
if self._is_packed(param_name, param, scales_dict):
param = self._packed_weight(
param_name,
param,
self.quant_model.quant_config.quant_algo_info["w_group_size"],
scales_dict,
)
param_name = param_name.replace("weight", "qweight")
else:
if not any(
substring in param_name
for substring in self.quant_model.quant_config.quant_algo_info[
"ignore_layers"
]
):
if param_name.endswith("weight_scale_inv"):
return
weight_scale = scales_dict.get(f"{param_name}_scale", None)
if weight_scale is not None:
new_save_dict[f"{param_name}_scale"] = weight_scale
new_save_dict[f"{param_name[:-7]}.input_scale"] = scales_dict[
f"{param_name[:-7]}.input_scale"
]
index_dict["weight_map"][f"{param_name}_scale"] = str(filename)
index_dict["weight_map"][f"{param_name[:-7]}.input_scale"] = str(
filename
)
if "w4a8" in self.quant_model.quant_config.quant_algo:
param = self._packed_weight(
param_name,
param,
self.quant_model.quant_config.quant_algo_info[
"w_group_size"
],
scales_dict,
)
new_save_dict[f"{param_name}_scale.int4"] = scales_dict[
f"{param_name}_scale.int4"
]
index_dict["weight_map"][f"{param_name}_scale.int4"] = str(
filename
)
param_name = param_name.replace("weight", "qweight")

new_save_dict[param_name] = param
index_dict["weight_map"][param_name] = str(filename)

def _is_packed(self, weight_name, weight, scales_dict):
if weight_name.endswith("weight_scale_inv") or self._is_exclude(weight_name):
return False
elif weight.element_size() == 1:
if f"{weight_name}_scale.int4" in scales_dict.keys():
return True
return False
else:
return False

def _is_exclude(self, tensor_name):
for pattern in self.exclude_key:
# Convert fnmatch-style pattern to regex
regex_pattern = pattern.replace("*", ".*").replace("?", ".")
if re.fullmatch(regex_pattern, tensor_name):
return True
return False

def _packed_weight(self, weight_name, weight, block_wise, scales_dict):
target_shape = (weight.shape[0] // block_wise, weight.shape[1] // block_wise)
scale_inv = scales_dict[f"{weight_name}_scale"]
Expand Down Expand Up @@ -920,65 +932,68 @@ def add_mtp_weight(self, input_path=None, save_path=None, file_name=None):
json.dump(new_model_index, f, indent=2)


class DeepseekV3PTQSaveTRTLLM(DeepseekV3HfPTQSave):
class DeepSeekV3PTQSaveSingle(DeepSeekV3PTQSaveMulti):
def __init__(self, quant_model):
super().__init__(quant_model=quant_model)

def save(self, save_path):
# setting quantization config
a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"]
if "w4a8" in self.quant_model.quant_config.quant_algo:
quant_dict = {
"quantization_config": {
"quant_method": "w4a8_awq",
"kv_cache_quant_method": "fp8",
"activation_scheme": (
"dynamic" if "dynamic" in a_quant_algo else "static"
),
"ignored_modules": [
"*self_attn*",
"*gate_up_proj",
"*down_proj",
"*layers.61*",
],
"ignored_quantization_config": {
"quant_method": "fp8_block_scales",
"kv_cache_quant_method": "fp8",
},
},
}
int4_scales = {}
for name, sub_layer in self.quant_model.model.named_modules():
if isinstance(sub_layer, QDQModule):
max_value_group_wise = sub_layer.weight_scale.data.clone()
int4_scales[f"{name}.weight_scale.int4"] = max_value_group_wise / 8
sub_layer.weight_scale = None
sub_layer.weight_scale = torch.nn.Parameter(
(max_value_group_wise.max() / 448.0).to(
max_value_group_wise.dtype
),
requires_grad=False,
if "fp8" in self.quant_model.quant_config.quant_algo:
if "w4a8" in self.quant_model.quant_config.quant_algo:
if self.quant_model.deploy_backend == "trtllm":
quant_dict = {
"quantization_config": {
"quant_method": "w4a8_awq",
"weight_group_size": 128,
"activation_scheme": (
"dynamic" if "dynamic" in a_quant_algo else "static"
),
"kv_cache_quant_method": "fp8",
"ignored_layers": [
"*self_attn*",
"*gate_up_proj",
"*down_proj",
"*layers.61*",
],
"ignored_quantization_config": {
"quant_method": "fp8",
"activation_scheme": "dynamic",
"fmt": "e4m3",
"kv_cache_quant_method": "fp8",
"weight_block_size": [128, 128],
},
}
}
else:
raise NotImplementedError(
f"deploy_backend {self.quant_model.deploy_backend} \
is not supported for w4a8_fp8."
)
else:
ignore_layers = self.quant_model.quant_config.quant_algo_info[
"ignore_layers"
]
if self.quant_model.deploy_backend == "vllm":
quant_dict = {
"quantization_config": {
"quant_method": "fp8",
"activation_scheme": (
"dynamic" if "dynamic" in a_quant_algo else "static"
),
"ignored_layers": ignore_layers,
}
}
else:
raise NotImplementedError(
f"deploy_backend {self.quant_model.deploy_backend} \
is not supported for fp8_static."
)

os.makedirs(save_path, exist_ok=True)
safetensor_file = os.path.join(save_path, "model-scales.safetensors")
safe_save(int4_scales, safetensor_file)
print_info(f"Save int4 scales to {safetensor_file}")

self.quant_model.get_model().config.update(quant_dict)
print_info("Save quantization_config: {}".format(quant_dict))

self.quant_model.get_model().save_pretrained(save_path)

new_model_index_file = os.path.join(
save_path, "model.safetensors.index.json"
)
with open(new_model_index_file, "r") as f:
new_model_index = json.load(f)
for key in int4_scales.keys():
new_model_index["weight_map"][key] = "model-scales.safetensors"
with open(new_model_index_file, "w") as f:
json.dump(new_model_index, f, indent=2)

self.add_mtp_weight(save_path=save_path)
else:
raise ValueError(
Expand Down
18 changes: 15 additions & 3 deletions angelslim/compressor/quant/modules/helper_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,8 @@ def __init__(
self.quant_algo = quant_algo
if "fp8" in quant_algo:
if "w4a8" in self.quant_algo:
tensor_max_value = weight_scale.clone()
tensor_wise_scale = tensor_max_value.max() / 448.0
max_value_group_wise = weight_scale.clone()
tensor_wise_scale = max_value_group_wise.max() / 448.0
quant_weight, _ = quantize_weight_per_tensor_fp8(
weight, tensor_wise_scale
)
Expand All @@ -587,10 +587,14 @@ def __init__(
new_weight_bf16, method="groupwise", bits=4, group_size=group_size
)
quant_weight, _ = quantize_weight_int(
new_weight_bf16_qdq, tensor_max_value, bits=4
new_weight_bf16_qdq, max_value_group_wise, bits=4
)
quant_weight = pack_weight_to_int8(quant_weight)
del new_weight_bf16_qdq, new_weight_bf16
self.weight_scale_int4 = torch.nn.Parameter(
max_value_group_wise / 8, requires_grad=False
)
weight_scale = tensor_wise_scale
else:
quant_weight, weight_scale = quantize_weight_per_tensor_fp8(
weight, weight_scale
Expand Down Expand Up @@ -653,6 +657,14 @@ def forward(self, x):
output = qoutput.to(output.dtype) * self.output_scale
return output

def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
keys_to_rename = [k for k in state_dict.keys() if "weight_scale_int4" in k]
for old_key in keys_to_rename:
new_key = old_key.replace("weight_scale_int4", "weight_scale.int4")
state_dict[new_key] = state_dict.pop(old_key)
return state_dict


class QLinear(torch.nn.Module):
def __init__(
Expand Down
3 changes: 2 additions & 1 deletion angelslim/compressor/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def _convert(self):
)

qdq_module = self.quant_model.get_qdq_module(sub_layer, name)
setattr(parent_layer, sub_name, qdq_module)
if qdq_module is not sub_layer:
setattr(parent_layer, sub_name, qdq_module)
self.quant_model.quantized = True

def __getattr__(self, item):
Expand Down
18 changes: 18 additions & 0 deletions angelslim/data/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,24 @@ def _load_jsonl_data(self, data_path: str, num_samples: int):
messages, tokenize=False, add_generation_prompt=True
)

thinking_data = False
for dic in messages:
if dic["role"] == "assistant":
if "<think>" and "</think>" in dic["content"]:
thinking_data = True
break
if thinking_data:
text = self.processor.bos_token
for dic in messages:
if dic["role"] == "system":
text += dic["content"]
elif dic["role"] == "user":
text = (
text + "<|User|>" + dic["content"] + "<|Assistant|>"
)
elif dic["role"] == "assistant":
text = text + dic["content"] + self.processor.eos_token

model_inputs = self.processor(
[text],
return_tensors="pt",
Expand Down
Loading