Skip to content

Commit fc92624

Browse files
xrennvidiapre-commit-ci[bot]greptile-apps[bot]
authored
Add the getter and setter of skip_fp8_weight_update_tensor (#3015)
* add the getter and setter of skip_fp8_weight_update_tensor Signed-off-by: Xiaowei Ren <xren@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/quantization.py return type fix Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> --------- Signed-off-by: Xiaowei Ren <xren@nvidia.com> Signed-off-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 0e58073 commit fc92624

2 files changed

Lines changed: 16 additions & 9 deletions

File tree

transformer_engine/pytorch/graph.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,7 @@ def _make_graphed_callables(
324324

325325
if cache_quantized_params:
326326
# Initialize flag that controls FP8 weight updates
327-
qstate = FP8GlobalStateManager.quantization_state
328-
if qstate.skip_fp8_weight_update_tensor is None:
329-
qstate.skip_fp8_weight_update_tensor = torch.empty(
330-
1, dtype=torch.float32, device="cuda"
331-
)
332-
qstate.skip_fp8_weight_update_tensor.fill_(False)
327+
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
333328

334329
# Check callables
335330
for c in callables:
@@ -841,9 +836,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i
841836
# Set flag for whether to update FP8 weight updates
842837
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
843838
if ctx.is_first_module and skip_fp8_weight_update is not None:
844-
FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_(
845-
skip_fp8_weight_update
846-
)
839+
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)
847840
ctx.cuda_graph_stream = cuda_graph_stream
848841
ctx.cuda_graph_event = cuda_graph_event
849842
# Copy values from new tensors into static tensors

transformer_engine/pytorch/quantization.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,20 @@ class FP8GlobalStateManager:
409409

410410
quantization_state = FP8GlobalState()
411411

412+
@classmethod
413+
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
414+
"""Set the skip fp8 weight update tensor"""
415+
if cls.quantization_state.skip_fp8_weight_update_tensor is None:
416+
cls.quantization_state.skip_fp8_weight_update_tensor = torch.empty(
417+
1, dtype=torch.float32, device="cuda"
418+
)
419+
cls.quantization_state.skip_fp8_weight_update_tensor.fill_(skip)
420+
421+
@classmethod
422+
def get_skip_fp8_weight_update_tensor(cls) -> Optional[torch.Tensor]:
423+
"""Get the skip fp8 weight update tensor"""
424+
return cls.quantization_state.skip_fp8_weight_update_tensor
425+
412426
@classmethod
413427
def reset(cls) -> None:
414428
"""Reset the global state"""

0 commit comments

Comments
 (0)