Skip to content

Commit 75f8a3c

Browse files
committed
swiglu offset
1 parent 82ace62 commit 75f8a3c

16 files changed

Lines changed: 80 additions & 52 deletions

File tree

tests/pytorch/test_fusible_ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,7 @@ def test_interleaved_swiglu(self):
17951795
@pytest.mark.parametrize("quantization", _quantization_list)
17961796
@pytest.mark.parametrize("quantize_forward", (False, True))
17971797
@pytest.mark.parametrize("quantize_backward", (False, True))
1798+
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
17981799
def test_clamped_swiglu(
17991800
self,
18001801
*,
@@ -1805,6 +1806,7 @@ def test_clamped_swiglu(
18051806
quantization: Optional[str],
18061807
quantize_forward: bool,
18071808
quantize_backward: bool,
1809+
glu_linear_offset: float,
18081810
limit: float = 0.75,
18091811
alpha: float = 1.702,
18101812
):
@@ -1847,7 +1849,7 @@ def test_clamped_swiglu(
18471849
x_glu = x_glu.clamp(min=None, max=limit)
18481850
x_linear = x_linear.clamp(min=-limit, max=limit)
18491851
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
1850-
y_ref = out_glu * (x_linear + 1)
1852+
y_ref = out_glu * (x_linear + glu_linear_offset)
18511853
y_ref.backward(dy_ref)
18521854

18531855
# Implementation with fusible operation
@@ -1858,6 +1860,7 @@ def test_clamped_swiglu(
18581860
te_ops.ClampedSwiGLU(
18591861
limit=limit,
18601862
alpha=alpha,
1863+
glu_linear_offset=glu_linear_offset,
18611864
glu_interleave_size=glu_interleave_size,
18621865
),
18631866
te_ops.Quantize(forward=quantize_forward, backward=False),
@@ -2240,6 +2243,7 @@ def test_interleaved_scaled_swiglu(self):
22402243
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
22412244
@pytest.mark.parametrize("input_requires_grad", (False, True))
22422245
@pytest.mark.parametrize("scales_requires_grad", (False, True))
2246+
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
22432247
def test_scaled_clamped_qgeglu(
22442248
self,
22452249
*,
@@ -2249,6 +2253,7 @@ def test_scaled_clamped_qgeglu(
22492253
device: torch.device = "cuda",
22502254
input_requires_grad: bool,
22512255
scales_requires_grad: bool,
2256+
glu_linear_offset: float,
22522257
limit: float = 7.0,
22532258
alpha: float = 1.702,
22542259
) -> None:
@@ -2293,7 +2298,7 @@ def test_scaled_clamped_qgeglu(
22932298
x_glu = x_glu.clamp(min=None, max=limit)
22942299
x_linear = x_linear.clamp(min=-limit, max=limit)
22952300
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
2296-
y = out_glu * (x_linear + 1)
2301+
y = out_glu * (x_linear + glu_linear_offset)
22972302
y_ref = scales_ref.unsqueeze(-1) * y
22982303
if input_requires_grad or scales_requires_grad:
22992304
y_ref.backward(dy_ref)
@@ -2302,6 +2307,7 @@ def test_scaled_clamped_qgeglu(
23022307
glu_interleave_size=glu_interleave_size,
23032308
limit=limit,
23042309
alpha=alpha,
2310+
glu_linear_offset=glu_linear_offset,
23052311
)
23062312
y_test = op(x_test, scales_test)
23072313
if input_requires_grad or scales_requires_grad:

transformer_engine/common/activation/swiglu.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,18 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
8585
}
8686

8787
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
88-
cudaStream_t stream) {
88+
float glu_linear_offset, cudaStream_t stream) {
8989
NVTE_API_CALL(nvte_clamped_swiglu);
9090
using namespace transformer_engine;
91-
ClampedSwiGLUParam param = {limit, alpha};
91+
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
9292
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
9393
}
9494

9595
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
96-
float limit, float alpha, cudaStream_t stream) {
96+
float limit, float alpha, float glu_linear_offset, cudaStream_t stream) {
9797
NVTE_API_CALL(nvte_clamped_dswiglu);
9898
using namespace transformer_engine;
99-
ClampedSwiGLUParam param = {limit, alpha};
99+
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
100100
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
101101
grad, input, output, param, stream);
102102
}

transformer_engine/common/cast/fp8/gated_fp8.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
169169
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
170170
bool dgate_elt = true; // gating is ideally an identity function
171171
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
172-
// In case of GPT OSS, clamp the activation and gate values
173-
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
174-
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
172+
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
173+
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
175174
}
176175

177176
if constexpr (IS_BWD) {

transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
245245
float after_gate_elt;
246246
bool dgate_elt = true; // gating is ideally an identity function
247247
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
248-
// In case of GPT OSS, clamp the activation and gate values
249-
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
250-
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
248+
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
249+
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
251250
}
252251
if constexpr (IS_BWD) {
253252
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
@@ -510,9 +509,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
510509
float after_gate_elt;
511510
bool dgate_elt = true;
512511
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
513-
// In case of GPT OSS, clamp the activation and gate values
514-
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
515-
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
512+
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
513+
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
516514
}
517515
if constexpr (IS_BWD) {
518516
float grad_elt = static_cast<float>(in_grad.data.elt[e]);

transformer_engine/common/include/transformer_engine/activation.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,11 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
336336
* It computes Act(input[N, :H]) x input[N, H:]
337337
* \param[in] limit Clipping limits for gate and pre-activation.
338338
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
339+
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
339340
* \param[in] stream CUDA stream used for the operation.
340341
*/
341342
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
342-
cudaStream_t stream);
343+
float glu_linear_offset, cudaStream_t stream);
343344

