[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
Conversation
Greptile SummaryThis PR makes the
Confidence Score: 4/5The core CUDA kernels, PyTorch/JAX bindings, and the standard ClampedSwiGLU path are all correct and backward-compatible. The fused grouped MLP cuDNN path has a gap where a non-default glu_linear_offset is silently ignored when cuDNN FE is in the [1.23, 1.24) range. The ClampedSwiGLU standard path is mathematically correct across forward and backward for all kernel types. The cuDNN fused grouped MLP path does not pass linear_offset to the cuDNN kernel when _pass_geglu_runtime_params is False, meaning a user on cuDNN FE >= 1.23.0 but < 1.24.0 with a non-default glu_linear_offset gets silently wrong results. transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py and backward_grouped_mlp.py need a guard or warning when non-default parameters are configured but the cuDNN FE version cannot support them. Important Files Changed
Reviews (18): Last reviewed commit: "Merge branch 'main' into swiglu_offset" | Re-trigger Greptile |
| * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| */ |
There was a problem hiding this comment.
nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.
timmoon10
left a comment
There was a problem hiding this comment.
The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:
TransformerEngine/transformer_engine/pytorch/ops/_common.py
Lines 180 to 183 in df0025b
|
/te-ci |
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
|
||
| void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, | ||
| cudaStream_t stream) { | ||
| float glu_linear_offset, cudaStream_t stream) { |
There was a problem hiding this comment.
Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
/te-ci |
|
/te-ci |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Overall looks pretty good from the JAX side, thanks for adding the JAX changes too! Left a couple small comments
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
/te-ci |
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
for more information, see https://pre-commit.ci
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM from JAX perspective! Once Tim/Varun approve fpr PyTorch changes and CI passes you can merge it. Thanks!
|
/te-ci |
|
/te-ci pytorch |
|
/te-ci |
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
|
Pipeline 52303026 |
Description
The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: