File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -26,6 +26,8 @@ namespace ctranslate2 {
2626 void synchronize_device (Device device, int index);
2727 void synchronize_stream (Device device);
2828
29+ void destroy_context (Device device);
30+
2931 class ScopedDeviceSetter {
3032 public:
3133 ScopedDeviceSetter (Device device, int index)
Original file line number Diff line number Diff line change @@ -354,6 +354,8 @@ namespace ctranslate2 {
354354
355355 void finalize () override {
356356 _replica.reset ();
357+
358+ destroy_context (_device);
357359 }
358360
359361 private:
Original file line number Diff line number Diff line change @@ -48,12 +48,17 @@ namespace ctranslate2 {
4848 curandState* _states;
4949 };
5050
51+ static thread_local std::unique_ptr<ScopedCurandStates<curandStatePhilox4_32_10_t>> states;
52+
5153 curandStatePhilox4_32_10_t* get_curand_states (size_t num_states) {
52- static thread_local std::unique_ptr<ScopedCurandStates<curandStatePhilox4_32_10_t>> states;
5354 if (!states || num_states > states->num_states ())
5455 states = std::make_unique<ScopedCurandStates<curandStatePhilox4_32_10_t>>(num_states);
5556 return states->states ();
5657 }
5758
59+ void free_curand_states () {
60+ states.reset ();
61+ }
62+
5863 }
5964}
Original file line number Diff line number Diff line change @@ -6,6 +6,7 @@ namespace ctranslate2 {
66 namespace cuda {
77
88 curandStatePhilox4_32_10_t* get_curand_states (size_t num_states);
9+ void free_curand_states ();
910
1011 }
1112}
Original file line number Diff line number Diff line change 22
33#ifdef CT2_WITH_CUDA
44# include " cuda/utils.h"
5+ # include " cuda/random.h"
56#endif
67#ifdef CT2_WITH_TENSOR_PARALLEL
78# include < unistd.h>
@@ -118,6 +119,17 @@ namespace ctranslate2 {
118119 (void )device;
119120#endif
120121 }
122+
123+ void destroy_context (Device device) {
124+ #ifdef CT2_WITH_CUDA
125+ if (device == Device::CUDA) {
126+ cuda::free_curand_states ();
127+ }
128+ #else
129+ (void )device;
130+ #endif
131+ }
132+
121133 // Initialize the static member variable
122134#ifdef CT2_WITH_TENSOR_PARALLEL
123135 std::vector<ncclComm_t*> ScopedMPISetter::_nccl_comms;
You can’t perform that action at this time.
0 commit comments