Skip to content

Commit 7e6ffcc

Browse files
hxbaivthumbe1503pre-commit-ci[bot]timmoon10
authored
[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable (#2938)
* swiglu offset Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * fix fusion pattern check Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * use swiglu_v2 Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * add default value to v1 Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * fix test Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * add default value to jax version Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * revert the default value change Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update the fusion path Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * update cudnn-frontend to 1.24.0 Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent 80ea313 commit 7e6ffcc

20 files changed

Lines changed: 240 additions & 61 deletions

File tree

tests/jax/test_custom_call_compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_act_grad(self, shape, activation_type):
245245
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
246246
)
247247
act_args = (
248-
{"limit": 0.75, "alpha": 1.702}
248+
{"limit": 0.75, "alpha": 1.702, "glu_linear_offset": 0.5}
249249
if activation_type == ("clamped_silu", "clamped_linear")
250250
else {}
251251
)

tests/pytorch/test_fusible_ops.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,6 +1846,7 @@ def test_interleaved_swiglu(self):
18461846
@pytest.mark.parametrize("quantization", _quantization_list)
18471847
@pytest.mark.parametrize("quantize_forward", (False, True))
18481848
@pytest.mark.parametrize("quantize_backward", (False, True))
1849+
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
18491850
def test_clamped_swiglu(
18501851
self,
18511852
*,
@@ -1856,6 +1857,7 @@ def test_clamped_swiglu(
18561857
quantization: Optional[str],
18571858
quantize_forward: bool,
18581859
quantize_backward: bool,
1860+
glu_linear_offset: float,
18591861
limit: float = 0.75,
18601862
alpha: float = 1.702,
18611863
):
@@ -1898,7 +1900,7 @@ def test_clamped_swiglu(
18981900
x_glu = x_glu.clamp(min=None, max=limit)
18991901
x_linear = x_linear.clamp(min=-limit, max=limit)
19001902
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
1901-
y_ref = out_glu * (x_linear + 1)
1903+
y_ref = out_glu * (x_linear + glu_linear_offset)
19021904
y_ref.backward(dy_ref)
19031905

19041906
# Implementation with fusible operation
@@ -1909,6 +1911,7 @@ def test_clamped_swiglu(
19091911
te_ops.ClampedSwiGLU(
19101912
limit=limit,
19111913
alpha=alpha,
1914+
glu_linear_offset=glu_linear_offset,
19121915
glu_interleave_size=glu_interleave_size,
19131916
),
19141917
te_ops.Quantize(forward=quantize_forward, backward=False),
@@ -1938,6 +1941,7 @@ def test_interleaved_clamped_swiglu(self):
19381941
quantize_forward=False,
19391942
quantize_backward=False,
19401943
glu_interleave_size=32,
1944+
glu_linear_offset=1.0,
19411945
)
19421946

19431947
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@@ -2594,6 +2598,7 @@ def test_scaled_activation_recompute_in_mlp_config(self, op_cls) -> None:
25942598
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
25952599
@pytest.mark.parametrize("input_requires_grad", (False, True))
25962600
@pytest.mark.parametrize("scales_requires_grad", (False, True))
2601+
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
25972602
def test_scaled_clamped_qgeglu(
25982603
self,
25992604
*,
@@ -2603,6 +2608,7 @@ def test_scaled_clamped_qgeglu(
26032608
device: torch.device = "cuda",
26042609
input_requires_grad: bool,
26052610
scales_requires_grad: bool,
2611+
glu_linear_offset: float,
26062612
limit: float = 7.0,
26072613
alpha: float = 1.702,
26082614
) -> None:
@@ -2647,7 +2653,7 @@ def test_scaled_clamped_qgeglu(
26472653
x_glu = x_glu.clamp(min=None, max=limit)
26482654
x_linear = x_linear.clamp(min=-limit, max=limit)
26492655
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
2650-
y = out_glu * (x_linear + 1)
2656+
y = out_glu * (x_linear + glu_linear_offset)
26512657
y_ref = scales_ref.unsqueeze(-1) * y
26522658
if input_requires_grad or scales_requires_grad:
26532659
y_ref.backward(dy_ref)
@@ -2656,6 +2662,7 @@ def test_scaled_clamped_qgeglu(
26562662
glu_interleave_size=glu_interleave_size,
26572663
limit=limit,
26582664
alpha=alpha,
2665+
glu_linear_offset=glu_linear_offset,
26592666
)
26602667
y_test = op(x_test, scales_test)
26612668
if input_requires_grad or scales_requires_grad:
@@ -2674,6 +2681,7 @@ def test_interleaved_scaled_clamped_qgeglu(self):
26742681
glu_interleave_size=32,
26752682
input_requires_grad=True,
26762683
scales_requires_grad=True,
2684+
glu_linear_offset=1.0,
26772685
)
26782686

26792687

@@ -3685,7 +3693,13 @@ def test_layernorm_mlp(
36853693
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
36863694
@pytest.mark.parametrize("hidden_size", (128, 256))
36873695
@pytest.mark.parametrize(
3688-
"activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_srelu")
3696+
"activation",
3697+
(
3698+
"scaled_swiglu",
3699+
"scaled_clamped_qgeglu",
3700+
"scaled_clamped_qgeglu_custom",
3701+
"scaled_srelu",
3702+
),
36893703
)
36903704
def test_grouped_mlp(
36913705
self,
@@ -3719,7 +3733,7 @@ def test_grouped_mlp(
37193733
with_quantization = quantization is not None
37203734
if activation == "scaled_swiglu":
37213735
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
3722-
elif activation == "scaled_clamped_qgeglu":
3736+
elif activation.startswith("scaled_clamped_qgeglu"):
37233737
scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
37243738
elif activation == "scaled_srelu":
37253739
scaled_act = te_ops.ScaledSReLU()
@@ -3742,13 +3756,23 @@ def test_grouped_mlp(
37423756
if (
37433757
with_quantization
37443758
and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6")
3745-
and activation == "scaled_clamped_qgeglu"
3759+
and activation.startswith("scaled_clamped_qgeglu")
37463760
and bias
37473761
):
37483762
# TODO: ksivaman: Need to debug numerics for this case.
37493763
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
37503764
fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size
37513765

3766+
# Activation parameters for clamped QGeGLU variants
3767+
if activation == "scaled_clamped_qgeglu_custom":
3768+
geglu_limit = 5.0
3769+
geglu_alpha = 1.5
3770+
geglu_offset = 0.5
3771+
else:
3772+
geglu_limit = 7.0
3773+
geglu_alpha = 1.702
3774+
geglu_offset = 1.0
3775+
37523776
# Random data
37533777
x_ref, x_test = make_reference_and_test_tensors(
37543778
in_shape,
@@ -3840,13 +3864,12 @@ def test_grouped_mlp(
38403864
if activation == "scaled_swiglu":
38413865
x1, x2 = x.chunk(2, dim=-1)
38423866
x = torch.nn.functional.silu(x1) * x2
3843-
elif activation == "scaled_clamped_qgeglu":
3867+
elif activation.startswith("scaled_clamped_qgeglu"):
38443868
x1, x2 = x.chunk(2, dim=-1)
3845-
lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype)
3846-
geglu_alpha = 1.702
3869+
lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype)
38473870
x1c = torch.minimum(x1, lim)
38483871
x2c = torch.clamp(x2, -lim, lim)
3849-
x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c))
3872+
x = (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c))
38503873
elif activation == "scaled_srelu":
38513874
x = torch.nn.functional.relu(x).square()
38523875
else:
@@ -3861,6 +3884,13 @@ def test_grouped_mlp(
38613884

38623885
# Construct operations
38633886
recipe = make_recipe(quantization)
3887+
if activation == "scaled_clamped_qgeglu_custom":
3888+
scaled_act = te_ops.ScaledClampedQGeGLU(
3889+
glu_interleave_size=glu_interleave_size,
3890+
limit=geglu_limit,
3891+
alpha=geglu_alpha,
3892+
glu_linear_offset=geglu_offset,
3893+
)
38643894
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
38653895
fc1 = te_ops.GroupedLinear(
38663896
group_size,

transformer_engine/common/activation/swiglu.cu

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,35 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit,
3939
cudaStream_t stream) {
4040
NVTE_API_CALL(nvte_clamped_swiglu);
4141
using namespace transformer_engine;
42-
ClampedSwiGLUParam param = {limit, alpha};
42+
// Preserve original behavior: linear (gate) component offset is hard-coded to 1.0f.
43+
ClampedSwiGLUParam param = {limit, alpha, /*glu_linear_offset=*/1.0f};
44+
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
45+
}
46+
47+
void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha,
48+
float glu_linear_offset, cudaStream_t stream) {
49+
NVTE_API_CALL(nvte_clamped_swiglu_v2);
50+
using namespace transformer_engine;
51+
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
4352
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
4453
}
4554

4655
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
4756
float limit, float alpha, cudaStream_t stream) {
4857
NVTE_API_CALL(nvte_clamped_dswiglu);
4958
using namespace transformer_engine;
50-
ClampedSwiGLUParam param = {limit, alpha};
59+
// Preserve original behavior: linear (gate) component offset is hard-coded to 1.0f.
60+
ClampedSwiGLUParam param = {limit, alpha, /*glu_linear_offset=*/1.0f};
61+
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
62+
grad, input, output, param, stream);
63+
}
64+
65+
void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTETensor output,
66+
float limit, float alpha, float glu_linear_offset,
67+
cudaStream_t stream) {
68+
NVTE_API_CALL(nvte_clamped_dswiglu_v2);
69+
using namespace transformer_engine;
70+
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
5171
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
5272
grad, input, output, param, stream);
5373
}

