@@ -63,13 +63,17 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
6363
6464 block_size_m , block_size_n = weight_block_size
6565 rows , cols = param .shape [- 2 :]
66+ original_device = param .device
6667
6768 # Tensor-wise
6869 if block_size_m == - 1 or block_size_m > rows :
6970 block_size_m = rows
7071 if block_size_n == - 1 or block_size_n > cols :
7172 block_size_n = cols
7273
74+ # Move to CPU for padding to save GPU memory
75+ param = param .cpu ()
76+
7377 if rows % block_size_m != 0 :
7478 pad = torch .zeros (
7579 [* param .shape [:- 2 ], block_size_m - rows % block_size_m , cols ],
@@ -86,6 +90,7 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
8690 param = torch .concat ([param , pad ], dim = - 1 )
8791 param_value_shape = param .shape
8892
93+ # Convert to float on CPU first
8994 param_value = (
9095 param .float ()
9196 .reshape (
@@ -98,6 +103,11 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
98103 .permute (0 , 1 , 3 , 2 , 4 )
99104 )
100105
106+ # Move back to GPU for quantization
107+ param_value = param_value .to (original_device )
108+ del param # Free CPU memory
109+ torch .cuda .empty_cache ()
110+
101111 # Calculate scaling factor for each block
102112 max_abs = torch .amax (torch .abs (param_value ), dim = (- 1 , - 2 ))
103113 scale_inv = fp8_max / max_abs
@@ -108,6 +118,9 @@ def create_quantized_param(param, weight_block_size=(128, 128)):
108118 quantized_param = torch .clamp (param_value * scale_inv , min = fp8_min , max = fp8_max ).to (
109119 torch .float8_e4m3fn
110120 )
121+ del param_value # Free GPU memory
122+ torch .cuda .empty_cache ()
123+
111124 quantized_param = quantized_param .permute (0 , 1 , 3 , 2 , 4 )
112125 quantized_param = quantized_param .reshape (param_value_shape )[..., :rows , :cols ]
113126
@@ -120,33 +133,48 @@ def process_safetensor(rank, file_name, input_path, output_path, block_size=(128
120133 state_dict = {}
121134 index = {}
122135 count = 0
136+
137+ # Load tensors on CPU first to avoid GPU memory issues
123138 with safe_open (
124- os .path .join (input_path , file_name ), framework = "pt" , device = f"cuda: { rank } "
139+ os .path .join (input_path , file_name ), framework = "pt" , device = "cpu "
125140 ) as f :
126141 print (f"Processing { file_name } with { len (f .keys ())} weights" )
127142 for weight_name in f .keys ():
128143 weight = f .get_tensor (weight_name )
129144 if any (weight_name .endswith (suffix ) for suffix in SUFFIX_TO_QUANT ):
145+ # Move to GPU only for quantization
146+ weight = weight .to (f"cuda:{ rank } " )
130147 quant_weight , scale = create_quantized_param (weight , block_size )
131- state_dict [weight_name ] = quant_weight
148+
149+ # Move back to CPU for saving
150+ state_dict [weight_name ] = quant_weight .cpu ()
132151 index [weight_name ] = file_name
133152
134153 # Reference: https://github.com/vllm-project/vllm/blob/v0.10.1/vllm/model_executor/layers/quantization/fp8.py#L295 # noqa: E501
135154 if block_size [0 ] == - 1 and block_size [1 ] == - 1 :
136155 # Tensor-wise
137- state_dict [f"{ weight_name } _scale" ] = scale
156+ state_dict [f"{ weight_name } _scale" ] = scale . cpu ()
138157 index [f"{ weight_name } _scale" ] = file_name
139158 else :
140159 # Block-wise
141- state_dict [f"{ weight_name } _scale_inv" ] = scale
160+ state_dict [f"{ weight_name } _scale_inv" ] = scale . cpu ()
142161 index [f"{ weight_name } _scale_inv" ] = file_name
162+
163+ # Clean up GPU memory after each weight
164+ del weight , quant_weight , scale
165+ torch .cuda .empty_cache ()
143166 else :
144167 state_dict [weight_name ] = weight
145168 index [weight_name ] = file_name
146169 count += 1
147170
148171 new_safetensor_file = os .path .join (output_path , file_name )
149172 save_file (state_dict , new_safetensor_file )
173+
174+ # Final cleanup
175+ del state_dict
176+ torch .cuda .empty_cache ()
177+
150178 return index
151179
152180
0 commit comments