Skip to content
14 changes: 11 additions & 3 deletions msamp/common/tensor/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""MS-AMP ScalingMeta."""

import copy
from typing import Optional
import torch

from msamp.common.dtype import Floating, Dtypes
Expand All @@ -13,7 +14,7 @@ class ScalingMeta:
"""The meta data for scaling tensor."""
in_time_scaling: bool = True

def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, group=None):
def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1, pre_scale=None, group=None):
"""Constructor.

Args:
Expand All @@ -22,11 +23,13 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1,
scale_inv (torch.Tensor, optional): The reciprocal of scaling tensor, defaults to None.
amax (torch.Tensor, optional): Absolute maximum tensor, defaults to None.
window_size (int, optional): Window size, defaults to 1.
pre_scale (torch.Tensor, optional): A pre-scale factor
group (torch.distributed.ProcessGroup, optional): Distributed group, defaults to None.
"""
self.qtype = qtype
self.scale = scale if scale is not None else torch.ones((), device='cuda')
self.scale_inv = scale_inv if scale_inv is not None else torch.ones((), device='cuda')
self.pre_scale = pre_scale if pre_scale is not None else torch.ones((), device='cuda')
self.amax = amax if amax is not None else torch.zeros((window_size, ), device='cuda')
self.amax_counter = torch.zeros((), dtype=torch.int32)
self.window_size = window_size
Expand All @@ -36,20 +39,23 @@ def __init__(self, qtype, scale=None, scale_inv=None, amax=None, window_size=1,

@staticmethod
@torch.jit.script
def compute_scaling_factor(amax, scale, fp_max: float, margin: int):
def compute_scaling_factor(amax, scale, fp_max: float, margin: int, pre_scale: Optional[torch.Tensor] = None):
"""A function to compute scaling factor.

Args:
amax (torch.Tensor): Absolute maximum tensor.
scale (torch.Tensor): Scale tensor.
fp_max (float): The maximum value of float point.
margin (int): Margin value.
pre_scale (torch.Tensor, optional): A pre-scale factor

Returns:
return new scaling tensor.
"""
exp = torch.floor(torch.log2(fp_max / amax)) - margin
sf = torch.round(torch.pow(2, torch.abs(exp)))
if pre_scale is not None:
sf.mul_(pre_scale)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(exp < 0, 1 / sf, sf)
Expand Down Expand Up @@ -108,7 +114,7 @@ def reset_scaling_factor(self, qtype=None):
self.scale.fill_(1)
else:
fp_max = Floating.qfp_max[qtype]
sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0)
sf = ScalingMeta.compute_scaling_factor(self.amax[0], self.scale, fp_max, 0, pre_scale=self.pre_scale)
self.scale.copy_(sf)

def copy_(self, src):
Expand All @@ -122,6 +128,7 @@ def copy_(self, src):
self.scale_inv.copy_(src.scale_inv)
self.amax.copy_(src.amax)
self.amax_counter.copy_(src.amax_counter)
self.pre_scale.copy_(src.pre_scale)
self.window_size = src.window_size

def clone(self):
Expand Down Expand Up @@ -156,4 +163,5 @@ def __repr__(self):
"""
return f'ScalingMeta(qtype={self.qtype}, '\
f'scale={self.scale.data:g}, scale_inv={self.scale_inv.data:g}, '\
f'pre_scale={self.pre_scale.data:g}, '\
f'amax={self.amax.max():g}, window_size={self.window_size})'
2 changes: 1 addition & 1 deletion msamp/megatron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _get_buffer_type(param):
start = pi * max_fp8_mems
for p in fp8_partitions[pi]:
meta = ScalingMeta(self.wgrad_qtype, scale=scales[t], scale_inv=scale_invs[t], amax=amaxs[t])
meta.pre_scale = pre_scale
meta.pre_scale.fill_(pre_scale)
t += 1
p.main_grad = ScalingTensor(self._grad_buffers[self.wgrad_dtype].get(p.shape, start), meta)
self._grad_buffer_param_index_map[self.wgrad_dtype][p] = (start, start + p.numel())
Expand Down
41 changes: 41 additions & 0 deletions msamp/megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,47 @@ def reduce_model_grads(self, args, timers): # noqa: C901

timers('grads-reduce-scatter').stop()

if args.wgrad_auto_scaling:
# Weight Gradient Auto Scaling
if args.curr_iteration % args.wgrad_auto_scaling_freq == 0:
timers('wgrad-auto-scaling', log_level=1).start(barrier=args.barrier_with_L1_time)

