Skip to content

Commit 0917ab8

Browse files
committed
fix test errors
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent d33ee36 commit 0917ab8

2 files changed

Lines changed: 35 additions & 4 deletions

File tree

modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from torch.autograd import Function
2222

23+
from modelopt.torch.opt.config import ModeloptBaseConfig
2324
from modelopt.torch.quantization.backends.gemm_registry import gemm_registry
2425
from modelopt.torch.quantization.config import FP8_DEFAULT_CFG, find_quant_cfg_entry_by_path
2526
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
@@ -121,17 +122,33 @@ def _fp8_availability_check(module, input, args, kwargs):
121122
# Check quantizer presence and configuration
122123
if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"):
123124
return False
125+
if not module.input_quantizer.is_enabled or not module.weight_quantizer.is_enabled:
126+
return False
124127

125128
# Check input quantizer config
126-
for key, value in input_cfg.items():
129+
# TODO: Move this compatibility check inside the quantizer; matching config items here
130+
# is fragile and easy to break as config semantics evolve.
131+
input_items = input_cfg.items
132+
if isinstance(input_cfg, ModeloptBaseConfig):
133+
input_items = input_cfg.explicit_items
134+
for key, value in input_items():
135+
if key == "enable":
136+
continue
127137
if (
128138
not hasattr(module.input_quantizer, key)
129139
or getattr(module.input_quantizer, key) != value
130140
):
131141
return False
132142

133143
# Check weight quantizer config
134-
for key, value in weight_cfg.items():
144+
# TODO: Move this compatibility check inside the quantizer; matching config items here
145+
# is fragile and easy to break as config semantics evolve.
146+
weight_items = weight_cfg.items
147+
if isinstance(weight_cfg, ModeloptBaseConfig):
148+
weight_items = weight_cfg.explicit_items
149+
for key, value in weight_items():
150+
if key == "enable":
151+
continue
135152
if (
136153
not hasattr(module.weight_quantizer, key)
137154
or getattr(module.weight_quantizer, key) != value

modelopt/torch/quantization/backends/nvfp4_gemm.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.autograd import Function
2222

2323
import modelopt.torch.quantization as mtq
24+
from modelopt.torch.opt.config import ModeloptBaseConfig
2425
from modelopt.torch.quantization.backends.gemm_registry import gemm_registry
2526
from modelopt.torch.quantization.backends.utils import fp4_compatible
2627
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
@@ -230,7 +231,15 @@ def _nvfp4_availability_check(module, input, args, kwargs):
230231
return False
231232

232233
# Check input quantizer config
233-
for key, value in input_cfg.items():
234+
# TODO: Move this compatibility check inside the quantizer; matching config items here
235+
# is fragile and easy to break as config semantics evolve.
236+
if not module.input_quantizer.is_enabled or not module.weight_quantizer.is_enabled:
237+
return False
238+
239+
input_items = input_cfg.items
240+
if isinstance(input_cfg, ModeloptBaseConfig):
241+
input_items = input_cfg.explicit_items
242+
for key, value in input_items():
234243
if key == "enable":
235244
continue
236245
if (
@@ -240,7 +249,12 @@ def _nvfp4_availability_check(module, input, args, kwargs):
240249
return False
241250

242251
# Check weight quantizer config
243-
for key, value in weight_cfg.items():
252+
# TODO: Move this compatibility check inside the quantizer; matching config items here
253+
# is fragile and easy to break as config semantics evolve.
254+
weight_items = weight_cfg.items
255+
if isinstance(weight_cfg, ModeloptBaseConfig):
256+
weight_items = weight_cfg.explicit_items
257+
for key, value in weight_items():
244258
if key == "enable":
245259
continue
246260
if (

0 commit comments

Comments
 (0)