Skip to content

Commit 9b9377a

Browse files
committed
Fix
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 4b4ef63 commit 9b9377a

5 files changed

Lines changed: 53 additions & 19 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,8 +727,12 @@ def post_quantize(
727727
"""
728728

729729
if args.verbose:
730-
mtq.print_quant_summary(full_model)
731-
save_expert_token_count_table(full_model, args.export_path)
730+
try:
731+
mtq.print_quant_summary(full_model, args.export_path)
732+
save_expert_token_count_table(full_model, args.export_path)
733+
except Exception as e:
734+
print(f"Error saving quant summary: {e}")
735+
print("Continuing with generation...")
732736

733737
# Run some samples
734738
torch.cuda.empty_cache()

modelopt/torch/export/moe_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | Non
3939
return
4040

4141
num_experts = rows[0][1].shape[0]
42+
assert all(r[1].shape[0] == num_experts for r in rows), (
43+
"All MoE layers must have the same number of experts"
44+
)
4245
html_parts = [
4346
"<html><head><style>",
4447
"table { border-collapse: collapse; font-family: monospace; }",
@@ -70,5 +73,5 @@ def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | Non
7073
if output_dir is None:
7174
output_dir = Path(".")
7275
output_path = Path(output_dir) / ".moe.html"
73-
output_path.write_text(html_content)
74-
print(f"Expert token count table saved to {output_path}")
76+
output_path.write_text(html_content, encoding="utf-8")
77+
print(f"\033[1mExpert token count table saved to {output_path}\033[0m")

modelopt/torch/quantization/model_quant.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -508,14 +508,26 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable):
508508

509509

510510
@atomic_print
511-
def print_quant_summary(model: nn.Module):
511+
def print_quant_summary(model: nn.Module, output_dir: str | None = None):
512512
"""Print summary of all quantizer modules in the model."""
513-
count = 0
514-
for name, mod in model.named_modules():
515-
if isinstance(mod, TensorQuantizer):
516-
print(f"{name:80} {mod}")
517-
count += 1
518-
print(f"{count} TensorQuantizers found in model")
513+
lines = [
514+
f"{name:80} {mod}"
515+
for name, mod in model.named_modules()
516+
if isinstance(mod, TensorQuantizer)
517+
]
518+
lines.append(f"{len(lines)} TensorQuantizers found in model")
519+
520+
if output_dir:
521+
path = (
522+
output_dir.joinpath(".quant_summary.txt")
523+
if hasattr(output_dir, "joinpath")
524+
else f"{output_dir}/.quant_summary.txt"
525+
)
526+
with open(path, "w", encoding="utf-8") as f:
527+
f.write("\n".join(lines) + "\n")
528+
print(f"\033[1mQuant summary saved to {path}\033[0m")
529+
else:
530+
print("\n".join(lines))
519531

520532

521533
def fold_weight(model: nn.Module):

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,13 @@ def _setup(self):
461461
self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
462462
self._count_expert_tokens = False
463463

464+
if num_experts == 0:
465+
warnings.warn(
466+
f"{self.__class__.__name__}: could not resolve num_experts; "
467+
"expert routing will not be tracked for this layer."
468+
)
469+
return
470+
464471
if hasattr(self, "gate"):
465472
self.gate.register_forward_hook(self._gate_forward_hook)
466473

@@ -488,8 +495,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
488495
# This is used only for calibration, we need to re-calculate the actual outputs again using
489496
# the original top_k
490497
if TRANSFORMERS_VERSION_GE_5_0:
491-
assert hasattr(self, "gate")
492-
# Path for transformers >= 5.0
498+
assert hasattr(self, "gate") and hasattr(self.gate, "top_k")
493499
original_top_k = self.gate.top_k
494500
self.gate.top_k = self.gate.num_experts
495501
super().forward(hidden_states)

tests/unit/torch/quantization/plugins/test_sparse_moe.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,14 @@ def test_setup_creates_expert_token_count(self):
193193

194194
converted = QuantModuleRegistry.convert(moe_block)
195195
assert hasattr(converted, "expert_token_count")
196-
expected_num_experts = moe_block.num_experts if hasattr(moe_block, "num_experts") else 0
196+
if hasattr(moe_block, "gate") and hasattr(moe_block.gate, "num_experts"):
197+
expected_num_experts = moe_block.gate.num_experts
198+
elif hasattr(moe_block, "num_experts"):
199+
expected_num_experts = moe_block.num_experts
200+
elif hasattr(moe_block, "experts") and hasattr(moe_block.experts, "num_experts"):
201+
expected_num_experts = moe_block.experts.num_experts
202+
else:
203+
expected_num_experts = 0
197204
assert converted.expert_token_count.shape == (expected_num_experts,)
198205
assert converted.expert_token_count.dtype == torch.long
199206
assert (converted.expert_token_count == 0).all()
@@ -298,14 +305,16 @@ def test_gate_forward_hook_counts_tokens(self):
298305
converted.expert_token_count.zero_()
299306
converted._count_expert_tokens = True
300307

301-
hidden_size = converted.gate.in_features
308+
if TRANSFORMERS_VERSION_GE_5_0:
309+
hidden_size = converted.gate.weight.shape[1]
310+
top_k = converted.gate.top_k
311+
else:
312+
hidden_size = converted.gate.in_features
313+
top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k
314+
302315
x = torch.randn(8, hidden_size)
303316
with torch.no_grad():
304317
converted.gate(x)
305-
306-
# After one gate call with counting enabled, total assigned tokens should equal
307-
# num_tokens * top_k
308-
top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k
309318
total_assigned = converted.expert_token_count.sum().item()
310319
assert total_assigned == 8 * top_k
311320

0 commit comments

Comments
 (0)