Skip to content

Commit 0968c24

Browse files
committed
Free curand states before the thread is destroyed
1 parent 617405f commit 0968c24

5 files changed

Lines changed: 23 additions & 1 deletion

File tree

include/ctranslate2/devices.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ namespace ctranslate2 {
2626
void synchronize_device(Device device, int index);
2727
void synchronize_stream(Device device);
2828

29+
void destroy_context(Device device);
30+
2931
class ScopedDeviceSetter {
3032
public:
3133
ScopedDeviceSetter(Device device, int index)

include/ctranslate2/replica_pool.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ namespace ctranslate2 {
354354

355355
void finalize() override {
356356
_replica.reset();
357+
358+
destroy_context(_device);
357359
}
358360

359361
private:

src/cuda/random.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,17 @@ namespace ctranslate2 {
4848
curandState* _states;
4949
};
5050

51+
static thread_local std::unique_ptr<ScopedCurandStates<curandStatePhilox4_32_10_t>> states;
52+
5153
curandStatePhilox4_32_10_t* get_curand_states(size_t num_states) {
52-
static thread_local std::unique_ptr<ScopedCurandStates<curandStatePhilox4_32_10_t>> states;
5354
if (!states || num_states > states->num_states())
5455
states = std::make_unique<ScopedCurandStates<curandStatePhilox4_32_10_t>>(num_states);
5556
return states->states();
5657
}
5758

59+
void free_curand_states() {
60+
states.reset();
61+
}
62+
5863
}
5964
}

src/cuda/random.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace ctranslate2 {
66
namespace cuda {
77

88
curandStatePhilox4_32_10_t* get_curand_states(size_t num_states);
9+
void free_curand_states();
910

1011
}
1112
}

src/devices.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#ifdef CT2_WITH_CUDA
44
# include "cuda/utils.h"
5+
# include "cuda/random.h"
56
#endif
67
#ifdef CT2_WITH_TENSOR_PARALLEL
78
# include <unistd.h>
@@ -118,6 +119,17 @@ namespace ctranslate2 {
118119
(void)device;
119120
#endif
120121
}
122+
123+
void destroy_context(Device device) {
124+
#ifdef CT2_WITH_CUDA
125+
if (device == Device::CUDA) {
126+
cuda::free_curand_states();
127+
}
128+
#else
129+
(void)device;
130+
#endif
131+
}
132+
121133
// Initialize the static member variable
122134
#ifdef CT2_WITH_TENSOR_PARALLEL
123135
std::vector<ncclComm_t*> ScopedMPISetter::_nccl_comms;

0 commit comments

Comments
 (0)