Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fms_mo/aiu_addons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _infer_quantization_config(quant_config: dict) -> dict | None:
# First, import required FP8 linear classes from fms-mo
# Local
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import
import fms_mo.aiu_addons.fp8.fp8_attn # pylint: disable=unused-import
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import

# This is used by get_linear to decide whether a linear layer
Expand Down
6 changes: 3 additions & 3 deletions fms_mo/aiu_addons/fp8/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,12 @@ def shard_fp8_linear(
sharding | param | shard | dim |
----------+----------------+-------+-----|
colwise | weight | Y | 0 |
| weight_scale | N | - |
| weight_scale | Y/N | 0/- |
| input_scale | N | - |
| bias | Y | 0 |
----------+----------------+-------+-----|
rowwise | weight | Y | 1 |
| weight_scale | Y/N | 0/- |
| weight_scale | N | - |
| input_scale | Y/N | 0/- |
| bias | 0 | - |
"""
Expand All @@ -339,7 +339,7 @@ def shard_fp8_linear(
]
# Scales are per-row or per-tensor
# Only sharding needed when row parallel and per-row
shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1
shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 0
params: dict[str, LinearParameterShardingInfo] = {
"weight": LinearParameterShardingInfo(
module_info.sharding_dim, ShardType.SHARD
Expand Down
Loading