Skip to content

Commit 0a17614

Browse files
committed
minor
1 parent 212cd37 commit 0a17614

2 files changed

Lines changed: 16 additions & 14 deletions

File tree

include/tensor.cuh

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ inline void gpuAssert(T code, const char *file, int line, bool abort = true) {
116116
/* ================================================================================================
117117
* SESSION
118118
* ================================================================================================ */
119+
static size_t s_numStreams = 1;
119120

120121
/**
121122
* Singleton for Cuda library handles.
@@ -125,11 +126,12 @@ inline void gpuAssert(T code, const char *file, int line, bool abort = true) {
125126
* The cuBlas handle can be accessed anywhere by `Session::getInstance().cuBlasHandle()`
126127
* The cuSolver handle can be accessed anywhere by `Session::getInstance().cuSolverHandle()`
127128
*/
128-
static size_t s_numStreams = 1;
129-
130129
class Session {
131130
public:
132-
131+
/**
132+
*
133+
* @param numStreams
134+
*/
133135
static void setStreams(size_t numStreams) {
134136
s_numStreams = numStreams;
135137
}
@@ -140,11 +142,11 @@ public:
140142
}
141143

142144
private:
143-
Session(size_t numStreams=10) {
144-
m_numBublasHandlesStreams = numStreams;
145-
m_cublasHandles.resize(m_numBublasHandlesStreams);
146-
m_cublasStreams.resize(m_numBublasHandlesStreams);
147-
for (size_t i=0; i<m_numBublasHandlesStreams; i++) {
145+
Session(size_t numStreams) {
146+
m_numCublasHandlesStreams = numStreams;
147+
m_cublasHandles.resize(m_numCublasHandlesStreams);
148+
m_cublasStreams.resize(m_numCublasHandlesStreams);
149+
for (size_t i=0; i<m_numCublasHandlesStreams; i++) {
148150
gpuErrChk(cublasCreate(&m_cublasHandles[i]));
149151
gpuErrChk(cudaStreamCreate(&m_cublasStreams[i]));
150152
gpuErrChk(cublasSetStream(m_cublasHandles[i], m_cublasStreams[i]));
@@ -153,7 +155,7 @@ private:
153155
}
154156

155157
~Session() {
156-
for (size_t i=0; i<m_numBublasHandlesStreams; i++) {
158+
for (size_t i=0; i<m_numCublasHandlesStreams; i++) {
157159
gpuErrChk(cublasDestroy(m_cublasHandles[i]));
158160
}
159161
gpuErrChk(cusolverDnDestroy(m_cusolverHandle));
@@ -163,7 +165,7 @@ private:
163165
std::vector<cudaStream_t> m_cublasStreams;
164166
cusolverDnHandle_t m_cusolverHandle;
165167
size_t m_bytesAllocated = 0;
166-
size_t m_numBublasHandlesStreams = 1;
168+
size_t m_numCublasHandlesStreams = 1;
167169

168170
public:
169171
Session(Session const &) = delete;
@@ -288,7 +290,7 @@ public:
288290
/**
289291
* Set the stream ID
290292
*/
291-
void setStreamIdx(size_t);
293+
DTensor<T> setStreamIdx(size_t);
292294

293295
size_t streamIdx() const { return m_idxStream; }
294296

@@ -614,11 +616,12 @@ public:
614616
}; /* END OF DTENSOR */
615617

616618
template<typename T>
617-
void DTensor<T>::setStreamIdx(size_t idx) {
619+
DTensor<T> DTensor<T>::setStreamIdx(size_t idx) {
618620
if (idx >= s_numStreams) {
619621
throw std::invalid_argument("Invalid stream index; it exceeds the max allocated streams");
620622
}
621623
m_idxStream = idx;
624+
return *this;
622625
}
623626

624627
template<typename T>

main.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
void xyz() {
88
/* Write to binary file */
9-
auto r = DTensor<double>::createRandomTensor(3, 6, 4, -1, 1);
10-
r.setStreamIdx(1);
9+
DTensor<double> r = DTensor<double>::createRandomTensor(3, 6, 4, -1, 1).setStreamIdx(1);
1110
std::string fName = "abcd.bt"; // binary tensor file extension: .bt
1211
r.saveToFile(fName);
1312

0 commit comments

Comments
 (0)