Skip to content

Commit 287403f

Browse files
Merge branch 'main' into flash_attn_pad_bw_seqs
2 parents d734891 + a014300 commit 287403f

20 files changed

Lines changed: 503 additions & 308 deletions
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
# A workflow to automatically label the contributions as community/org
6+
name: Label community contributions
7+
8+
on:
9+
pull_request_target:
10+
types: [opened, reopened, ready_for_review, synchronize]
11+
12+
permissions:
13+
contents: read
14+
issues: write
15+
16+
jobs:
17+
label:
18+
runs-on: ubuntu-latest
19+
steps:
20+
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3
21+
with:
22+
script: |
23+
const pr = context.payload.pull_request;
24+
const user = pr.user.login;
25+
const association = pr.author_association;
26+
27+
const communityLabel = "community-contribution";
28+
const orgLabel = "org-contribution";
29+
30+
let targetLabel = null;
31+
32+
const isOrgMember =
33+
association === "MEMBER" || association === "OWNER";
34+
35+
if (!isOrgMember) {
36+
targetLabel = communityLabel;
37+
} else {
38+
let permission = "none";
39+
40+
try {
41+
const res = await github.rest.repos.getCollaboratorPermissionLevel({
42+
owner: context.repo.owner,
43+
repo: context.repo.repo,
44+
username: user,
45+
});
46+
permission = res.data.permission;
47+
} catch (e) {
48+
if (e.status !== 404) throw e;
49+
}
50+
51+
const isCore = permission === "write" || permission === "admin";
52+
53+
if (!isCore) {
54+
targetLabel = orgLabel;
55+
}
56+
}
57+
58+
if (targetLabel) {
59+
await github.rest.issues.addLabels({
60+
owner: context.repo.owner,
61+
repo: context.repo.repo,
62+
issue_number: pr.number,
63+
labels: [targetLabel],
64+
});
65+
}

transformer_engine/common/CMakeLists.txt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,22 @@ list(APPEND transformer_engine_cuda_sources
212212
list(APPEND transformer_engine_cuda_arch_specific_sources
213213
fused_attn/flash_attn.cu
214214
activation/gelu.cu
215+
activation/gelu_dbias.cu
216+
activation/gelu_grouped.cu
217+
activation/gelu_grouped_dbias.cu
215218
activation/glu.cu
216219
activation/relu.cu
220+
activation/relu_dbias.cu
221+
activation/relu_grouped.cu
222+
activation/relu_grouped_dbias.cu
217223
activation/swiglu.cu
224+
activation/swiglu_dbias.cu
225+
activation/swiglu_grouped.cu
226+
activation/swiglu_grouped_dbias.cu
218227
cast/cast.cu
228+
cast/cast_dbias.cu
229+
cast/cast_grouped.cu
230+
cast/cast_grouped_dbias.cu
219231
gemm/cutlass_grouped_gemm.cu
220232
hadamard_transform/group_hadamard_transform.cu
221233
hadamard_transform/graph_safe_group_hadamard_transform.cu
@@ -447,9 +459,18 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
447459
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
448460
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
449461
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
462+
activation/gelu_dbias.cu
463+
activation/gelu_grouped.cu
464+
activation/gelu_grouped_dbias.cu
450465
activation/glu.cu
451466
activation/relu.cu
452-
activation/swiglu.cu)
467+
activation/relu_dbias.cu
468+
activation/relu_grouped.cu
469+
activation/relu_grouped_dbias.cu
470+
activation/swiglu.cu
471+
activation/swiglu_dbias.cu
472+
activation/swiglu_grouped.cu
473+
activation/swiglu_grouped_dbias.cu)
453474
endif()
454475

455476
foreach(cuda_source IN LISTS nvte_sources_with_fast_math)

transformer_engine/common/activation/gelu.cu

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -13,62 +13,13 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
1313
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
1414
}
1515

16-
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
17-
NVTE_API_CALL(nvte_group_gelu);
18-
using namespace transformer_engine;
19-
constexpr bool IS_ACT = true;
20-
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
21-
stream);
22-
}
23-
2416
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
2517
cudaStream_t stream) {
2618
NVTE_API_CALL(nvte_dgelu);
2719
using namespace transformer_engine;
2820
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
2921
}
3022

31-
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
32-
NVTEGroupedTensor output, cudaStream_t stream) {
33-
NVTE_API_CALL(nvte_group_dgelu);
34-
using namespace transformer_engine;
35-
NVTEGroupedTensor dbias = nullptr;
36-
NVTETensor workspace = nullptr;
37-
38-
constexpr bool IS_DBIAS = false;
39-
constexpr bool IS_DACT = true;
40-
41-
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
42-
grad, input, output, dbias, workspace, nullptr, stream);
43-
}
44-
45-
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
46-
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
47-
cudaStream_t stream) {
48-
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
49-
using namespace transformer_engine;
50-
51-
constexpr bool IS_DBIAS = true;
52-
constexpr bool IS_DACT = true;
53-
54-
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
55-
input, activation_input, output, dbias, workspace, nullptr, stream);
56-
}
57-
58-
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
59-
const NVTEGroupedTensor activation_input,
60-
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
61-
NVTETensor workspace, cudaStream_t stream) {
62-
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
63-
using namespace transformer_engine;
64-
65-
constexpr bool IS_DBIAS = true;
66-
constexpr bool IS_DACT = true;
67-
68-
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
69-
input, activation_input, output, dbias, workspace, nullptr, stream);
70-
}
71-
7223
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
7324
NVTE_API_CALL(nvte_geglu);
7425
using namespace transformer_engine;
@@ -90,63 +41,13 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
9041
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
9142
}
9243

