Skip to content

Commit 6a56e15

Browse files
Cleanup TE cuda graphs with the right api (#3459)
Signed-off-by: Gautham Kollu <gkollu@nvidia.com> Signed-off-by: gautham-kollu <gkollu@nvidia.com>
1 parent 0fc0e61 commit 6a56e15

1 file changed

Lines changed: 1 addition & 5 deletions

File tree

src/megatron/bridge/training/train.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,11 +1567,7 @@ def _delete_cuda_graphs(cuda_graph_helper: TECudaGraphHelper):
15671567

15681568
# Cleanup CUDA graphs object for partial Cuda-graphs (implemented in TransformerEngine)
15691569
if cuda_graph_helper is not None:
1570-
for layers in cuda_graph_helper.callables_per_chunk:
1571-
for layer in layers:
1572-
for cuda_graph in layer.cuda_graphs:
1573-
del cuda_graph
1574-
del layer.cuda_graphs
1570+
cuda_graph_helper.delete_cuda_graphs()
15751571

15761572
# Run GC to collect the freshed object
15771573
gc.collect()

0 commit comments

Comments
 (0)