344345
/*! \brief Computes the gated ReLU activation of the input.
345346
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
@@ -413,10 +414,11 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
413414
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
414415
* \param[in] limit Clipping limits for gate and pre-activation.
415416
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
417+
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
416418
* \param[in] stream CUDA stream used for the operation.
417419
*/
418420
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
419-
float limit, float alpha, cudaStream_t stream);
421+
float limit, float alpha, float glu_linear_offset, cudaStream_t stream);
420422

421423
/*! \brief Computes the gated ReLU activation gradient.
422424
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,

transformer_engine/common/util/math.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ struct Empty {};
1313

1414
struct ClampedSwiGLUParam {
1515
float limit;
16-
float alpha = 1.702f; // Default value for QuickGELU
16+
float alpha = 1.702f; // Default value for QuickGELU
17+
float glu_linear_offset = 1.0f; // Offset added to the linear (gate) component after clamping
1718
};
1819

1920
template <typename OType, typename IType>

transformer_engine/common/util/vectorized_pointwise.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,8 @@ __launch_bounds__(unary_kernel_threads) __global__
434434
ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
435435

436436
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
437-
// Clamp the gated value and add 1 at the end
438437
ComputeType limit = p.limit;
439-
val2 = std::min(std::max(-limit, val2), limit) + 1;
438+
val2 = std::min(std::max(-limit, val2), limit) + p.glu_linear_offset;
440439
}
441440
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
442441
if (requires_amax) {
@@ -542,10 +541,9 @@ __launch_bounds__(unary_kernel_threads) __global__
542541
bool dgate_in = true;
543542

544543
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
545-
// In case of GPT OSS, clamp the activation and gate values
546544
const ComputeType limit = p.limit;
547-
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
548-
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
545+
dgate_in = gate_in <= limit && gate_in >= -limit;
546+
gate_in = std::min(std::max(-limit, gate_in), limit) + p.glu_linear_offset;
549547
}
550548

551549
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;

transformer_engine/jax/cpp_extensions/activation.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@ class ClampedSwigluParams:
6464

6565
limit: float = 7.0
6666
alpha: float = 1.702
67+
glu_linear_offset: float = 1.0
6768

6869
def __hash__(self):
6970
"""Custom hash function to ensure dataclass is hashable for jax jit to work.
7071
7172
Returns:
7273
int: Hash value of the dataclass instance.
7374
"""
74-
return hash((self.limit, self.alpha))
75+
return hash((self.limit, self.alpha, self.glu_linear_offset))
7576

7677
def to_ffi_lowering_dict(self):
7778
"""Convert the activation parameters to a dictionary format for FFI lowering.
@@ -80,7 +81,11 @@ def to_ffi_lowering_dict(self):
8081
dict: A dictionary representation of the activation parameters consumable by
8182
XLA FFI bindings for activation functions.
8283
"""
83-
return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}
84+
return {
85+
"limit": np.float32(self.limit),
86+
"alpha": np.float32(self.alpha),
87+
"glu_linear_offset": np.float32(self.glu_linear_offset),
88+
}
8489

8590

8691
@dataclass(frozen=True)
@@ -121,11 +126,9 @@ def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
121126
if fn_or_string == "linear":
122127
return lambda x: x
123128
if fn_or_string == "clamped_linear":
124-
# This function is used for ClampedSwiGLU
125-
# used in GPT OSS where the gates are not only clamped
126-
# but also shifted by +1
127129
limit = act_params.clamped_swiglu.limit
128-
return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
130+
offset = act_params.clamped_swiglu.glu_linear_offset
131+
return lambda x: jnp.clip(x, min=-limit, max=limit) + offset
129132
if fn_or_string == "quick_gelu":
130133
return lambda x: jax.nn.sigmoid(1.702 * x) * x
131134
if fn_or_string == "squared_relu":

transformer_engine/jax/csrc/extensions.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace jax {
3939
struct ClampedSwigluConfig {
4040
float limit;
4141
float alpha;
42+
float glu_linear_offset;
4243
};
4344

4445
struct ActivationConfig {
@@ -208,7 +209,8 @@ pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k);
208209

209210
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
210211
::xla::ffi::StructMember<float>("limit"),
211-
::xla::ffi::StructMember<float>("alpha"));
212+
::xla::ffi::StructMember<float>("alpha"),
213+
::xla::ffi::StructMember<float>("glu_linear_offset"));
212214

213215
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
214216
transformer_engine::jax::ActivationConfig,

transformer_engine/jax/csrc/extensions/activation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
2323
// parameters for clamped swiglu used in GPT OSS
2424
auto swiglu_limit = act_params.clamped_swiglu.limit;
2525
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
26+
auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset;
2627

2728
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
2829
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
@@ -138,7 +139,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
138139
break;
139140
case NVTE_Activation_Type::CLAMPED_SWIGLU:
140141
nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
141-
stream);
142+
swiglu_glu_linear_offset, stream);
142143
break;
143144
default:
144145
NVTE_ERROR("Unsupported ActivationEnum");
@@ -271,6 +272,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
271272
// parameters for clamped swiglu used in GPT OSS
272273
auto swiglu_limit = act_params.clamped_swiglu.limit;
273274
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
275+
auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset;
274276

275277
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
276278
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
@@ -447,7 +449,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
447449
break;
448450
case NVTE_Activation_Type::CLAMPED_SWIGLU:
449451
nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
450-
swiglu_limit, swiglu_alpha, stream);
452+
swiglu_limit, swiglu_alpha, swiglu_glu_linear_offset, stream);
451453
break;
452454
default:
453455
NVTE_ERROR("Unsupported ActivationEnum");

0 commit comments

Comments
 (0)