diff --git a/fms_mo/aiu_addons/__init__.py b/fms_mo/aiu_addons/__init__.py index 30367a4..1d9c5cf 100644 --- a/fms_mo/aiu_addons/__init__.py +++ b/fms_mo/aiu_addons/__init__.py @@ -1,3 +1,7 @@ +# Local +from fms_mo.prep import available_packages + + def _infer_quantization_config(quant_config: dict) -> dict | None: """Construct linear_config dictionary carrying FP8 configuration for FMS. @@ -20,6 +24,10 @@ def _infer_quantization_config(quant_config: dict) -> dict | None: quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 ): + if not available_packages["torchao"]: + raise ImportError( + "You need torchao installed to load FP8 checkpoints in FMS" + ) # First, import required FP8 linear classes from fms-mo # Local import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index 5a091ae..e60438b 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -220,8 +220,9 @@ def _math_fp8_compute_op( .to(dtype=orig_dtype) .transpose(-2, -1) ) - attn_weight = query @ key_t - attn_weight *= scale_factor + attn_weight = (query * math.sqrt(scale_factor)) @ ( + key_t * math.sqrt(scale_factor) + ) attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, p_dropout, train=True) diff --git a/pyproject.toml b/pyproject.toml index e3f2300..f509d56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ [project.optional-dependencies] examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"] -fp8 = ["llmcompressor", "torchao>=0.11,<=0.12"] +fp8 = ["llmcompressor", "torchao==0.11"] gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"]