transformer_engine/common/cast/fp8/gated_fp8.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
169169
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
170170
bool dgate_elt = true; // gating is ideally an identity function
171171
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
172-
// In case of GPT OSS, clamp the activation and gate values
173-
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
174-
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
172+
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
173+
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
175174
}
176175

177176
if constexpr (IS_BWD) {

transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
245245
float after_gate_elt;
246246
bool dgate_elt = true; // gating is ideally an identity function
247247
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
248-
// In case of GPT OSS, clamp the activation and gate values
249-
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
250-
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
248+
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
249+
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
251250
}
252251
if constexpr (IS_BWD) {
253252
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
@@ -510,9 +509,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
510509
float after_gate_elt;
511510
bool dgate_elt = true;
512511
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
513-
// In case of GPT OSS, clamp the activation and gate values
514-
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
515-
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
512+
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
513+
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
516514
}
517515
if constexpr (IS_BWD) {
518516
float grad_elt = static_cast<float>(in_grad.data.elt[e]);

transformer_engine/common/include/transformer_engine/activation.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
322322
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
323323

324324
/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
325+
*
326+
* \deprecated This function has been deprecated in favor of nvte_clamped_swiglu_v2,
327+
* which exposes a configurable offset for the linear (gate) component.
328+
* This API is preserved for backward compatibility and is equivalent to
329+
* calling nvte_clamped_swiglu_v2 with glu_linear_offset = 1.0.
325330
*
326331
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
327332
* This Gated activation has two differences compared to the original SwiGLU
@@ -341,6 +346,28 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
341346
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
342347
cudaStream_t stream);
343348

