Skip to content

Commit a719ae2

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 1829ee7 commit a719ae2

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def _fakequant_module_weights(
170170
weight_name = attr_name.removesuffix("_quantizer")
171171
prefix = f"{module_name}." if module_name else ""
172172
sd_key = f"{prefix}{weight_name}"
173-
assert sd_key not in fakequant_weights, f"Weight {sd_key} has already been fakequantized"
173+
if sd_key in fakequant_weights:
174+
raise RuntimeError(f"Weight {sd_key} has already been fakequantized")
174175

175176
if inplace:
176177
w = getattr(module, weight_name)
@@ -179,7 +180,8 @@ def _fakequant_module_weights(
179180
else:
180181
w_quant = quantizer(w.float()).to(w.dtype)
181182
else:
182-
assert state_dict is not None
183+
if state_dict is None:
184+
raise RuntimeError("state_dict is required when inplace=False for fakequant export")
183185
if sd_key not in state_dict:
184186
continue
185187
w = state_dict[sd_key]
@@ -209,7 +211,8 @@ def _fakequant_module_weights(
209211
if inplace:
210212
w.data.copy_(w_quant)
211213
else:
212-
assert state_dict is not None
214+
if state_dict is None:
215+
raise RuntimeError("state_dict is required when inplace=False for fakequant export")
213216
state_dict[sd_key] = w_quant.cpu()
214217
fakequant_weights.add(sd_key)
215218

@@ -390,7 +393,10 @@ def _process_group(modules: list[nn.Module]) -> None:
390393
)
391394
w_param.data.copy_((w_param.to(torch.float32) * ratio).to(w_param.dtype))
392395
else:
393-
assert state_dict is not None
396+
if state_dict is None:
397+
raise RuntimeError(
398+
"state_dict is required when inplace=False in _resmooth_experts_for_export"
399+
)
394400
weight = state_dict[w_key]
395401
ratio = old_pqs.to(dtype=torch.float32, device=weight.device) / avg_pqs.to(
396402
device=weight.device
@@ -424,7 +430,10 @@ def _process_group(modules: list[nn.Module]) -> None:
424430
if not experts:
425431
continue
426432
if inplace:
427-
assert name_to_module is not None
433+
if name_to_module is None:
434+
raise RuntimeError(
435+
"name_to_module is required when inplace=True in _resmooth_experts_for_export"
436+
)
428437
with _enable_writeback_for_group(experts, model, name_to_module):
429438
_process_group(experts)
430439
else:
@@ -450,7 +459,10 @@ def _dummy_forward() -> None:
450459
if len(modules) <= 1:
451460
continue
452461
if inplace:
453-
assert name_to_module is not None
462+
if name_to_module is None:
463+
raise RuntimeError(
464+
"name_to_module is required when inplace=True in _resmooth_experts_for_export"
465+
)
454466
with _enable_writeback_for_group(modules, model, name_to_module):
455467
_process_group(modules)
456468
else:
@@ -499,9 +511,10 @@ def export_hf_vllm_fq_checkpoint(
499511
pqs_overrides, requant_weights = _resmooth_experts_for_export(model, None, inplace=True)
500512
# Inplace path: iterate decoder layers, one offload<->onload per layer.
501513
decoder_layers = LayerActivationCollector.get_decoder_layers(model)
502-
assert decoder_layers is not None, (
503-
"inplace_mem_efficient=True requires a model with discoverable decoder layers"
504-
)
514+
if decoder_layers is None:
515+
raise RuntimeError(
516+
"inplace_mem_efficient=True requires a model with discoverable decoder layers"
517+
)
505518
for name, module in model.named_modules():
506519
if module not in decoder_layers:
507520
continue

0 commit comments

Comments
 (0)