Skip to content

Commit 45fb909

Browse files
cyanguwaKshitijLakhani
authored andcommitted
[PyTorch] Fix CP A2A F16 when NVTE_FP8_DPA_BWD=1 (#2917)
fix fp8 and is_bwd_fp8 relationship Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
1 parent a506ec5 commit 45fb909

1 file changed

Lines changed: 14 additions & 14 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,7 +1469,8 @@ def forward(
14691469
fwd_nominal_dtype = q.dtype
14701470
is_input_fp8 = isinstance(q, QuantizedTensorStorage)
14711471
is_output_fp8 = fp8_output
1472-
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
1472+
_use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
1473+
is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
14731474
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
14741475
# may be different from fp8_meta["recipe"]
14751476
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
@@ -2063,20 +2064,17 @@ def forward(
20632064
# prepare for return and ctx saves
20642065
out_fp8 = None
20652066
out_f16 = out.to(fwd_nominal_dtype)
2066-
if fp8 and (
2067-
is_output_fp8
2068-
or (
2069-
is_bwd_fp8
2070-
and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
2071-
and not fp8_recipe.mxfp8()
2072-
)
2067+
if (fp8 and is_output_fp8) or (
2068+
is_bwd_fp8
2069+
and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
2070+
and not fp8_recipe.mxfp8()
20732071
):
20742072
out_fp8 = O_quantizer(out_f16)
20752073
out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16
20762074

20772075
ctx.layer_number = layer_number
20782076
ctx.fp8_recipe = fp8_recipe
2079-
ctx.fp8 = fp8 and is_bwd_fp8
2077+
ctx.fp8 = is_bwd_fp8
20802078

20812079
kv_fp8 = None
20822080
kv = p2p_comm_buffers[-1]
@@ -3063,7 +3061,8 @@ def forward(
30633061
), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage."
30643062
is_input_fp8 = isinstance(q, QuantizedTensorStorage)
30653063
is_output_fp8 = fp8_output
3066-
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
3064+
_use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
3065+
is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
30673066
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
30683067
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
30693068
fp8_recipe = fp8_meta["local_recipes"][0]
@@ -3306,12 +3305,12 @@ def forward(
33063305
or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)
33073306
)
33083307
)
3309-
if fp8 and (is_output_fp8 or bwd_requires_o_fp8):
3308+
if (fp8 and is_output_fp8) or bwd_requires_o_fp8:
33103309
out_fp8 = O_quantizer(out_f16)
33113310
out_ret = out_fp8 if is_output_fp8 else out_f16
33123311

33133312
# save tensors for backward
3314-
ctx.fp8 = fp8 and is_bwd_fp8
3313+
ctx.fp8 = is_bwd_fp8
33153314
ctx.fp8_recipe = fp8_recipe
33163315
fp8_tensors = (None, None, None, None)
33173316
f16_tensors = (None, None, None, None)
@@ -3931,7 +3930,8 @@ def forward(
39313930
), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage."
39323931
is_input_fp8 = isinstance(q, QuantizedTensorStorage)
39333932
is_output_fp8 = fp8_output
3934-
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
3933+
_use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
3934+
is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
39353935
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
39363936
# may be different from fp8_meta["recipe"]
39373937
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
@@ -4161,7 +4161,7 @@ def forward(
41614161
ctx.orig_o_shape = orig_o_shape
41624162

41634163
# save tensors for backward
4164-
ctx.fp8 = fp8 and is_bwd_fp8
4164+
ctx.fp8 = is_bwd_fp8
41654165
fp8_tensors = (None, None, None, None)
41664166
f16_tensors = (None, None, None, None)
41674167
if is_training:

0 commit comments

Comments
 (0)