|
13 | 13 | #include <string.h> |
14 | 14 | #include "ceed-cuda-ref.h" |
15 | 15 |
|
| 16 | + |
| 17 | +//------------------------------------------------------------------------------ |
| 18 | +// Check if host/device sync is needed |
| 19 | +//------------------------------------------------------------------------------ |
| 20 | +static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, |
| 21 | + CeedMemType mem_type, bool *need_sync) { |
| 22 | + int ierr; |
| 23 | + CeedVector_Cuda *impl; |
| 24 | + ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); |
| 25 | + |
| 26 | + bool has_valid_array = false; |
| 27 | + ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); |
| 28 | + switch (mem_type) { |
| 29 | + case CEED_MEM_HOST: |
| 30 | + *need_sync = has_valid_array && !impl->h_array; |
| 31 | + break; |
| 32 | + case CEED_MEM_DEVICE: |
| 33 | + *need_sync = has_valid_array && !impl->d_array; |
| 34 | + break; |
| 35 | + } |
| 36 | + |
| 37 | + return CEED_ERROR_SUCCESS; |
| 38 | +} |
| 39 | + |
16 | 40 | //------------------------------------------------------------------------------ |
17 | 41 | // Sync host to device |
18 | 42 | //------------------------------------------------------------------------------ |
@@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) { |
88 | 112 | //------------------------------------------------------------------------------ |
89 | 113 | // Sync arrays |
90 | 114 | //------------------------------------------------------------------------------ |
91 | | -static inline int CeedVectorSync_Cuda(const CeedVector vec, |
92 | | - CeedMemType mem_type) { |
| 115 | +static int CeedVectorSyncArray_Cuda(const CeedVector vec, |
| 116 | + CeedMemType mem_type) { |
| 117 | + int ierr; |
| 118 | + // Check whether device/host sync is needed |
| 119 | + bool need_sync = false; |
| 120 | + ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); |
| 121 | + CeedChkBackend(ierr); |
| 122 | + if (!need_sync) |
| 123 | + return CEED_ERROR_SUCCESS; |
| 124 | + |
93 | 125 | switch (mem_type) { |
94 | 126 | case CEED_MEM_HOST: return CeedVectorSyncD2H_Cuda(vec); |
95 | 127 | case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Cuda(vec); |
@@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, |
167 | 199 | return CEED_ERROR_SUCCESS; |
168 | 200 | } |
169 | 201 |
|
170 | | -//------------------------------------------------------------------------------ |
171 | | -// Check if is any array of given type |
172 | | -//------------------------------------------------------------------------------ |
173 | | -static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, |
174 | | - CeedMemType mem_type, bool *need_sync) { |
175 | | - int ierr; |
176 | | - CeedVector_Cuda *impl; |
177 | | - ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); |
178 | | - |
179 | | - bool has_valid_array = false; |
180 | | - ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr); |
181 | | - switch (mem_type) { |
182 | | - case CEED_MEM_HOST: |
183 | | - *need_sync = has_valid_array && !impl->h_array; |
184 | | - break; |
185 | | - case CEED_MEM_DEVICE: |
186 | | - *need_sync = has_valid_array && !impl->d_array; |
187 | | - break; |
188 | | - } |
189 | | - |
190 | | - return CEED_ERROR_SUCCESS; |
191 | | -} |
192 | | - |
193 | 202 | //------------------------------------------------------------------------------ |
194 | 203 | // Set array from host |
195 | 204 | //------------------------------------------------------------------------------ |
@@ -368,11 +377,7 @@ static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, |
368 | 377 | ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); |
369 | 378 |
|
370 | 379 | // Sync array to requested mem_type |
371 | | - bool need_sync = false; |
372 | | - ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr); |
373 | | - if (need_sync) { |
374 | | - ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr); |
375 | | - } |
| 380 | + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); |
376 | 381 |
|
377 | 382 | // Update pointer |
378 | 383 | switch (mem_type) { |
@@ -403,14 +408,8 @@ static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, |
403 | 408 | CeedVector_Cuda *impl; |
404 | 409 | ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr); |
405 | 410 |
|
406 | | - bool need_sync = false, has_array_of_type = true; |
407 | | - ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr); |
408 | | - ierr = CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type); |
409 | | - CeedChkBackend(ierr); |
410 | | - if (need_sync) { |
411 | | - // Sync array to requested mem_type |
412 | | - ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr); |
413 | | - } |
| 411 | + // Sync array to requested mem_type |
| 412 | + ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr); |
414 | 413 |
|
415 | 414 | // Update pointer |
416 | 415 | switch (mem_type) { |
@@ -763,6 +762,8 @@ int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) { |
763 | 762 | ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", |
764 | 763 | (int (*)())(CeedVectorSetValue_Cuda)); |
765 | 764 | CeedChkBackend(ierr); |
| 765 | + ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", |
| 766 | + CeedVectorSyncArray_Cuda); CeedChkBackend(ierr); |
766 | 767 | ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", |
767 | 768 | CeedVectorGetArray_Cuda); CeedChkBackend(ierr); |
768 | 769 | ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", |
|
0 commit comments