Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 101 additions & 24 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,27 @@ def forward(self, x):
self.smoothquant,
)

def re_register_qdata(self) -> None:
"""Remove existing self.qdata tensor and register it again as a buffer.
This method is used during TP, after other quantization metadata have been
updated.
"""

del self.qdata
self.register_buffer(
"qdata",
torch.cat(
(
self.w_clip_val,
self.w_clip_valn,
self.a_clip_val,
self.a_clip_valn,
self.zero_shift,
self.smoothquant_scale,
)
),
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}"
Expand Down Expand Up @@ -222,7 +243,7 @@ def get_int8_aiu_linear(
# Preprocess linear_config if its linear_type field is a callable
# (which would not initialize correctly the dataclass parameters).
# We don't want to alter the original linear_config though.
linear_config_for_dataclass: Optional[dict[Union[str, Callable], Any]] = None
linear_config_for_dataclass = None
if callable(linear_config["linear_type"]):
linear_config_for_dataclass = update_from_partial(linear_config)
linear_config_for_dataclass["linear_type"] = linear_type
Expand All @@ -240,6 +261,36 @@ def get_int8_aiu_linear(
return linear


def is_w_clip_per_channel(
w_clip: torch.Tensor,
) -> bool:
"""Determine whether the weight clip value in use for INT8 quantization of the
provided linear module is:
- per-tensor (1 element, 1-dim tensor), or
- per-channel (out_feat elements, 1-dim tensor).
"""

if w_clip.dim() != 1:
raise ValueError(
f"TP error: weight clip value dimensions {str(list(w_clip.size()))} are "
"incompatible with expected per-tensor or per-channel quantization."
)
return w_clip.numel() > 1


def is_smoothquant_enabled(
smoothquant_scale: torch.Tensor,
) -> bool:
"""Determine whether smoothquant is enabled on a module."""

if smoothquant_scale.dim() != 1:
raise ValueError(
"TP error: smoothquant_scale array should always be 1-dimensional but "
f"has size {str(list(smoothquant_scale.size()))}"
)
return smoothquant_scale.numel() > 1


def shard_int8_aiu_linear(
tensor_values: dict[str, torch.Tensor],
tp_module: TPModule,
Expand All @@ -259,49 +310,73 @@ def shard_int8_aiu_linear(
| bias | 0 | - |
| others* | N | - |

Other quantization parameters: w_clip_val, w_clip_valn,
a_clip_val, a_clip_valn, zero_shift, smoothquant_scale
No sharding on all these parameters, except w_clip_val and w_clip_valn when
per-channel quantization is used
Other quantization parameters: w_clip_val, w_clip_valn, a_clip_val, a_clip_valn,
zero_shift, smoothquant_scale

No sharding on any of these parameters (they are CLONED on each rank), with the
exception of:
- w_clip_val and w_clip_valn, only column-sharding and only when per-channel
quantization is used
- smoothquant_scale, only row-sharding and only if smoothquant in use

These parameters are 1-dimensional, so if sharding is needed, it is always applied
on dim=0.
"""

param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {}
w_clip_linear_param = None
for module_name, module_info in module_sharding_info.items():
int8_aiu_mod = module_info.linear_module
int8_aiu_module = module_info.linear_module

# check every module if per-channel in use (sharding depends on module)
if is_w_clip_per_channel(module_info.linear_module.w_clip_val):
w_clip_linear_param = LinearParameterShardingInfo(
0,
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.CLONE,
)
else:
w_clip_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)

# check for every linear module if smoothquant is enabled
if is_smoothquant_enabled(module_info.linear_module.smoothquant_scale):
smoothquant_linear_param = LinearParameterShardingInfo(
0, ShardType.SHARD if module_info.sharding_dim == 1 else ShardType.CLONE
)
else:
smoothquant_linear_param = LinearParameterShardingInfo(0, ShardType.CLONE)

params: dict[str, LinearParameterShardingInfo] = {
"weight": LinearParameterShardingInfo(
module_info.sharding_dim, ShardType.SHARD
),
# FIXME: with per-channel W, clips need to be sharded
# but if per-tensor w, there should be no sharding
# HOW CAN WE DISCRIMINATE THE TWO CASES?
"w_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
"w_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
# "w_clip_val": LinearParameterShardingInfo(
# module_info.sharding_dim,
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
# ),
# "w_clip_valn": LinearParameterShardingInfo(
# module_info.sharding_dim,
# ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
# ),
"w_clip_val": w_clip_linear_param,
"w_clip_valn": w_clip_linear_param,
"a_clip_val": LinearParameterShardingInfo(0, ShardType.CLONE),
"a_clip_valn": LinearParameterShardingInfo(0, ShardType.CLONE),
"zero_shift": LinearParameterShardingInfo(0, ShardType.CLONE),
"smooqthquant_scale": LinearParameterShardingInfo(0, ShardType.CLONE),
"smoothquant_scale": smoothquant_linear_param,
}
if int8_aiu_mod.bias is not None:
if int8_aiu_module.bias is not None and int8_aiu_module.bias.numel() > 1:
params["bias"] = LinearParameterShardingInfo(
module_info.sharding_dim,
0,
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
)
param_sharding_info[module_name] = params

# trim qdata from dictionary of tensors to be copied on sharded modules.
# if not trimmed, qdata wouldn't be copied but the keys would be marked as unused
tensor_values = {k: v for k, v in tensor_values.items() if "qdata" not in k}

unused_keys = shard_base_linear(
tensor_values, tp_module, module_sharding_info, param_sharding_info
)

raise NotImplementedError("TP not yet supported for INT8. Work in progress")
# return unused_keys
# qdata contains all quantization metadata to pass to the AIU and needs to be
# updated post-sharding, after metadata tensor have changed
for module_name, module_info in module_sharding_info.items():
module_info.linear_module.re_register_qdata()

return unused_keys


register_linear_type_to_module_map(
Expand All @@ -320,4 +395,6 @@ def shard_int8_aiu_linear(
use_smoothquant=True,
),
)

# int8 linear with and w/o smoothquant share a common sharding map
register_linear_type_to_sharding_map("int8_aiu", shard_int8_aiu_linear)
Loading