@@ -414,12 +414,10 @@ void csr2csc_template_(
414414 if (nnz == 0 ) {
415415 return ;
416416 }
417- csc.row_indices =
418- static_cast <int *>(fbgemm::fbgemmAlignedAlloc (64 , nnz * sizeof (int )));
417+ csc.row_indices = fbgemm::makeAlignedUniquePtr<int >(64 , nnz);
419418 bool has_weights = csr_weights.data () != nullptr ;
420419 if (IS_VALUE_PAIR) {
421- csc.weights = static_cast <float *>(
422- fbgemm::fbgemmAlignedAlloc (64 , nnz * sizeof (float )));
420+ csc.weights = fbgemm::makeAlignedUniquePtr<float >(64 , nnz);
423421 }
424422
425423 [[maybe_unused]] int column_ptr_curr = 0 ;
@@ -431,16 +429,15 @@ void csr2csc_template_(
431429 using pair_t = std::pair<int , scalar_t >;
432430 using value_t = typename std::conditional<IS_VALUE_PAIR, pair_t , int >::type;
433431
434- csc.column_segment_ids =
435- static_cast <int *>(fbgemm::fbgemmAlignedAlloc (64 , nnz * sizeof (int )));
436- int * tmpBufKeys =
437- static_cast <int *>(fbgemm::fbgemmAlignedAlloc (64 , NS * sizeof (int )));
438- value_t * tmpBufValues = static_cast <value_t *>(
439- fbgemm::fbgemmAlignedAlloc (64 , NS * sizeof (value_t )));
440- int * tmpBuf1Keys =
441- static_cast <int *>(fbgemm::fbgemmAlignedAlloc (64 , NS * sizeof (int )));
442- value_t * tmpBuf1Values = static_cast <value_t *>(
443- fbgemm::fbgemmAlignedAlloc (64 , NS * sizeof (value_t )));
432+ csc.column_segment_ids = fbgemm::makeAlignedUniquePtr<int >(64 , nnz);
433+ auto tmpBufKeys = fbgemm::makeAlignedUniquePtr<int >(64 , NS);
434+ fbgemm::aligned_unique_ptr<value_t > tmpBufValues (
435+ static_cast <value_t *>(
436+ fbgemm::fbgemmAlignedAlloc (64 , NS * sizeof (value_t ))));
437+ auto tmpBuf1Keys = fbgemm::makeAlignedUniquePtr<int >(64 , NS);
438+ fbgemm::aligned_unique_ptr<value_t > tmpBuf1Values (
439+ static_cast <value_t *>(
440+ fbgemm::fbgemmAlignedAlloc (64 , NS * sizeof (value_t ))));
444441
445442 const auto FBo = csr_offsets[table_to_feature_offset[0 ] * B];
446443 for (int feature = table_to_feature_offset[0 ];
@@ -461,11 +458,11 @@ void csr2csc_template_(
461458 : 1.0 ;
462459 for (const auto p : c10::irange (pool_begin, pool_end)) {
463460 tmpBufKeys[p - FBo] = csr_indices[p];
464- if (IS_VALUE_PAIR) {
465- reinterpret_cast < pair_t *>( tmpBufValues) [p - FBo] = std::pair{
461+ if constexpr (IS_VALUE_PAIR) {
462+ tmpBufValues[p - FBo] = std::pair{
466463 FBs + b, scale_factor * (has_weights ? csr_weights[p] : 1 .0f )};
467464 } else {
468- reinterpret_cast < int *>( tmpBufValues) [p - FBo] = FBs + b;
465+ tmpBufValues[p - FBo] = FBs + b;
469466 }
470467 }
471468 }
@@ -475,10 +472,10 @@ void csr2csc_template_(
475472 value_t * sorted_col_row_index_values;
476473 std::tie (sorted_col_row_index_keys, sorted_col_row_index_values) =
477474 fbgemm::radix_sort_parallel (
478- tmpBufKeys,
479- tmpBufValues,
480- tmpBuf1Keys,
481- tmpBuf1Values,
475+ tmpBufKeys. get () ,
476+ tmpBufValues. get () ,
477+ tmpBuf1Keys. get () ,
478+ tmpBuf1Values. get () ,
482479 NS,
483480 num_embeddings);
484481
@@ -509,10 +506,8 @@ void csr2csc_template_(
509506 U = num_uniq[max_thds - 1 ][0 ];
510507 }
511508
512- csc.column_segment_ptr =
513- static_cast <int *>(fbgemm::fbgemmAlignedAlloc (64 , (NS + 1 ) * sizeof (int )));
514- csc.column_segment_indices =
515- static_cast <int *>(fbgemm::fbgemmAlignedAlloc (64 , NS * sizeof (int )));
509+ csc.column_segment_ptr = fbgemm::makeAlignedUniquePtr<int >(64 , NS + 1 );
510+ csc.column_segment_indices = fbgemm::makeAlignedUniquePtr<int >(64 , NS);
516511 csc.column_segment_ptr [0 ] = 0 ;
517512 const pair_t * sorted_col_row_index_values_pair =
518513 reinterpret_cast <const pair_t *>(sorted_col_row_index_values);
@@ -528,26 +523,31 @@ void csr2csc_template_(
528523 }
529524 csc.column_segment_indices [0 ] = sorted_col_row_index_keys[0 ];
530525
526+ int * col_seg_indices = csc.column_segment_indices .get ();
527+ int * col_seg_ptr = csc.column_segment_ptr .get ();
528+
531529#pragma omp parallel
532530 {
533531 int tid = omp_get_thread_num ();
534532 int * tstart =
535- (tid == 0 ? csc. column_segment_indices + 1
536- : csc. column_segment_indices + num_uniq[tid - 1 ][0 ]);
533+ (tid == 0 ? col_seg_indices + 1
534+ : col_seg_indices + num_uniq[tid - 1 ][0 ]);
537535
538536 int * t_offs =
539- (tid == 0 ? csc.column_segment_ptr + 1
540- : csc.column_segment_ptr + num_uniq[tid - 1 ][0 ]);
537+ (tid == 0 ? col_seg_ptr + 1 : col_seg_ptr + num_uniq[tid - 1 ][0 ]);
541538
542539 if (!IS_VALUE_PAIR && !is_shared_table) {
543540 // For non shared table, no need for computing modulo.
544541 // As an optimization, pointer swap instead of copying.
545542#pragma omp master
546- std::swap (
547- csc.row_indices ,
548- *reinterpret_cast <int **>(
549- sorted_col_row_index_values == tmpBufValues ? &tmpBufValues
550- : &tmpBuf1Values));
543+ {
544+ auto & buf = sorted_col_row_index_values == tmpBufValues.get ()
545+ ? tmpBufValues
546+ : tmpBuf1Values;
547+ int * tmp = csc.row_indices .release ();
548+ csc.row_indices .reset (reinterpret_cast <int *>(buf.release ()));
549+ buf.reset (reinterpret_cast <value_t *>(tmp));
550+ }
551551 } else {
552552#ifdef FBCODE_CAFFE2
553553 libdivide::divider<int > divisor (B);
@@ -582,7 +582,7 @@ void csr2csc_template_(
582582
583583 if (at::get_num_threads () == 1 && tid == 0 ) {
584584 // Special handling of single thread case
585- U = t_offs - csc.column_segment_ptr ;
585+ U = t_offs - csc.column_segment_ptr . get () ;
586586 }
587587
588588 } // omp parallel
@@ -591,11 +591,6 @@ void csr2csc_template_(
591591 csc.column_segment_ptr [U] = NS;
592592 column_ptr_curr += NS;
593593
594- fbgemm::fbgemmAlignedFree (tmpBufKeys);
595- fbgemm::fbgemmAlignedFree (tmpBufValues);
596- fbgemm::fbgemmAlignedFree (tmpBuf1Keys);
597- fbgemm::fbgemmAlignedFree (tmpBuf1Values);
598-
599594 assert (column_ptr_curr == nnz);
600595}
601596
0 commit comments