|
20 | 20 | import torch |
21 | 21 | from torch.autograd import Function |
22 | 22 |
|
| 23 | +from modelopt.torch.opt.config import ModeloptBaseConfig |
23 | 24 | from modelopt.torch.quantization.backends.gemm_registry import gemm_registry |
24 | 25 | from modelopt.torch.quantization.config import FP8_DEFAULT_CFG, find_quant_cfg_entry_by_path |
25 | 26 | from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear |
@@ -121,17 +122,33 @@ def _fp8_availability_check(module, input, args, kwargs): |
121 | 122 | # Check quantizer presence and configuration |
122 | 123 | if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"): |
123 | 124 | return False |
| 125 | + if not module.input_quantizer.is_enabled or not module.weight_quantizer.is_enabled: |
| 126 | + return False |
124 | 127 |
|
125 | 128 | # Check input quantizer config |
126 | | - for key, value in input_cfg.items(): |
| 129 | + # TODO: Move this compatibility check inside the quantizer; matching config items here |
| 130 | + # is fragile and easy to break as config semantics evolve. |
| 131 | + input_items = input_cfg.items |
| 132 | + if isinstance(input_cfg, ModeloptBaseConfig): |
| 133 | + input_items = input_cfg.explicit_items |
| 134 | + for key, value in input_items(): |
| 135 | + if key == "enable": |
| 136 | + continue |
127 | 137 | if ( |
128 | 138 | not hasattr(module.input_quantizer, key) |
129 | 139 | or getattr(module.input_quantizer, key) != value |
130 | 140 | ): |
131 | 141 | return False |
132 | 142 |
|
133 | 143 | # Check weight quantizer config |
134 | | - for key, value in weight_cfg.items(): |
| 144 | + # TODO: Move this compatibility check inside the quantizer; matching config items here |
| 145 | + # is fragile and easy to break as config semantics evolve. |
| 146 | + weight_items = weight_cfg.items |
| 147 | + if isinstance(weight_cfg, ModeloptBaseConfig): |
| 148 | + weight_items = weight_cfg.explicit_items |
| 149 | + for key, value in weight_items(): |
| 150 | + if key == "enable": |
| 151 | + continue |
135 | 152 | if ( |
136 | 153 | not hasattr(module.weight_quantizer, key) |
137 | 154 | or getattr(module.weight_quantizer, key) != value |
|
0 commit comments