Skip to content

Commit 6b72f97

Browse files
cyyevermeta-codesync[bot]
authored andcommitted
Use aligned_unique_ptr in more places to avoid leak (#5621)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2578 Pull Request resolved: #5621 Reviewed By: henrylhtsang Differential Revision: D100536982 Pulled By: q10 fbshipit-source-id: c0b1b9f5dbd462a123ffce359b5e9bc4092c2b12
1 parent f528b0d commit 6b72f97

4 files changed

Lines changed: 63 additions & 71 deletions

File tree

fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,11 @@ for (const auto t : c10::irange(num_tables)) {
111111
int feature_begin = table_to_feature_offset[t];
112112

113113
int num_non_zero_columns = cscs[t].num_non_zero_columns;
114-
int* col_segment_ptr = cscs[t].column_segment_ptr;
115-
int* col_segment_indices = cscs[t].column_segment_indices;
114+
int* col_segment_ptr = cscs[t].column_segment_ptr.get();
115+
int* col_segment_indices = cscs[t].column_segment_indices.get();
116+
int* col_segment_ids = cscs[t].column_segment_ids.get();
117+
int* row_indices = cscs[t].row_indices.get();
118+
float* ind_weights = cscs[t].weights.get();
116119

117120
const auto D_begin = D_offsets_data[feature_begin];
118121
const auto D =
@@ -134,7 +137,7 @@ for (const auto t : c10::irange(num_tables)) {
134137
/*IndexType=*/int32_t,
135138
/*OffsetType=*/int32_t>(
136139
D,
137-
cscs[t].weights != nullptr,
140+
ind_weights != nullptr,
138141
/*normalize_by_lengths=*/false,
139142
/*prefetch=*/16,
140143
/*is_weight_positional=*/false,
@@ -156,11 +159,11 @@ for (const auto t : c10::irange(num_tables)) {
156159
B,
157160
reinterpret_cast<const fbgemm_weight_t*>(
158161
grad_output_data + D_begin),
159-
cscs[t].row_indices + *offsets_begin_ptr,
162+
row_indices + *offsets_begin_ptr,
160163
offsets_begin_ptr,
161-
cscs[t].weights == nullptr
164+
ind_weights == nullptr
162165
? nullptr
163-
: cscs[t].weights + *offsets_begin_ptr,
166+
: ind_weights + *offsets_begin_ptr,
164167
reinterpret_cast<float*>(grad_blocked_buffer));
165168

166169
if (!success) {
@@ -170,7 +173,7 @@ for (const auto t : c10::irange(num_tables)) {
170173
c,
171174
c_block_end,
172175
col_segment_ptr,
173-
cscs[t].row_indices,
176+
row_indices,
174177
hash_size,
175178
/*allow_minus_one=*/false);
176179
}
@@ -218,14 +221,14 @@ for (const auto t : c10::irange(num_tables)) {
218221
for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
219222
int D_offset = D_begin;
220223
if (is_shared_table) {
221-
D_offset += cscs[t].column_segment_ids[r] * D;
224+
D_offset += col_segment_ids[r] * D;
222225
}
223-
int b = cscs[t].row_indices[r];
226+
int b = row_indices[r];
224227

225228
for (const auto d : c10::irange(D)) {
226-
if (cscs[t].weights != nullptr) {
229+
if (ind_weights != nullptr) {
227230
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] *
228-
cscs[t].weights[r];
231+
ind_weights[r];
229232
} else {
230233
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d];
231234
}

fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -414,12 +414,10 @@ void csr2csc_template_(
414414
if (nnz == 0) {
415415
return;
416416
}
417-
csc.row_indices =
418-
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int)));
417+
csc.row_indices = fbgemm::makeAlignedUniquePtr<int>(64, nnz);
419418
bool has_weights = csr_weights.data() != nullptr;
420419
if (IS_VALUE_PAIR) {
421-
csc.weights = static_cast<float*>(
422-
fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(float)));
420+
csc.weights = fbgemm::makeAlignedUniquePtr<float>(64, nnz);
423421
}
424422

425423
[[maybe_unused]] int column_ptr_curr = 0;
@@ -431,16 +429,15 @@ void csr2csc_template_(
431429
using pair_t = std::pair<int, scalar_t>;
432430
using value_t = typename std::conditional<IS_VALUE_PAIR, pair_t, int>::type;
433431

434-
csc.column_segment_ids =
435-
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int)));
436-
int* tmpBufKeys =
437-
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
438-
value_t* tmpBufValues = static_cast<value_t*>(
439-
fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(value_t)));
440-
int* tmpBuf1Keys =
441-
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
442-
value_t* tmpBuf1Values = static_cast<value_t*>(
443-
fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(value_t)));
432+
csc.column_segment_ids = fbgemm::makeAlignedUniquePtr<int>(64, nnz);
433+
auto tmpBufKeys = fbgemm::makeAlignedUniquePtr<int>(64, NS);
434+
fbgemm::aligned_unique_ptr<value_t> tmpBufValues(
435+
static_cast<value_t*>(
436+
fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(value_t))));
437+
auto tmpBuf1Keys = fbgemm::makeAlignedUniquePtr<int>(64, NS);
438+
fbgemm::aligned_unique_ptr<value_t> tmpBuf1Values(
439+
static_cast<value_t*>(
440+
fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(value_t))));
444441