349+
/*! \brief Computes the gated Swish activation of the input used in GPT OSS, with a configurable
350+
* offset for the linear (gate) component after clamping.
351+
*
352+
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
353+
* This Gated activation has two differences compared to the original SwiGLU
354+
* 1. Both gate and pre-activations are clipped based on parameter limit.
355+
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
356+
* by original GELU paper https://arxiv.org/pdf/1606.08415
357+
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
358+
* the block quantization (MXFP8) of the specified shape of the block will be used.
359+
*
360+
* \param[in] input Input tensor of shape [N, H * 2].
361+
* \param[in,out] output Output tensor of shape [N, H].
362+
* It computes Act(input[N, :H]) x (input[N, H:] + glu_linear_offset)
363+
* \param[in] limit Clipping limits for gate and pre-activation.
364+
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
365+
* \param[in] glu_linear_offset Offset added to the linear component after clamping (typically 1.0).
366+
* \param[in] stream CUDA stream used for the operation.
367+
*/
368+
void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha,
369+
float glu_linear_offset, cudaStream_t stream);
370+
344371
/*! \brief Computes the gated ReLU activation of the input.
345372
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
346373
* the block quantization (MXFP8) of the specified shape of the block will be used.
@@ -399,6 +426,11 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
399426
cudaStream_t stream);
400427

401428
/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
429+
*
430+
* \deprecated This function has been deprecated in favor of nvte_clamped_dswiglu_v2,
431+
* which exposes a configurable offset for the linear (gate) component.
432+
* This API is preserved for backward compatibility and is equivalent to
433+
* calling nvte_clamped_dswiglu_v2 with glu_linear_offset = 1.0.
402434
*
403435
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
404436
* This activation has two differences compared to the original SwiGLU
@@ -418,6 +450,29 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
418450
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
419451
float limit, float alpha, cudaStream_t stream);
420452

453+
/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS, with a
454+
* configurable offset for the linear (gate) component after clamping.
455+
*
456+
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
457+
* This activation has two differences compared to the original SwiGLU
458+
* 1. Both gate and pre-activations are clipped based on parameter limit.
459+
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
460+
* by original GELU paper https://arxiv.org/pdf/1606.08415
461+
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
462+
* the block quantization (MXFP8) of the specified shape of the block will be used.
463+
*
464+
* \param[in] grad Incoming gradient of shape [N, H].
465+
* \param[in] input Forward input tensor of shape [N, H * 2].
466+
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
467+
* \param[in] limit Clipping limits for gate and pre-activation.
468+
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
469+
* \param[in] glu_linear_offset Offset added to the linear component after clamping (typically 1.0).
470+
* \param[in] stream CUDA stream used for the operation.
471+
*/
472+
void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTETensor output,
473+
float limit, float alpha, float glu_linear_offset,
474+
cudaStream_t stream);
475+
421476
/*! \brief Computes the gated ReLU activation gradient.
422477
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
423478
* the block quantization (MXFP8) of the specified shape of the block will be used.

transformer_engine/common/util/math.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ struct Empty {};
1313

1414
struct ClampedSwiGLUParam {
1515
float limit;
16-
float alpha = 1.702f; // Default value for QuickGELU
16+
float alpha = 1.702f; // Default value for QuickGELU
17+
float glu_linear_offset = 1.0f; // Offset added to the linear (gate) component after clamping
1718
};
1819

1920
template <typename OType, typename IType>

transformer_engine/common/util/vectorized_pointwise.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,8 @@ __launch_bounds__(unary_kernel_threads) __global__
434434
ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
435435

436436
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
437-
// Clamp the gated value and add 1 at the end
438437
ComputeType limit = p.limit;
439-
val2 = std::min(std::max(-limit, val2), limit) + 1;
438+
val2 = std::min(std::max(-limit, val2), limit) + p.glu_linear_offset;
440439
}
441440
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
442441
if (requires_amax) {
@@ -542,10 +541,9 @@ __launch_bounds__(unary_kernel_threads) __global__
542541
bool dgate_in = true;
543542

544543
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
545-
// In case of GPT OSS, clamp the activation and gate values
546544
const ComputeType limit = p.limit;
547-
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
548-
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
545+
dgate_in = gate_in <= limit && gate_in >= -limit;
546+
gate_in = std::min(std::max(-limit, gate_in), limit) + p.glu_linear_offset;
549547
}
550548

551549
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;

0 commit comments

Comments
 (0)