Skip to content

Commit 02b1983

Browse files
committed
fix quarot bug
1 parent 37a9f37 commit 02b1983

4 files changed

Lines changed: 31 additions & 27 deletions

File tree

llmc/compression/blockwise_optimization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def __init__(self, model, quant_config, input, config):
1111
self.quant_config = quant_config
1212
self.sparsity_config = quant_config
1313
self.input = input
14+
self.data_free = False if self.input else True
1415
self.config = config
1516
self.block_idx = None
1617
self.num_blocks = len(self.blocks)

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -246,36 +246,39 @@ def block_opt(self, block):
246246
handles = []
247247
self.block_init(block)
248248

249-
for name in named_linears:
250-
handles.append(
251-
named_linears[name].register_forward_hook(
252-
functools.partial(
253-
self.cache_input_hook, name=name, feat_dict=input_feat
249+
if not self.data_free:
250+
for name in named_linears:
251+
handles.append(
252+
named_linears[name].register_forward_hook(
253+
functools.partial(
254+
self.cache_input_hook, name=name, feat_dict=input_feat
255+
)
254256
)
255257
)
256-
)
257-
258-
if self.quant_out:
259-
self.block_forward(block)
260-
else:
261-
self.input['data'] = self.block_forward(block)
262258

263-
for h in handles:
264-
h.remove()
265-
torch.cuda.empty_cache()
259+
if self.quant_out:
260+
self.block_forward(block)
261+
else:
262+
self.input['data'] = self.block_forward(block)
266263

267-
self.block_transform(block, input_feat, self.input['kwargs'])
264+
for h in handles:
265+
h.remove()
266+
torch.cuda.empty_cache()
267+
self.block_transform(block, input_feat, self.input['kwargs'])
268+
else:
269+
self.block_transform(block)
268270

269-
if self.quant_out:
270-
self.model.replace_module_block(
271-
FakeQuantLinear,
272-
block,
273-
self.block_idx,
274-
self.get_replacement_params(
275-
mode='fake_quant', w_only=self.w_only, name=None
276-
),
277-
)
278-
self.input['data'] = self.block_forward(block)
271+
if not self.data_free:
272+
if self.quant_out:
273+
self.model.replace_module_block(
274+
FakeQuantLinear,
275+
block,
276+
self.block_idx,
277+
self.get_replacement_params(
278+
mode='fake_quant', w_only=self.w_only, name=None
279+
),
280+
)
281+
self.input['data'] = self.block_forward(block)
279282

280283
block = block.cpu()
281284
del input_feat

llmc/compression/quantization/quarot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_orthogonal_matrix(self):
6060
else:
6161
raise ValueError(f'Unsupported mode {self.mode}')
6262

63-
def block_transform(self, block, input_feat, block_kwargs):
63+
def block_transform(self, block):
6464
logger.info(f'Start transform the {self.block_idx+1}-th block')
6565

6666
if self.online_rotate:

scripts/run_quarot_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export PYTHONPATH=$llmc:$PYTHONPATH
99
task_name=llm_quant_exp
1010

1111
nohup \
12-
python -m llmc --config ../configs/quantization/QuaRot/quarot_w4a4.yml\
12+
python -m llmc --config ../configs/quantization/QuaRot/quarot_w4a4.yml \
1313
> ${task_name}.log 2>&1 &
1414

1515
echo $! > ${task_name}.pid

0 commit comments

Comments
 (0)