@@ -1469,7 +1469,8 @@ def forward(
14691469 fwd_nominal_dtype = q .dtype
14701470 is_input_fp8 = isinstance (q , QuantizedTensorStorage )
14711471 is_output_fp8 = fp8_output
1472- is_bwd_fp8 = int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" ))
1472+ _use_fp8_dpa_bwd = bool (int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" )))
1473+ is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
14731474 # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
14741475 # may be different from fp8_meta["recipe"]
14751476 fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
@@ -2063,20 +2064,17 @@ def forward(
20632064 # prepare for return and ctx saves
20642065 out_fp8 = None
20652066 out_f16 = out .to (fwd_nominal_dtype )
2066- if fp8 and (
2067- is_output_fp8
2068- or (
2069- is_bwd_fp8
2070- and not (fp8_recipe .float8_current_scaling () and _dpa_fp8_cs_o_in_f16 )
2071- and not fp8_recipe .mxfp8 ()
2072- )
2067+ if (fp8 and is_output_fp8 ) or (
2068+ is_bwd_fp8
2069+ and not (fp8_recipe .float8_current_scaling () and _dpa_fp8_cs_o_in_f16 )
2070+ and not fp8_recipe .mxfp8 ()
20732071 ):
20742072 out_fp8 = O_quantizer (out_f16 )
20752073 out_ret = out_fp8 if (fp8 and is_output_fp8 ) else out_f16
20762074
20772075 ctx .layer_number = layer_number
20782076 ctx .fp8_recipe = fp8_recipe
2079- ctx .fp8 = fp8 and is_bwd_fp8
2077+ ctx .fp8 = is_bwd_fp8
20802078
20812079 kv_fp8 = None
20822080 kv = p2p_comm_buffers [- 1 ]
@@ -3063,7 +3061,8 @@ def forward(
30633061 ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage."
30643062 is_input_fp8 = isinstance (q , QuantizedTensorStorage )
30653063 is_output_fp8 = fp8_output
3066- is_bwd_fp8 = int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" ))
3064+ _use_fp8_dpa_bwd = bool (int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" )))
3065+ is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
30673066 fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
30683067 if fp8_meta is not None and fp8_meta .get ("local_recipes" , None ) is not None :
30693068 fp8_recipe = fp8_meta ["local_recipes" ][0 ]
@@ -3306,12 +3305,12 @@ def forward(
33063305 or (fp8_recipe .float8_current_scaling () and not _dpa_fp8_cs_o_in_f16 )
33073306 )
33083307 )
3309- if fp8 and ( is_output_fp8 or bwd_requires_o_fp8 ) :
3308+ if ( fp8 and is_output_fp8 ) or bwd_requires_o_fp8 :
33103309 out_fp8 = O_quantizer (out_f16 )
33113310 out_ret = out_fp8 if is_output_fp8 else out_f16
33123311
33133312 # save tensors for backward
3314- ctx .fp8 = fp8 and is_bwd_fp8
3313+ ctx .fp8 = is_bwd_fp8
33153314 ctx .fp8_recipe = fp8_recipe
33163315 fp8_tensors = (None , None , None , None )
33173316 f16_tensors = (None , None , None , None )
@@ -3931,7 +3930,8 @@ def forward(
39313930 ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage."
39323931 is_input_fp8 = isinstance (q , QuantizedTensorStorage )
39333932 is_output_fp8 = fp8_output
3934- is_bwd_fp8 = int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" ))
3933+ _use_fp8_dpa_bwd = bool (int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" )))
3934+ is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
39353935 # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
39363936 # may be different from fp8_meta["recipe"]
39373937 fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
@@ -4161,7 +4161,7 @@ def forward(
41614161 ctx .orig_o_shape = orig_o_shape
41624162
41634163 # save tensors for backward
4164- ctx .fp8 = fp8 and is_bwd_fp8
4164+ ctx .fp8 = is_bwd_fp8
41654165 fp8_tensors = (None , None , None , None )
41664166 f16_tensors = (None , None , None , None )
41674167 if is_training :
0 commit comments