445442
const auto FBo = csr_offsets[table_to_feature_offset[0] * B];
446443
for (int feature = table_to_feature_offset[0];
@@ -461,11 +458,11 @@ void csr2csc_template_(
461458
: 1.0;
462459
for (const auto p : c10::irange(pool_begin, pool_end)) {
463460
tmpBufKeys[p - FBo] = csr_indices[p];
464-
if (IS_VALUE_PAIR) {
465-
reinterpret_cast<pair_t*>(tmpBufValues)[p - FBo] = std::pair{
461+
if constexpr (IS_VALUE_PAIR) {
462+
tmpBufValues[p - FBo] = std::pair{
466463
FBs + b, scale_factor * (has_weights ? csr_weights[p] : 1.0f)};
467464
} else {
468-
reinterpret_cast<int*>(tmpBufValues)[p - FBo] = FBs + b;
465+
tmpBufValues[p - FBo] = FBs + b;
469466
}
470467
}
471468
}
@@ -475,10 +472,10 @@ void csr2csc_template_(
475472
value_t* sorted_col_row_index_values;
476473
std::tie(sorted_col_row_index_keys, sorted_col_row_index_values) =
477474
fbgemm::radix_sort_parallel(
478-
tmpBufKeys,
479-
tmpBufValues,
480-
tmpBuf1Keys,
481-
tmpBuf1Values,
475+
tmpBufKeys.get(),
476+
tmpBufValues.get(),
477+
tmpBuf1Keys.get(),
478+
tmpBuf1Values.get(),
482479
NS,
483480
num_embeddings);
484481

@@ -509,10 +506,8 @@ void csr2csc_template_(
509506
U = num_uniq[max_thds - 1][0];
510507
}
511508

512-
csc.column_segment_ptr =
513-
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, (NS + 1) * sizeof(int)));
514-
csc.column_segment_indices =
515-
static_cast<int*>(fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
509+
csc.column_segment_ptr = fbgemm::makeAlignedUniquePtr<int>(64, NS + 1);
510+
csc.column_segment_indices = fbgemm::makeAlignedUniquePtr<int>(64, NS);
516511
csc.column_segment_ptr[0] = 0;
517512
const pair_t* sorted_col_row_index_values_pair =
518513
reinterpret_cast<const pair_t*>(sorted_col_row_index_values);
@@ -528,26 +523,31 @@ void csr2csc_template_(
528523
}
529524
csc.column_segment_indices[0] = sorted_col_row_index_keys[0];
530525

526+
int* col_seg_indices = csc.column_segment_indices.get();
527+
int* col_seg_ptr = csc.column_segment_ptr.get();
528+
531529
#pragma omp parallel
532530
{
533531
int tid = omp_get_thread_num();
534532
int* tstart =
535-
(tid == 0 ? csc.column_segment_indices + 1
536-
: csc.column_segment_indices + num_uniq[tid - 1][0]);
533+
(tid == 0 ? col_seg_indices + 1
534+
: col_seg_indices + num_uniq[tid - 1][0]);
537535

538536
int* t_offs =
539-
(tid == 0 ? csc.column_segment_ptr + 1
540-
: csc.column_segment_ptr + num_uniq[tid - 1][0]);
537+
(tid == 0 ? col_seg_ptr + 1 : col_seg_ptr + num_uniq[tid - 1][0]);
541538

542539
if (!IS_VALUE_PAIR && !is_shared_table) {
543540
// For non shared table, no need for computing modulo.
544541
// As an optimization, pointer swap instead of copying.
545542
#pragma omp master
546-
std::swap(
547-
csc.row_indices,
548-
*reinterpret_cast<int**>(
549-
sorted_col_row_index_values == tmpBufValues ? &tmpBufValues
550-
: &tmpBuf1Values));
543+
{
544+
auto& buf = sorted_col_row_index_values == tmpBufValues.get()
545+
? tmpBufValues
546+
: tmpBuf1Values;
547+
int* tmp = csc.row_indices.release();
548+
csc.row_indices.reset(reinterpret_cast<int*>(buf.release()));
549+
buf.reset(reinterpret_cast<value_t*>(tmp));
550+
}
551551
} else {
552552
#ifdef FBCODE_CAFFE2
553553
libdivide::divider<int> divisor(B);
@@ -582,7 +582,7 @@ void csr2csc_template_(
582582

583583
if (at::get_num_threads() == 1 && tid == 0) {
584584
// Special handling of single thread case
585-
U = t_offs - csc.column_segment_ptr;
585+
U = t_offs - csc.column_segment_ptr.get();
586586
}
587587

588588
} // omp parallel
@@ -591,11 +591,6 @@ void csr2csc_template_(
591591
csc.column_segment_ptr[U] = NS;
592592
column_ptr_curr += NS;
593593

594-
fbgemm::fbgemmAlignedFree(tmpBufKeys);
595-
fbgemm::fbgemmAlignedFree(tmpBufValues);
596-
fbgemm::fbgemmAlignedFree(tmpBuf1Keys);
597-
fbgemm::fbgemmAlignedFree(tmpBuf1Values);
598-
599594
assert(column_ptr_curr == nnz);
600595
}
601596

fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,15 @@ struct HyperCompressedSparseColumn {
4444
// For a shared table, a column can have multiple segments, each for a
4545
// feature sharing the table. In this case, the segments will have the
4646
// same column_segment_indices but different column_segment_ids.
47-
int* column_segment_ptr = nullptr;
48-
int* column_segment_indices = nullptr; // length num_non_zero_columns
49-
int* column_segment_ids = nullptr; // length num_non_zero_columns
50-
int* row_indices = nullptr; // length column_ptr[num_non_zero_columns]
51-
float* weights = nullptr; // length column_ptr[num_non_zero_columns]
52-
~HyperCompressedSparseColumn() {
53-
if (column_segment_ptr) {
54-
fbgemm::fbgemmAlignedFree(column_segment_ptr);
55-
fbgemm::fbgemmAlignedFree(column_segment_indices);
56-
fbgemm::fbgemmAlignedFree(column_segment_ids);
57-
fbgemm::fbgemmAlignedFree(row_indices);
58-
}
59-
if (weights) {
60-
fbgemm::fbgemmAlignedFree(weights);
61-
}
62-
}
47+
fbgemm::aligned_unique_ptr<int> column_segment_ptr;
48+
fbgemm::aligned_unique_ptr<int>
49+
column_segment_indices; // length num_non_zero_columns
50+
fbgemm::aligned_unique_ptr<int>
51+
column_segment_ids; // length num_non_zero_columns
52+
fbgemm::aligned_unique_ptr<int>
53+
row_indices; // length column_ptr[num_non_zero_columns]
54+
fbgemm::aligned_unique_ptr<float>
55+
weights; // length column_ptr[num_non_zero_columns]
6356
};
6457

6558
template <typename index_t, typename scalar_t>

src/Utils.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,15 +369,16 @@ void* fbgemmAlignedAlloc(
369369
size_t size,
370370
bool raiseException /*=false*/) {
371371
void* aligned_mem = nullptr;
372-
int ret = 0;
373372
#ifdef _MSC_VER
374373
aligned_mem = _aligned_malloc(size, align);
375-
ret = 0;
376374
#else
377-
ret = posix_memalign(&aligned_mem, align, size);
375+
int ret = posix_memalign(&aligned_mem, align, size);
376+
if (ret != 0) {
377+
aligned_mem = nullptr;
378+
}
378379
#endif
379380
// Throw std::bad_alloc in the case of memory allocation failure.
380-
if (raiseException && (ret || aligned_mem == nullptr)) {
381+
if (raiseException && aligned_mem == nullptr) {
381382
throw std::bad_alloc();
382383
}
383384
return aligned_mem;

0 commit comments

Comments
 (0)