@@ -13,62 +13,13 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
1313 act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
1414}
1515
16- void nvte_group_gelu (const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
17- NVTE_API_CALL (nvte_group_gelu);
18- using namespace transformer_engine ;
19- constexpr bool IS_ACT = true ;
20- dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr ,
21- stream);
22- }
23-
2416void nvte_dgelu (const NVTETensor grad, const NVTETensor input, NVTETensor output,
2517 cudaStream_t stream) {
2618 NVTE_API_CALL (nvte_dgelu);
2719 using namespace transformer_engine ;
2820 dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
2921}
3022
31- void nvte_group_dgelu (const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
32- NVTEGroupedTensor output, cudaStream_t stream) {
33- NVTE_API_CALL (nvte_group_dgelu);
34- using namespace transformer_engine ;
35- NVTEGroupedTensor dbias = nullptr ;
36- NVTETensor workspace = nullptr ;
37-
38- constexpr bool IS_DBIAS = false ;
39- constexpr bool IS_DACT = true ;
40-
41- dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
42- grad, input, output, dbias, workspace, nullptr , stream);
43- }
44-
45- void nvte_quantize_dbias_dgelu (const NVTETensor input, const NVTETensor activation_input,
46- NVTETensor output, NVTETensor dbias, NVTETensor workspace,
47- cudaStream_t stream) {
48- NVTE_API_CALL (nvte_quantize_dbias_dgelu);
49- using namespace transformer_engine ;
50-
51- constexpr bool IS_DBIAS = true ;
52- constexpr bool IS_DACT = true ;
53-
54- dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
55- input, activation_input, output, dbias, workspace, nullptr , stream);
56- }
57-
58- void nvte_group_quantize_dbias_dgelu (const NVTEGroupedTensor input,
59- const NVTEGroupedTensor activation_input,
60- NVTEGroupedTensor output, NVTEGroupedTensor dbias,
61- NVTETensor workspace, cudaStream_t stream) {
62- NVTE_API_CALL (nvte_group_quantize_dbias_dgelu);
63- using namespace transformer_engine ;
64-
65- constexpr bool IS_DBIAS = true ;
66- constexpr bool IS_DACT = true ;
67-
68- dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
69- input, activation_input, output, dbias, workspace, nullptr , stream);
70- }
71-
7223void nvte_geglu (const NVTETensor input, NVTETensor output, cudaStream_t stream) {
7324 NVTE_API_CALL (nvte_geglu);
7425 using namespace transformer_engine ;
@@ -90,63 +41,13 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
9041 act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
9142}
9243
93- void nvte_group_qgelu (const NVTEGroupedTensor input, NVTEGroupedTensor output,
94- cudaStream_t stream) {
95- NVTE_API_CALL (nvte_group_qgelu);
96- using namespace transformer_engine ;
97- constexpr bool IS_ACT = true ;
98- dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr ,
99- stream);
100- }
101-
10244void nvte_dqgelu (const NVTETensor grad, const NVTETensor input, NVTETensor output,
10345 cudaStream_t stream) {
10446 NVTE_API_CALL (nvte_dqgelu);
10547 using namespace transformer_engine ;
10648 dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
10749}
10850
109- void nvte_group_dqgelu (const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
110- NVTEGroupedTensor output, cudaStream_t stream) {
111- NVTE_API_CALL (nvte_group_dqgelu);
112- using namespace transformer_engine ;
113- NVTEGroupedTensor dbias = nullptr ;
114- NVTETensor workspace = nullptr ;
115-
116- constexpr bool IS_DBIAS = false ;
117- constexpr bool IS_DACT = true ;
118-
119- dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
120- grad, input, output, dbias, workspace, nullptr , stream);
121- }
122-
123- void nvte_quantize_dbias_dqgelu (const NVTETensor input, const NVTETensor activation_input,
124- NVTETensor output, NVTETensor dbias, NVTETensor workspace,
125- cudaStream_t stream) {
126- NVTE_API_CALL (nvte_quantize_dbias_dqgelu);
127- using namespace transformer_engine ;
128-
129- constexpr bool IS_DBIAS = true ;
130- constexpr bool IS_DACT = true ;
131-
132- dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
133- input, activation_input, output, dbias, workspace, nullptr , stream);
134- }
135-
136- void nvte_group_quantize_dbias_dqgelu (const NVTEGroupedTensor input,
137- const NVTEGroupedTensor activation_input,
138- NVTEGroupedTensor output, NVTEGroupedTensor dbias,
139- NVTETensor workspace, cudaStream_t stream) {
140- NVTE_API_CALL (nvte_group_quantize_dbias_dqgelu);
141- using namespace transformer_engine ;
142-
143- constexpr bool IS_DBIAS = true ;
144- constexpr bool IS_DACT = true ;
145-
146- dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
147- input, activation_input, output, dbias, workspace, nullptr , stream);
148- }
149-
15051void nvte_qgeglu (const NVTETensor input, NVTETensor output, cudaStream_t stream) {
15152 NVTE_API_CALL (nvte_qgeglu);
15253 using namespace transformer_engine ;
0 commit comments