Skip to content

Commit ace2a96

Browse files
[PyTorch] Allocate grouped linear wgrads as tensor views (#3049)
* Allocate grouped linear wgrads as views Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 439ca21 commit ace2a96

4 files changed

Lines changed: 27 additions & 16 deletions

File tree

transformer_engine/pytorch/csrc/extensions/allocate.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
namespace transformer_engine {
1313
namespace pytorch {
1414

15+
/* Allocate multiple PyTorch tensors backed by the same buffer.
16+
*
17+
* Use with caution and avoid exposing externally.
18+
*
19+
* In order to reduce CPU overhead, we compute pointer offsets
20+
* manually and construct PyTorch tensors with raw pointers. The
21+
* backing buffer is deallocated once the final tensor is destroyed.
22+
* Stream usage is not recorded, so there may be race conditions if
23+
* compute is performed on multiple streams.
24+
*/
1525
std::vector<at::Tensor> bulk_allocate(const std::vector<std::vector<size_t>> &shapes,
1626
const std::vector<at::ScalarType> &dtypes,
1727
std::optional<c10::Device> device,

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,13 @@ def backward(
496496
if ctx.fuse_wgrad_accumulation:
497497
wgrad_list = main_grads
498498
else:
499-
weight_shape = list(weights[0].size())
500-
wgrad_list = tex.bulk_allocate(
501-
[weight_shape] * ctx.num_gemms,
502-
[ctx.activation_dtype] * ctx.num_gemms,
503-
ctx.device,
504-
[256] * ctx.num_gemms, # alignment
499+
wgrad_packed = torch.empty(
500+
ctx.num_gemms,
501+
*weights[0].size(),
502+
dtype=ctx.activation_dtype,
503+
device=ctx.device,
505504
)
505+
wgrad_list = [wgrad_packed[i] for i in range(ctx.num_gemms)]
506506

507507
if ctx.save_original_input:
508508
inp = inputmats[0]

transformer_engine/pytorch/ops/basic/grouped_linear.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,12 +1393,12 @@ def _fuser_backward_split_quantize(
13931393
]
13941394
accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0])
13951395
else:
1396-
grad_weights = tex.bulk_allocate(
1397-
[weight_shape] * num_groups,
1398-
[ctx.dtype] * num_groups,
1399-
device,
1400-
[256] * num_groups, # alignment
1396+
grad_weights_packed = torch.empty(
1397+
grouped_shape,
1398+
dtype=ctx.dtype,
1399+
device=device,
14011400
)
1401+
grad_weights = [grad_weights_packed[i] for i in range(num_groups)]
14021402
final_weight_grads = list(grad_weights)
14031403

14041404
# Perform dgrad GEMMs

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,13 @@ def _compute_grad_params(
197197
w_list = [get_main_grad_from_param(w, op_label=op_label) for w in weights]
198198
accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0])
199199
else:
200-
w_list = tex.bulk_allocate(
201-
[weight_shape] * num_groups,
202-
[dtype] * num_groups,
203-
device,
204-
[256] * num_groups, # alignment
200+
wgrad_packed = torch.empty(
201+
num_groups,
202+
*weight_shape,
203+
dtype=dtype,
204+
device=device,
205205
)
206+
w_list = [wgrad_packed[i] for i in range(num_groups)]
206207
wgrad_output = w_list
207208

208209
if ctx.weight_requires_grad:

0 commit comments

Comments
 (0)