diff --git a/fms_mo/aiu_addons/__init__.py b/fms_mo/aiu_addons/__init__.py index 1d9c5cf..4002e96 100644 --- a/fms_mo/aiu_addons/__init__.py +++ b/fms_mo/aiu_addons/__init__.py @@ -31,6 +31,7 @@ def _infer_quantization_config(quant_config: dict) -> dict | None: # First, import required FP8 linear classes from fms-mo # Local import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import + import fms_mo.aiu_addons.fp8.fp8_attn # pylint: disable=unused-import import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import # This is used by get_linear to decide whether a linear layer diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 8456268..866d6aa 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -321,12 +321,12 @@ def shard_fp8_linear( sharding | param | shard | dim | ----------+----------------+-------+-----| colwise | weight | Y | 0 | - | weight_scale | N | - | + | weight_scale | Y/N | 0/- | | input_scale | N | - | | bias | Y | 0 | ----------+----------------+-------+-----| rowwise | weight | Y | 1 | - | weight_scale | Y/N | 0/- | + | weight_scale | N | - | | input_scale | Y/N | 0/- | | bias | 0 | - | """ @@ -339,7 +339,7 @@ def shard_fp8_linear( ] # Scales are per-row or per-tensor # Only sharding needed when row parallel and per-row - shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 + shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 0 params: dict[str, LinearParameterShardingInfo] = { "weight": LinearParameterShardingInfo( module_info.sharding_dim, ShardType.SHARD