@@ -116,6 +116,7 @@ inline void gpuAssert(T code, const char *file, int line, bool abort = true) {
116116/* ================================================================================================
117117 * SESSION
118118 * ================================================================================================ */
119+ static size_t s_numStreams = 1 ;
119120
120121/* *
121122 * Singleton for Cuda library handles.
@@ -125,11 +126,12 @@ inline void gpuAssert(T code, const char *file, int line, bool abort = true) {
125126 * The cuBlas handle can be accessed anywhere by `Session::getInstance().cuBlasHandle()`
126127 * The cuSolver handle can be accessed anywhere by `Session::getInstance().cuSolverHandle()`
127128 */
128- static size_t s_numStreams = 1 ;
129-
130129class Session {
131130public:
132-
131+ /* *
132+ *
133+ * @param numStreams
134+ */
133135 static void setStreams (size_t numStreams) {
134136 s_numStreams = numStreams;
135137 }
@@ -140,11 +142,11 @@ public:
140142 }
141143
142144private:
143- Session (size_t numStreams= 10 ) {
144- m_numBublasHandlesStreams = numStreams;
145- m_cublasHandles.resize (m_numBublasHandlesStreams );
146- m_cublasStreams.resize (m_numBublasHandlesStreams );
147- for (size_t i=0 ; i<m_numBublasHandlesStreams ; i++) {
145+ Session (size_t numStreams) {
146+ m_numCublasHandlesStreams = numStreams;
147+ m_cublasHandles.resize (m_numCublasHandlesStreams );
148+ m_cublasStreams.resize (m_numCublasHandlesStreams );
149+ for (size_t i=0 ; i<m_numCublasHandlesStreams ; i++) {
148150 gpuErrChk (cublasCreate (&m_cublasHandles[i]));
149151 gpuErrChk (cudaStreamCreate (&m_cublasStreams[i]));
150152 gpuErrChk (cublasSetStream (m_cublasHandles[i], m_cublasStreams[i]));
@@ -153,7 +155,7 @@ private:
153155 }
154156
155157 ~Session () {
156- for (size_t i=0 ; i<m_numBublasHandlesStreams ; i++) {
158+ for (size_t i=0 ; i<m_numCublasHandlesStreams ; i++) {
157159 gpuErrChk (cublasDestroy (m_cublasHandles[i]));
158160 }
159161 gpuErrChk (cusolverDnDestroy (m_cusolverHandle));
@@ -163,7 +165,7 @@ private:
163165 std::vector<cudaStream_t> m_cublasStreams;
164166 cusolverDnHandle_t m_cusolverHandle;
165167 size_t m_bytesAllocated = 0 ;
166- size_t m_numBublasHandlesStreams = 1 ;
168+ size_t m_numCublasHandlesStreams = 1 ;
167169
168170public:
169171 Session (Session const &) = delete ;
@@ -288,7 +290,7 @@ public:
288290 /* *
289291 * Set the stream ID
290292 */
291- void setStreamIdx (size_t );
293+ DTensor<T> setStreamIdx (size_t );
292294
293295 size_t streamIdx () const { return m_idxStream; }
294296
@@ -614,11 +616,12 @@ public:
614616}; /* END OF DTENSOR */
615617
616618template <typename T>
617- void DTensor<T>::setStreamIdx(size_t idx) {
619+ DTensor<T> DTensor<T>::setStreamIdx(size_t idx) {
618620 if (idx >= s_numStreams) {
619621 throw std::invalid_argument (" Invalid stream index; it exceeds the max allocated streams" );
620622 }
621623 m_idxStream = idx;
624+ return *this ;
622625}
623626
624627template <typename T>
0 commit comments