# update pre_scale in this partition
for model_group in self.model_fp8_groups:
for p in model_group:
g = p.main_grad
if g is not None and not torch.is_tensor(g):
if g.qtype != Dtypes.kfloat8_e4m3:
raise TypeError('g.qtype != Dtypes.kfloat8_e4m3: {}'.format(g.qtype))
# stat overflow ratio
num_infs = torch.count_nonzero((g.value & 0x7f) == 126)
overflow_ratio = num_infs / g.numel()
if overflow_ratio > args.wgrad_auto_scaling_ratio:
g.meta.pre_scale.div_(2.0)
else:
g.meta.pre_scale.mul_(2.0**(1.0 / args.wgrad_auto_scaling_window))

# synchonize pre_scale in all partitions
for model_id, model in enumerate(self.models):
# all fp8 gradients
partitions = self.model_gbuf_ranges[model_id][torch.uint8]['partitions']
fp8_grads = [[p.main_grad for p in part.keys()] for part in partitions]
# pre_scales in the partition `data_parallel_rank`
pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]]
max_elems_per_rank = max(model._grad_buffer_num_params)
pre_scales = torch.stack(pre_scales)
# padding to max_elems_per_rank
pad = max_elems_per_rank - pre_scales.numel()
pre_scales = F.pad(pre_scales, (0, pad))
output_pre_scales = pre_scales.new_empty((data_parallel_world_size, max_elems_per_rank))
torch.distributed._all_gather_base(output_pre_scales, pre_scales, group=data_parallel_group)
# assign pre_scale to all fp8 gradients
for grads, pre_scales in zip(fp8_grads, output_pre_scales):
for g, pre_scale in zip(grads, pre_scales):
g.meta.pre_scale.copy_(pre_scale)

timers('wgrad-auto-scaling').stop()

