Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_act_grad(self, shape, activation_type):
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
{"limit": 0.75, "alpha": 1.702, "glu_linear_offset": 0.5}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
Expand Down
48 changes: 39 additions & 9 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,7 @@ def test_interleaved_swiglu(self):
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
def test_clamped_swiglu(
self,
*,
Expand All @@ -1856,6 +1857,7 @@ def test_clamped_swiglu(
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
glu_linear_offset: float,
limit: float = 0.75,
alpha: float = 1.702,
):
Expand Down Expand Up @@ -1898,7 +1900,7 @@ def test_clamped_swiglu(
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref = out_glu * (x_linear + glu_linear_offset)
y_ref.backward(dy_ref)

# Implementation with fusible operation
Expand All @@ -1909,6 +1911,7 @@ def test_clamped_swiglu(
te_ops.ClampedSwiGLU(
limit=limit,
alpha=alpha,
glu_linear_offset=glu_linear_offset,
glu_interleave_size=glu_interleave_size,
),
te_ops.Quantize(forward=quantize_forward, backward=False),
Expand Down Expand Up @@ -1938,6 +1941,7 @@ def test_interleaved_clamped_swiglu(self):
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
glu_linear_offset=1.0,
)

@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
Expand Down Expand Up @@ -2594,6 +2598,7 @@ def test_scaled_activation_recompute_in_mlp_config(self, op_cls) -> None:
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
def test_scaled_clamped_qgeglu(
self,
*,
Expand All @@ -2603,6 +2608,7 @@ def test_scaled_clamped_qgeglu(
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
glu_linear_offset: float,
limit: float = 7.0,
alpha: float = 1.702,
) -> None:
Expand Down Expand Up @@ -2647,7 +2653,7 @@ def test_scaled_clamped_qgeglu(
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
y = out_glu * (x_linear + glu_linear_offset)
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)
Expand All @@ -2656,6 +2662,7 @@ def test_scaled_clamped_qgeglu(
glu_interleave_size=glu_interleave_size,
limit=limit,
alpha=alpha,
glu_linear_offset=glu_linear_offset,
)
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
Expand All @@ -2674,6 +2681,7 @@ def test_interleaved_scaled_clamped_qgeglu(self):
glu_interleave_size=32,
input_requires_grad=True,
scales_requires_grad=True,
glu_linear_offset=1.0,
)


Expand Down Expand Up @@ -3685,7 +3693,13 @@ def test_layernorm_mlp(
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
@pytest.mark.parametrize("hidden_size", (128, 256))
@pytest.mark.parametrize(
"activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_srelu")
"activation",
(
"scaled_swiglu",
"scaled_clamped_qgeglu",
"scaled_clamped_qgeglu_custom",
"scaled_srelu",
),
)
def test_grouped_mlp(
self,
Expand Down Expand Up @@ -3719,7 +3733,7 @@ def test_grouped_mlp(
with_quantization = quantization is not None
if activation == "scaled_swiglu":
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_clamped_qgeglu":
elif activation.startswith("scaled_clamped_qgeglu"):
scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_srelu":
scaled_act = te_ops.ScaledSReLU()
Expand All @@ -3742,13 +3756,23 @@ def test_grouped_mlp(
if (
with_quantization
and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6")
and activation == "scaled_clamped_qgeglu"
and activation.startswith("scaled_clamped_qgeglu")
and bias
):
# TODO: ksivaman: Need to debug numerics for this case.
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size

# Activation parameters for clamped QGeGLU variants
if activation == "scaled_clamped_qgeglu_custom":
geglu_limit = 5.0
geglu_alpha = 1.5
geglu_offset = 0.5
else:
geglu_limit = 7.0
geglu_alpha = 1.702
geglu_offset = 1.0

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
Expand Down Expand Up @@ -3840,13 +3864,12 @@ def test_grouped_mlp(
if activation == "scaled_swiglu":
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
elif activation == "scaled_clamped_qgeglu":
elif activation.startswith("scaled_clamped_qgeglu"):
x1, x2 = x.chunk(2, dim=-1)
lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype)
geglu_alpha = 1.702
lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype)
x1c = torch.minimum(x1, lim)
x2c = torch.clamp(x2, -lim, lim)
x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c))
x = (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c))
elif activation == "scaled_srelu":
x = torch.nn.functional.relu(x).square()
else:
Expand All @@ -3861,6 +3884,13 @@ def test_grouped_mlp(

# Construct operations
recipe = make_recipe(quantization)
if activation == "scaled_clamped_qgeglu_custom":
scaled_act = te_ops.ScaledClampedQGeGLU(
glu_interleave_size=glu_interleave_size,
limit=geglu_limit,
alpha=geglu_alpha,
glu_linear_offset=geglu_offset,
)
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
Expand Down
24 changes: 22 additions & 2 deletions transformer_engine/common/activation/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,35 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
// Preserve original behavior: linear (gate) component offset is hard-coded to 1.0f.
ClampedSwiGLUParam param = {limit, alpha, /*glu_linear_offset=*/1.0f};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}

void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha,
float glu_linear_offset, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu_v2);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}

