File tree Expand file tree Collapse file tree
transformer_engine/pytorch Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1212namespace transformer_engine {
1313namespace pytorch {
1414
15+ /* Allocate multiple PyTorch tensors backed by the same buffer.
16+ *
17+ * Use with caution and avoid exposing externally.
18+ *
19+ * In order to reduce CPU overhead, we compute pointer offsets
20+ * manually and construct PyTorch tensors with raw pointers. The
21+ * backing buffer is deallocated once the final tensor is destroyed.
22+ * Stream usage is not recorded, so there may be race conditions if
23+ * compute is performed on multiple streams.
24+ */
1525std::vector<at::Tensor> bulk_allocate (const std::vector<std::vector<size_t >> &shapes,
1626 const std::vector<at::ScalarType> &dtypes,
1727 std::optional<c10::Device> device,
Original file line number Diff line number Diff line change @@ -496,13 +496,13 @@ def backward(
496496 if ctx .fuse_wgrad_accumulation :
497497 wgrad_list = main_grads
498498 else :
499- weight_shape = list (weights [0 ].size ())
500- wgrad_list = tex .bulk_allocate (
501- [weight_shape ] * ctx .num_gemms ,
502- [ctx .activation_dtype ] * ctx .num_gemms ,
503- ctx .device ,
504- [256 ] * ctx .num_gemms , # alignment
499+ wgrad_packed = torch .empty (
500+ ctx .num_gemms ,
501+ * weights [0 ].size (),
502+ dtype = ctx .activation_dtype ,
503+ device = ctx .device ,
505504 )
505+ wgrad_list = [wgrad_packed [i ] for i in range (ctx .num_gemms )]
506506
507507 if ctx .save_original_input :
508508 inp = inputmats [0 ]
Original file line number Diff line number Diff line change @@ -1393,12 +1393,12 @@ def _fuser_backward_split_quantize(
13931393 ]
13941394 accumulate_into_main_grad = get_accumulate_flag_in_param (weights [0 ])
13951395 else :
1396- grad_weights = tex .bulk_allocate (
1397- [weight_shape ] * num_groups ,
1398- [ctx .dtype ] * num_groups ,
1399- device ,
1400- [256 ] * num_groups , # alignment
1396+ grad_weights_packed = torch .empty (
1397+ grouped_shape ,
1398+ dtype = ctx .dtype ,
1399+ device = device ,
14011400 )
1401+ grad_weights = [grad_weights_packed [i ] for i in range (num_groups )]
14021402 final_weight_grads = list (grad_weights )
14031403
14041404 # Perform dgrad GEMMs
Original file line number Diff line number Diff line change @@ -197,12 +197,13 @@ def _compute_grad_params(
197197 w_list = [get_main_grad_from_param (w , op_label = op_label ) for w in weights ]
198198 accumulate_into_main_grad = get_accumulate_flag_in_param (weights [0 ])
199199 else :
200- w_list = tex . bulk_allocate (
201- [ weight_shape ] * num_groups ,
202- [ dtype ] * num_groups ,
203- device ,
204- [ 256 ] * num_groups , # alignment
200+ wgrad_packed = torch . empty (
201+ num_groups ,
202+ * weight_shape ,
203+ dtype = dtype ,
204+ device = device ,
205205 )
206+ w_list = [wgrad_packed [i ] for i in range (num_groups )]
206207 wgrad_output = w_list
207208
208209 if ctx .weight_requires_grad :
You can’t perform that action at this time.
0 commit comments