93-
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
94-
cudaStream_t stream) {
95-
NVTE_API_CALL(nvte_group_qgelu);
96-
using namespace transformer_engine;
97-
constexpr bool IS_ACT = true;
98-
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
99-
stream);
100-
}
101-
10244
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
10345
cudaStream_t stream) {
10446
NVTE_API_CALL(nvte_dqgelu);
10547
using namespace transformer_engine;
10648
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
10749
}
10850

109-
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
110-
NVTEGroupedTensor output, cudaStream_t stream) {
111-
NVTE_API_CALL(nvte_group_dqgelu);
112-
using namespace transformer_engine;
113-
NVTEGroupedTensor dbias = nullptr;
114-
NVTETensor workspace = nullptr;
115-
116-
constexpr bool IS_DBIAS = false;
117-
constexpr bool IS_DACT = true;
118-
119-
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
120-
grad, input, output, dbias, workspace, nullptr, stream);
121-
}
122-
123-
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
124-
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
125-
cudaStream_t stream) {
126-
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
127-
using namespace transformer_engine;
128-
129-
constexpr bool IS_DBIAS = true;
130-
constexpr bool IS_DACT = true;
131-
132-
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
133-
input, activation_input, output, dbias, workspace, nullptr, stream);
134-
}
135-
136-
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
137-
const NVTEGroupedTensor activation_input,
138-
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
139-
NVTETensor workspace, cudaStream_t stream) {
140-
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
141-
using namespace transformer_engine;
142-
143-
constexpr bool IS_DBIAS = true;
144-
constexpr bool IS_DACT = true;
145-
146-
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
147-
input, activation_input, output, dbias, workspace, nullptr, stream);
148-
}
149-
15051
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
15152
NVTE_API_CALL(nvte_qgeglu);
15253
using namespace transformer_engine;
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include "../util/math.h"
8+
#include "./activation_template.h"
9+
10+
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
11+
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
12+
cudaStream_t stream) {
13+
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
14+
using namespace transformer_engine;
15+
16+
constexpr bool IS_DBIAS = true;
17+
constexpr bool IS_DACT = true;
18+
19+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
20+
input, activation_input, output, dbias, workspace, nullptr, stream);
21+
}
22+
23+
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
24+
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
25+
cudaStream_t stream) {
26+
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
27+
using namespace transformer_engine;
28+
29+
constexpr bool IS_DBIAS = true;
30+
constexpr bool IS_DACT = true;
31+
32+
dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
33+
input, activation_input, output, dbias, workspace, nullptr, stream);
34+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include "../util/math.h"
8+
#include "./activation_template.h"
9+
10+
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
11+
NVTE_API_CALL(nvte_group_gelu);
12+
using namespace transformer_engine;
13+
constexpr bool IS_ACT = true;
14+
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
15+
stream);
16+
}
17+
18+
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
19+
NVTEGroupedTensor output, cudaStream_t stream) {
20+
NVTE_API_CALL(nvte_group_dgelu);
21+
using namespace transformer_engine;
22+
NVTEGroupedTensor dbias = nullptr;
23+
NVTETensor workspace = nullptr;
24+
25+
constexpr bool IS_DBIAS = false;
26+
constexpr bool IS_DACT = true;
27+
28+
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
29+
grad, input, output, dbias, workspace, nullptr, stream);
30+
}
31+
32+
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
33+
cudaStream_t stream) {
34+
NVTE_API_CALL(nvte_group_qgelu);
35+
using namespace transformer_engine;
36+
constexpr bool IS_ACT = true;
37+
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
38+
stream);
39+
}
40+
41+
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
42+
NVTEGroupedTensor output, cudaStream_t stream) {
43+
NVTE_API_CALL(nvte_group_dqgelu);
44+
using namespace transformer_engine;
45+
NVTEGroupedTensor dbias = nullptr;
46+
NVTETensor workspace = nullptr;
47+
48+
constexpr bool IS_DBIAS = false;
49+
constexpr bool IS_DACT = true;
50+
51+
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
52+
grad, input, output, dbias, workspace, nullptr, stream);
53+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include "../util/math.h"
8+
#include "./activation_template.h"
9+
10+
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
11+
const NVTEGroupedTensor activation_input,
12+
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
13+
NVTETensor workspace, cudaStream_t stream) {
14+
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
15+
using namespace transformer_engine;
16+
17+
constexpr bool IS_DBIAS = true;
18+
constexpr bool IS_DACT = true;
19+
20+
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
21+
input, activation_input, output, dbias, workspace, nullptr, stream);
22+
}
23+
24+
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
25+
const NVTEGroupedTensor activation_input,
26+
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
27+
NVTETensor workspace, cudaStream_t stream) {
28+
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
29+
using namespace transformer_engine;
30+
31+
constexpr bool IS_DBIAS = true;
32+
constexpr bool IS_DACT = true;
33+
34+
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
35+
input, activation_input, output, dbias, workspace, nullptr, stream);
36+
}

0 commit comments

Comments
 (0)