void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
// Preserve original behavior: linear (gate) component offset is hard-coded to 1.0f.
ClampedSwiGLUParam param = {limit, alpha, /*glu_linear_offset=*/1.0f};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}

void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, float glu_linear_offset,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu_v2);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}
5 changes: 2 additions & 3 deletions transformer_engine/common/cast/fp8/gated_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}

if constexpr (IS_BWD) {
Expand Down
10 changes: 4 additions & 6 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float after_gate_elt;
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
Expand Down Expand Up @@ -510,9 +509,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float after_gate_elt;
bool dgate_elt = true;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
*
* \deprecated This function has been deprecated in favor of nvte_clamped_swiglu_v2,
* which exposes a configurable offset for the linear (gate) component.
* This API is preserved for backward compatibility and is equivalent to
* calling nvte_clamped_swiglu_v2 with glu_linear_offset = 1.0.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
Expand All @@ -341,6 +346,28 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream);

/*! \brief Computes the gated Swish activation of the input used in GPT OSS, with a configurable
* offset for the linear (gate) component after clamping.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x (input[N, H:] + glu_linear_offset)
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] glu_linear_offset Offset added to the linear component after clamping (typically 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha,
float glu_linear_offset, cudaStream_t stream);

/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down Expand Up @@ -399,6 +426,11 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
cudaStream_t stream);

/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
*
* \deprecated This function has been deprecated in favor of nvte_clamped_dswiglu_v2,
* which exposes a configurable offset for the linear (gate) component.
* This API is preserved for backward compatibility and is equivalent to
* calling nvte_clamped_dswiglu_v2 with glu_linear_offset = 1.0.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
Expand All @@ -418,6 +450,29 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream);

/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS, with a
* configurable offset for the linear (gate) component after clamping.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] glu_linear_offset Offset added to the linear component after clamping (typically 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, float glu_linear_offset,
cudaStream_t stream);

/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/util/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ struct Empty {};

struct ClampedSwiGLUParam {
float limit;
float alpha = 1.702f; // Default value for QuickGELU
float alpha = 1.702f; // Default value for QuickGELU
float glu_linear_offset = 1.0f; // Offset added to the linear (gate) component after clamping
};

template <typename OType, typename IType>
Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/common/util/vectorized_pointwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,8 @@ __launch_bounds__(unary_kernel_threads) __global__
ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);

if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// Clamp the gated value and add 1 at the end
ComputeType limit = p.limit;
val2 = std::min(std::max(-limit, val2), limit) + 1;
val2 = std::min(std::max(-limit, val2), limit) + p.glu_linear_offset;
}
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if (requires_amax) {
Expand Down Expand Up @@ -542,10 +541,9 @@ __launch_bounds__(unary_kernel_threads) __global__
bool dgate_in = true;

if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
const ComputeType limit = p.limit;
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
dgate_in = gate_in <= limit && gate_in >= -limit;
gate_in = std::min(std::max(-limit, gate_in), limit) + p.glu_linear_offset;
}

ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
Expand Down
Loading
Loading