Skip to content

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938

Merged
timmoon10 merged 20 commits into
NVIDIA:mainfrom
hxbai:swiglu_offset
May 26, 2026
Merged

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
timmoon10 merged 20 commits into
NVIDIA:mainfrom
hxbai:swiglu_offset

Conversation

@hxbai
Copy link
Copy Markdown
Contributor

@hxbai hxbai commented Apr 28, 2026

Description

The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR makes the glu_linear_offset parameter of ClampedSwiGLU configurable (previously hardcoded to 1.0 following GPT-OSS), enabling support for DeepSeek-V4 which uses glu_linear_offset=0.0. The public C API is preserved via a clean _v2 versioning pattern while the original symbols are deprecated in-place.

  • C layer: nvte_clamped_swiglu_v2/nvte_clamped_dswiglu_v2 added to activation.h; ClampedSwiGLUParam gains glu_linear_offset; all CUDA kernels (vectorized_pointwise.h, gated_fp8.cuh, gated_mxfp8.cuh) correctly apply the offset after clamping and set the dgate boolean mask before adding offset in backward.
  • PyTorch/JAX bindings: ClampedSwiGLU, ScaledClampedQGeGLU, ClampedSwigluParams (JAX) all updated with default 1.0 preserving backward compatibility; pybind defaults and ONNX export path are consistent.
  • Fused grouped MLP (cuDNN path): linear_offset is forwarded to cuDNN only when FE ≥ 1.24.0 (_pass_geglu_runtime_params); no guard exists for non-default offsets when FE is in the [1.23, 1.24) range, which silently produces incorrect results in both forward and backward passes.

Confidence Score: 4/5

The core CUDA kernels, PyTorch/JAX bindings, and the standard ClampedSwiGLU path are all correct and backward-compatible. The fused grouped MLP cuDNN path has a gap where a non-default glu_linear_offset is silently ignored when cuDNN FE is in the [1.23, 1.24) range.

The ClampedSwiGLU standard path is mathematically correct across forward and backward for all kernel types. The cuDNN fused grouped MLP path does not pass linear_offset to the cuDNN kernel when _pass_geglu_runtime_params is False, meaning a user on cuDNN FE >= 1.23.0 but < 1.24.0 with a non-default glu_linear_offset gets silently wrong results.

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py and backward_grouped_mlp.py need a guard or warning when non-default parameters are configured but the cuDNN FE version cannot support them.

Important Files Changed

Filename Overview
transformer_engine/common/activation/swiglu.cu Adds nvte_clamped_swiglu_v2/nvte_clamped_dswiglu_v2 with configurable offset; old API preserved with hardcoded 1.0 offset and deprecation notice.
transformer_engine/common/include/transformer_engine/activation.h Properly deprecates nvte_clamped_swiglu/nvte_clamped_dswiglu in favor of new v2 functions; old symbols preserved for ABI compatibility.
transformer_engine/common/util/vectorized_pointwise.h Forward and backward CUDA kernel correctly applies glu_linear_offset after clamping; backward computes dgate_in before adding offset (correct) and uses offset-adjusted gate_in for the activation gradient (mathematically correct).
transformer_engine/pytorch/ops/basic/swiglu.py Adds glu_linear_offset param to ClampedSwiGLU and ScaledClampedQGeGLU with correct default 1.0; helper methods properly forward the offset.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Passes glu_linear_offset to cuDNN only when FE >= 1.24.0; no guard exists when FE is in [1.23, 1.24) and a non-default offset is configured, leading to silent incorrect behavior.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Same gap as forward_grouped_mlp: cuDNN backward kernel is not given linear_offset when _pass_geglu_runtime_params is False, causing silent numerical mismatch for non-default offsets.
transformer_engine/jax/cpp_extensions/activation.py Adds glu_linear_offset to ClampedSwigluParams and correctly serializes it for XLA FFI. Reference function for clamped_linear correctly applies the offset.

Reviews (18): Last reviewed commit: "Merge branch 'main' into swiglu_offset" | Re-trigger Greptile

Comment on lines 339 to 341
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public C API change

nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:

elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):

@timmoon10
Copy link
Copy Markdown
Member

/te-ci

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as draft April 29, 2026 00:28
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as ready for review April 29, 2026 01:01

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
float glu_linear_offset, cudaStream_t stream) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrited this part

vthumbe1503 and others added 3 commits May 6, 2026 11:38
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci

hxbai added 2 commits May 12, 2026 15:13
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks pretty good from the JAX side, thanks for adding the JAX changes too! Left a couple small comments

Comment thread transformer_engine/jax/csrc/extensions.h
Comment thread transformer_engine/jax/cpp_extensions/activation.py
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

/te-ci

hxbai and others added 2 commits May 16, 2026 06:38
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 16, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from JAX perspective! Once Tim/Varun approve fpr PyTorch changes and CI passes you can merge it. Thanks!

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

/te-ci

@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci pytorch

@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci

hxbai added 2 commits May 20, 2026 23:14
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
Comment thread transformer_engine/pytorch/ops/_common.py
@timmoon10
Copy link
Copy Markdown
Member

/te-ci

timmoon10
timmoon10 previously approved these changes May 22, 2026
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Copy Markdown
Member

Pipeline 52303026

@timmoon10 timmoon10 merged commit 7e6ffcc into NVIDIA:main May 26, 2026
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants