44 * See LICENSE for license information.
55 ************************************************************************/
66
7+ #include < cstdlib>
78#include < transformer_engine/layer_norm.h>
9+ #include < string>
810#include < vector>
911#include " ln.h"
1012#include " ../common.h"
@@ -31,6 +33,9 @@ Compute always in FP32
3133namespace transformer_engine {
3234namespace layer_norm {
3335
36+ // [Augment] Forward declare helper kernel added to avoid using memset.
37+ void launch_zero_out (void *, size_t , size_t , cudaStream_t);
38+
3439using namespace transformer_engine ;
3540
3641// Create registries and provide runtime versions of config hash functions.
@@ -232,16 +237,36 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
232237 params.barrier = reinterpret_cast <int *>(barrier->data .dptr );
233238 }
234239
240+ // NOTE[augment]: this envvar exists to restore the prior behavior of TE (ie, use a memset
241+ // kernel. So if you want to get the upstream behavior, run with NVTE_FORCE_MEMSET=1.
242+ const char *envval = std::getenv (" NVTE_FORCE_MEMSET" );
243+ bool force_memset = (envval != nullptr ) && (std::string (envval) == " 1" );
235244 // Clear buffers
236245 if ( params.fp8_out ) {
237- cudaMemsetAsync (params.amax , 0 ,
238- layer_norm::product (z->amax .shape ) *
239- typeToSize (z->amax .dtype ), stream);
246+ if ( force_memset ) {
247+ cudaMemsetAsync (params.amax , 0 ,
248+ layer_norm::product (z->amax .shape ) *
249+ typeToSize (z->amax .dtype ), stream);
250+ } else {
251+ // [Augment] Use the zero-out kernel, not memset.
252+ layer_norm::launch_zero_out (params.amax ,
253+ layer_norm::product (z->amax .shape ),
254+ typeToSize (z->amax .dtype ),
255+ stream);
256+ }
240257 }
241258 if ( launch_params.barrier_size > 0 ) {
242- cudaMemsetAsync (params.barrier , 0 ,
243- layer_norm::product (barrier->data .shape ) *
244- typeToSize (barrier->data .dtype ), stream);
259+ if ( force_memset ) {
260+ cudaMemsetAsync (params.barrier , 0 ,
261+ layer_norm::product (barrier->data .shape ) *
262+ typeToSize (barrier->data .dtype ), stream);
263+ } else {
264+ // [Augment] Use the zero-out kernel, not memset.
265+ layer_norm::launch_zero_out (params.barrier ,
266+ layer_norm::product (barrier->data .shape ),
267+ typeToSize (barrier->data .dtype ),
268+ stream);
269+ }
245270 }
246271
247272 // Launch the kernel.
0 commit comments