Skip to content

Commit 901ba14

Browse files
committed
streams and cusolver handles
1 parent 0a17614 commit 901ba14

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

include/tensor.cuh

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,18 @@ static size_t s_numStreams = 1;
129129
class Session {
130130
public:
131131
/**
132-
*
133-
* @param numStreams
132+
* Sets the total number of available streams
133+
* @param numStreams number of streams (default: 1)
134134
*/
135135
static void setStreams(size_t numStreams) {
136136
s_numStreams = numStreams;
137137
}
138138

139+
/**
140+
* Returns the unique instance of Session (constructed upon first
141+
* invocation)
142+
* @return instance of Session
143+
*/
139144
static Session &getInstance() {
140145
static Session instance(s_numStreams);
141146
return instance;
@@ -146,24 +151,26 @@ private:
146151
m_numCublasHandlesStreams = numStreams;
147152
m_cublasHandles.resize(m_numCublasHandlesStreams);
148153
m_cublasStreams.resize(m_numCublasHandlesStreams);
154+
m_cusolverHandles.resize(m_numCublasHandlesStreams);
149155
for (size_t i=0; i<m_numCublasHandlesStreams; i++) {
150156
gpuErrChk(cublasCreate(&m_cublasHandles[i]));
151157
gpuErrChk(cudaStreamCreate(&m_cublasStreams[i]));
152158
gpuErrChk(cublasSetStream(m_cublasHandles[i], m_cublasStreams[i]));
159+
gpuErrChk(cusolverDnCreate(&m_cusolverHandles[i]));
160+
gpuErrChk(cusolverDnSetStream(m_cusolverHandles[i], m_cublasStreams[i]));
153161
}
154-
gpuErrChk(cusolverDnCreate(&m_cusolverHandle));
155162
}
156163

157164
~Session() {
158165
for (size_t i=0; i<m_numCublasHandlesStreams; i++) {
159166
gpuErrChk(cublasDestroy(m_cublasHandles[i]));
167+
gpuErrChk(cusolverDnDestroy(m_cusolverHandles[i]));
160168
}
161-
gpuErrChk(cusolverDnDestroy(m_cusolverHandle));
162169
}
163170

164171
std::vector<cublasHandle_t> m_cublasHandles;
165172
std::vector<cudaStream_t> m_cublasStreams;
166-
cusolverDnHandle_t m_cusolverHandle;
173+
std::vector<cusolverDnHandle_t> m_cusolverHandles;
167174
size_t m_bytesAllocated = 0;
168175
size_t m_numCublasHandlesStreams = 1;
169176

@@ -172,9 +179,19 @@ public:
172179

173180
void operator=(Session const &) = delete;
174181

182+
/**
183+
* cuBLAS handle
184+
* @param idx index of stream
185+
* @return cuBLAS handle
186+
*/
175187
cublasHandle_t &cuBlasHandle(size_t idx=0) { return m_cublasHandles[idx]; }
176188

177-
cusolverDnHandle_t &cuSolverHandle() { return m_cusolverHandle; }
189+
/**
190+
* cuSolver handle
191+
* @param idx index of stream
192+
* @return cuSolver handle
193+
*/
194+
cusolverDnHandle_t &cuSolverHandle(size_t idx=0) { return m_cusolverHandles[idx]; }
178195

179196
/**
180197
* Preferred method for CUDA memory allocation; it allocated memory on the device
@@ -198,7 +215,11 @@ public:
198215
*/
199216
size_t totalAllocatedBytes() const { return m_bytesAllocated; }
200217

201-
void incrementAllocatedBytes(size_t s) { m_bytesAllocated += s; }
218+
/**
219+
* Increment counter of allocated bytes
220+
* @param s allocated bytes (can be negative)
221+
*/
222+
void incrementAllocatedBytes(int s) { m_bytesAllocated += s; }
202223
};
203224

204225

0 commit comments

Comments
 (0)