def gather_model_params(self, args, timers): # noqa: C901
"""All-gather updated model params.

Expand Down
6 changes: 4 additions & 2 deletions msamp/operators/arithmetic/arithmetic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@ void add_to_fp8(at::Tensor fp8_tensor,
at::Tensor scale,
at::Tensor scale_inv,
at::Tensor amax,
at::Tensor pre_scale,
const at::Tensor& other,
bool is_e4m3) {
const size_t N = other.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType,
SELECT_FP8_TYPE(is_e4m3, OType,

constexpr int nvec = 32 / sizeof(IType);

VectorizedAddToFp8KernelLauncher<nvec>(
reinterpret_cast<IType*>(other.data_ptr()),
reinterpret_cast<OType*>(fp8_tensor.data_ptr()),
reinterpret_cast<fp32*>(scale.data_ptr()),
reinterpret_cast<fp32*>(scale_inv.data_ptr()),
reinterpret_cast<fp32*>(amax.data_ptr()),
reinterpret_cast<fp32*>(pre_scale.data_ptr()),
N,
stream
);
Expand Down
4 changes: 3 additions & 1 deletion msamp/operators/arithmetic/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ def add_to_fp8(fp8_tensor, meta, other):

is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3

msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3)
msamp_arithmetic.add_to_fp8(
fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], meta.pre_scale, other, is_e4m3
)
26 changes: 15 additions & 11 deletions msamp/operators/arithmetic/vectorized_pointwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ __global__ void add_to_fp8_kernel(InputType *input,
ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax,
ComputeType *pre_scale,
const size_t N,
const size_t num_aligned_elements) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
Expand Down Expand Up @@ -262,12 +263,14 @@ __global__ void add_to_fp8_kernel(InputType *input,
ComputeType exp = floorf(log2f(fp_max/(amax_value)));
ComputeType sf = roundf(powf(2, fabsf(exp)));

sf *= *pre_scale;

if (amax_value <= 0 || !isfinite(amax_value)) {
sf = *scale;
}

if (exp < 0) {
sf = 1 / sf;
sf = 1.0f / sf;
}

// using new scaling factor to quantize the input
Expand All @@ -280,9 +283,9 @@ __global__ void add_to_fp8_kernel(InputType *input,
for (int i = 0; i < nvec; ++i) {
const InputType val1 = static_cast<InputType>(input_storer.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(output_storer.separate()[i]);

InputType temp1 = static_cast<InputType>(val2 * s);

if constexpr (is_half<InputType>::value) {
temp1 = static_cast<ComputeType>(__hadd(temp1, val1));
} else {
Expand All @@ -296,7 +299,7 @@ __global__ void add_to_fp8_kernel(InputType *input,

if (threadIdx.x == 0 && blockIdx.x == 0) {
*scale = sf;
*scale_inv = 1.0 / sf;
*scale_inv = 1.0f / sf;
}
}

Expand Down Expand Up @@ -363,6 +366,7 @@ void VectorizedAddToFp8KernelLauncher(InputType *input,
fp32 *scale,
fp32 *scale_inv,
fp32 *amax,
fp32 *pre_scale,
const size_t N,
cudaStream_t stream) {
if (N != 0) {
Expand All @@ -373,26 +377,26 @@ void VectorizedAddToFp8KernelLauncher(InputType *input,
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);

// We use DeviceSyncer to sync the amax value between blocks, the block number should be less than
// (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set
// max_blocks to half of 528 to make sure it works on other H100 GPUs.
// We use DeviceSyncer to sync the amax value between blocks, the block number should be less than
// (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set
// max_blocks to half of 528 to make sure it works on other H100 GPUs.
// constexpr size_t max_blocks = 65535;
constexpr size_t max_blocks = 264;
num_blocks = std::min(num_blocks, max_blocks);

switch (align) {
case Alignment::SAME_ALIGNED:
add_to_fp8_kernel<nvec, true, fp32><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, N, num_aligned_elements);
input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
add_to_fp8_kernel<nvec, false, fp32><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, N, num_aligned_elements);
input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
add_to_fp8_kernel<1, true, fp32><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, N, num_aligned_elements);
input, output, scale, scale_inv, amax, pre_scale, N, num_aligned_elements);
break;
}
}
Expand All @@ -401,4 +405,4 @@ void VectorizedAddToFp8KernelLauncher(InputType *input,

} // namespace msamp

#endif // MSAMP_VECTORIZED_POINTWISE_H
#endif // MSAMP_VECTORIZED_POINTWISE_H
15 changes: 15 additions & 0 deletions tests/common/tensor/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ def test_disable_in_time_scaling(self):
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
self.assertFalse(meta.is_in_time_scaling())
ScalingMeta.in_time_scaling = bak

def test_pre_scale(self):
"""Test pre_scale in ScalingMeta."""
x = torch.randn((4, 4), device='cuda')
meta = ScalingMeta(Dtypes.kfloat8_e4m3)
qtype = Dtypes.kfloat8_e4m3
q1 = x.cast(qtype, meta)

r = 0.5
meta2 = ScalingMeta(Dtypes.kfloat8_e4m3)
meta2.pre_scale.fill_(r)
q2 = x.cast(qtype, meta2)
self.assertTrue(torch.allclose(q1.float(), q2.float(), atol=5e-4))
self.assertTrue(torch.allclose(q1.meta.scale * r, q2.meta.scale))
self.assertTrue(torch.allclose(q1.meta.scale_inv / r, q2.meta.scale_inv))
25 changes: 21 additions & 4 deletions tests/operators/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
class ArithmeticTestCase(unittest.TestCase):
"""A class for Arithmetic test cases."""
def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2):
self.assertTrue(torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value)))
self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale)))
self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv)))
self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax)))
atol = 1e-6
self.assertTrue(torch.allclose(scaling_tensor1.value, scaling_tensor2.value, atol=3))
self.assertTrue(torch.allclose(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale, atol=atol))
self.assertTrue(torch.allclose(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv, atol=atol))
self.assertTrue(torch.allclose(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax, atol=atol))

@decorator.cuda_test
def test_add_to_fp8(self):
Expand All @@ -31,10 +32,26 @@ def test_add_to_fp8(self):
for i, j, dtype, qtype, in itertools.product(sizes, sizes, dtypes, qtypes):
size = (i, j)
input1 = torch.rand(size, dtype=dtype, device='cuda')

# w/o pre_scale
scaling_tensor1 = input1.cast(qtype)
scaling_tensor2 = input1.cast(qtype)

for i in range(10):
input2 = torch.rand(size, dtype=dtype, device='cuda')
meta = scaling_tensor1.meta
Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2)
scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta))
self._check_scaling_tensor(scaling_tensor1, scaling_tensor2)

# w/ pre_scale
scaling_tensor1 = input1.cast(qtype)
scaling_tensor2 = input1.cast(qtype)

for i in range(10):
pre_scale = torch.rand(1).item()
scaling_tensor1.meta.pre_scale.fill_(pre_scale)
scaling_tensor2.meta.pre_scale.fill_(pre_scale)
input2 = torch.rand(size, dtype=dtype, device='cuda')
meta = scaling_tensor1.meta
Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2)
Expand Down