Skip to content

Commit e6952fa

Browse files
authored
fix Qwen3-VL-235B fp8 block-wise quant OOM (#180)
1 parent e6204cf commit e6952fa

1 file changed

Lines changed: 32 additions & 4 deletions

File tree

tools/fp8_quant_blockwise.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,17 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
6363

6464
block_size_m, block_size_n = weight_block_size
6565
rows, cols = param.shape[-2:]
66+
original_device = param.device
6667

6768
# Tensor-wise
6869
if block_size_m == -1 or block_size_m > rows:
6970
block_size_m = rows
7071
if block_size_n == -1 or block_size_n > cols:
7172
block_size_n = cols
7273

74+
# Move to CPU for padding to save GPU memory
75+
param = param.cpu()
76+
7377
if rows % block_size_m != 0:
7478
pad = torch.zeros(
7579
[*param.shape[:-2], block_size_m - rows % block_size_m, cols],
@@ -86,6 +90,7 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
8690
param = torch.concat([param, pad], dim=-1)
8791
param_value_shape = param.shape
8892

93+
# Convert to float on CPU first
8994
param_value = (
9095
param.float()
9196
.reshape(
@@ -98,6 +103,11 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
98103
.permute(0, 1, 3, 2, 4)
99104
)
100105

106+
# Move back to GPU for quantization
107+
param_value = param_value.to(original_device)
108+
del param # Free CPU memory
109+
torch.cuda.empty_cache()
110+
101111
# Calculate scaling factor for each block
102112
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
103113
scale_inv = fp8_max / max_abs
@@ -108,6 +118,9 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
108118
quantized_param = torch.clamp(param_value * scale_inv, min=fp8_min, max=fp8_max).to(
109119
torch.float8_e4m3fn
110120
)
121+
del param_value # Free GPU memory
122+
torch.cuda.empty_cache()
123+
111124
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
112125
quantized_param = quantized_param.reshape(param_value_shape)[..., :rows, :cols]
113126

@@ -120,33 +133,48 @@ def process_safetensor(rank, file_name, input_path, output_path, block_size=(128
120133
state_dict = {}
121134
index = {}
122135
count = 0
136+
137+
# Load tensors on CPU first to avoid GPU memory issues
123138
with safe_open(
124-
os.path.join(input_path, file_name), framework="pt", device=f"cuda:{rank}"
139+
os.path.join(input_path, file_name), framework="pt", device="cpu"
125140
) as f:
126141
print(f"Processing {file_name} with {len(f.keys())} weights")
127142
for weight_name in f.keys():
128143
weight = f.get_tensor(weight_name)
129144
if any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT):
145+
# Move to GPU only for quantization
146+
weight = weight.to(f"cuda:{rank}")
130147
quant_weight, scale = create_quantized_param(weight, block_size)
131-
state_dict[weight_name] = quant_weight
148+
149+
# Move back to CPU for saving
150+
state_dict[weight_name] = quant_weight.cpu()
132151
index[weight_name] = file_name
133152

134153
# Reference: https://github.com/vllm-project/vllm/blob/v0.10.1/vllm/model_executor/layers/quantization/fp8.py#L295 # noqa: E501
135154
if block_size[0] == -1 and block_size[1] == -1:
136155
# Tensor-wise
137-
state_dict[f"{weight_name}_scale"] = scale
156+
state_dict[f"{weight_name}_scale"] = scale.cpu()
138157
index[f"{weight_name}_scale"] = file_name
139158
else:
140159
# Block-wise
141-
state_dict[f"{weight_name}_scale_inv"] = scale
160+
state_dict[f"{weight_name}_scale_inv"] = scale.cpu()
142161
index[f"{weight_name}_scale_inv"] = file_name
162+
163+
# Clean up GPU memory after each weight
164+
del weight, quant_weight, scale
165+
torch.cuda.empty_cache()
143166
else:
144167
state_dict[weight_name] = weight
145168
index[weight_name] = file_name
146169
count += 1
147170

148171
new_safetensor_file = os.path.join(output_path, file_name)
149172
save_file(state_dict, new_safetensor_file)
173+
174+
# Final cleanup
175+
del state_dict
176+
torch.cuda.empty_cache()
177+
150178
return index
151179

152180

0 commit comments

Comments
 (0)