Skip to content

Commit 1095a7a

Browse files
authored
Merge pull request #1812 from CEED/zach/reset-work-vectors
Reset work vectors in operator setup to save memory
2 parents acba419 + 25433d2 commit 1095a7a

5 files changed

Lines changed: 159 additions & 1 deletion

File tree

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,16 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) {
353353
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
354354
}
355355
}
356+
CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len));
357+
{
358+
// Create two work vectors for diagonal assembly
359+
CeedVector temp_1, temp_2;
360+
361+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1));
362+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2));
363+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1));
364+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2));
365+
}
356366
CeedCallBackend(CeedOperatorSetSetupDone(op));
357367
CeedCallBackend(CeedQFunctionDestroy(&qf));
358368
return CEED_ERROR_SUCCESS;
@@ -740,6 +750,16 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
740750
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
741751
}
742752
}
753+
CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len));
754+
{
755+
// Create two work vectors for diagonal assembly
756+
CeedVector temp_1, temp_2;
757+
758+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1));
759+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2));
760+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1));
761+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2));
762+
}
743763
CeedCallBackend(CeedOperatorSetSetupDone(op));
744764
CeedCallBackend(CeedQFunctionDestroy(&qf));
745765
return CEED_ERROR_SUCCESS;

backends/hip-ref/ceed-hip-ref-operator.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,16 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
352352
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
353353
}
354354
}
355+
CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len));
356+
{
357+
// Create two work vectors for diagonal assembly
358+
CeedVector temp_1, temp_2;
359+
360+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1));
361+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2));
362+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1));
363+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2));
364+
}
355365
CeedCallBackend(CeedOperatorSetSetupDone(op));
356366
CeedCallBackend(CeedQFunctionDestroy(&qf));
357367
return CEED_ERROR_SUCCESS;
@@ -738,6 +748,16 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
738748
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
739749
}
740750
}
751+
CeedCallBackend(CeedClearWorkVectors(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len));
752+
{
753+
// Create two work vectors for diagonal assembly
754+
CeedVector temp_1, temp_2;
755+
756+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_1));
757+
CeedCallBackend(CeedGetWorkVector(CeedOperatorReturnCeed(op), impl->max_active_e_vec_len, &temp_2));
758+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_1));
759+
CeedCallBackend(CeedRestoreWorkVector(CeedOperatorReturnCeed(op), &temp_2));
760+
}
741761
CeedCallBackend(CeedOperatorSetSetupDone(op));
742762
CeedCallBackend(CeedQFunctionDestroy(&qf));
743763
return CEED_ERROR_SUCCESS;

include/ceed/backend.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ CEED_EXTERN int CeedSetData(Ceed ceed, void *data);
256256
CEED_EXTERN int CeedReference(Ceed ceed);
257257
CEED_EXTERN int CeedGetWorkVector(Ceed ceed, CeedSize len, CeedVector *vec);
258258
CEED_EXTERN int CeedRestoreWorkVector(Ceed ceed, CeedVector *vec);
259+
CEED_EXTERN int CeedClearWorkVectors(Ceed ceed, CeedSize min_len);
260+
CEED_EXTERN int CeedGetWorkVectorMemoryUsage(Ceed ceed, CeedScalar *usage_mb);
259261
CEED_EXTERN int CeedGetJitSourceRoots(Ceed ceed, CeedInt *num_source_roots, const char ***jit_source_roots);
260262
CEED_EXTERN int CeedRestoreJitSourceRoots(Ceed ceed, const char ***jit_source_roots);
261263
CEED_EXTERN int CeedGetJitDefines(Ceed ceed, CeedInt *num_defines, const char ***jit_defines);

interface/ceed.c

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,62 @@ int CeedReference(Ceed ceed) {
817817
return CEED_ERROR_SUCCESS;
818818
}
819819

820+
/**
821+
@brief Computes the current memory usage of the work vectors in a `Ceed` context and prints to debug.abort
822+
823+
@param[in] ceed `Ceed` context
824+
825+
@return An error code: 0 - success, otherwise - failure
826+
827+
@ref Developer
828+
**/
829+
int CeedGetWorkVectorMemoryUsage(Ceed ceed, CeedScalar *usage_mb) {
830+
*usage_mb = 0.0;
831+
if (ceed->work_vectors) {
832+
for (CeedInt i = 0; i < ceed->work_vectors->num_vecs; i++) {
833+
CeedSize vec_len;
834+
CeedCall(CeedVectorGetLength(ceed->work_vectors->vecs[i], &vec_len));
835+
*usage_mb += vec_len;
836+
}
837+
*usage_mb *= sizeof(CeedScalar) * 1e-6;
838+
CeedDebug(ceed, "Resource {%s}: Work vectors memory usage: %" CeedInt_FMT " vectors, %g MB\n", ceed->resource, ceed->work_vectors->num_vecs,
839+
*usage_mb);
840+
}
841+
return CEED_ERROR_SUCCESS;
842+
}
843+
844+
/**
845+
@brief Clear inactive work vectors in a `Ceed` context below a minimum length.
846+
847+
@param[in,out] ceed `Ceed` context
848+
@param[in] min_len Minimum length of work vector to keep
849+
850+
@return An error code: 0 - success, otherwise - failure
851+
852+
@ref Backend
853+
**/
854+
int CeedClearWorkVectors(Ceed ceed, CeedSize min_len) {
855+
if (!ceed->work_vectors) return CEED_ERROR_SUCCESS;
856+
for (CeedInt i = 0; i < ceed->work_vectors->num_vecs; i++) {
857+
if (ceed->work_vectors->is_in_use[i]) continue;
858+
CeedSize vec_len;
859+
CeedCall(CeedVectorGetLength(ceed->work_vectors->vecs[i], &vec_len));
860+
if (vec_len < min_len) {
861+
ceed->ref_count += 2; // Note: increase ref_count to prevent Ceed destructor from triggering
862+
CeedCall(CeedVectorDestroy(&ceed->work_vectors->vecs[i]));
863+
ceed->ref_count -= 1; // Note: restore ref_count
864+
ceed->work_vectors->num_vecs--;
865+
if (ceed->work_vectors->num_vecs > 0) {
866+
ceed->work_vectors->vecs[i] = ceed->work_vectors->vecs[ceed->work_vectors->num_vecs];
867+
ceed->work_vectors->is_in_use[i] = ceed->work_vectors->is_in_use[ceed->work_vectors->num_vecs];
868+
ceed->work_vectors->is_in_use[ceed->work_vectors->num_vecs] = false;
869+
i--;
870+
}
871+
}
872+
}
873+
return CEED_ERROR_SUCCESS;
874+
}
875+
820876
/**
821877
@brief Get a `CeedVector` for scratch work from a `Ceed` context.
822878
@@ -831,7 +887,8 @@ int CeedReference(Ceed ceed) {
831887
@ref Backend
832888
**/
833889
int CeedGetWorkVector(Ceed ceed, CeedSize len, CeedVector *vec) {
834-
CeedInt i = 0;
890+
CeedInt i = 0;
891+
CeedScalar usage_mb;
835892

836893
if (!ceed->work_vectors) CeedCall(CeedWorkVectorsCreate(ceed));
837894

@@ -858,6 +915,7 @@ int CeedGetWorkVector(Ceed ceed, CeedSize len, CeedVector *vec) {
858915
ceed->work_vectors->num_vecs++;
859916
CeedCallBackend(CeedVectorCreate(ceed, len, &ceed->work_vectors->vecs[i]));
860917
ceed->ref_count--; // Note: ref_count manipulation to prevent a ref-loop
918+
if (ceed->is_debug) CeedGetWorkVectorMemoryUsage(ceed, &usage_mb);
861919
}
862920
// Return pointer to work vector
863921
ceed->work_vectors->is_in_use[i] = true;

tests/t131-vector.c

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/// @file
2+
/// Test clearing work vectors
3+
/// \test Test clearing work vectors
4+
5+
#include <ceed.h>
6+
#include <ceed/backend.h>
7+
#include <math.h>
8+
#include <stdio.h>
9+
10+
static CeedScalar expected_usage(CeedSize length) { return length * sizeof(CeedScalar) * 1e-6; }
11+
12+
int main(int argc, char **argv) {
13+
Ceed ceed;
14+
CeedVector x, y, z;
15+
CeedScalar usage_mb;
16+
17+
CeedInit(argv[1], &ceed);
18+
19+
// Add work vectors of different lengths
20+
CeedGetWorkVector(ceed, 10, &x);
21+
CeedGetWorkVector(ceed, 20, &y);
22+
CeedGetWorkVector(ceed, 30, &z);
23+
24+
// Check memory usage, should be 60 * sizeof(CeedScalar)
25+
CeedGetWorkVectorMemoryUsage(ceed, &usage_mb);
26+
if (fabs(usage_mb - expected_usage(60)) > 100. * CEED_EPSILON) printf("Wrong usage: %0.8g MB != %0.8g MB\n", usage_mb, expected_usage(60));
27+
28+
// Restore x and z
29+
CeedRestoreWorkVector(ceed, &x);
30+
CeedRestoreWorkVector(ceed, &z);
31+
32+
// Clear work vectors with length < 30. This should:
33+
// - Remove x
34+
// - Leave y, since it is still in use
35+
// - Leave z, since it is length 30
36+
CeedClearWorkVectors(ceed, 30);
37+
CeedGetWorkVectorMemoryUsage(ceed, &usage_mb);
38+
if (fabs(usage_mb - expected_usage(50)) > 100. * CEED_EPSILON) printf("Wrong usage: %0.8g MB != %0.8g MB\n", usage_mb, expected_usage(50));
39+
40+
// Clear work vectors with length < 31. This should:
41+
// - Leave y, since it is still in use
42+
// - Remove z
43+
CeedClearWorkVectors(ceed, 31);
44+
CeedGetWorkVectorMemoryUsage(ceed, &usage_mb);
45+
if (fabs(usage_mb - expected_usage(20)) > 100. * CEED_EPSILON) printf("Wrong usage: %0.8g MB != %0.8g MB\n", usage_mb, expected_usage(20));
46+
47+
// Restore y
48+
CeedRestoreWorkVector(ceed, &y);
49+
50+
// Make sure we can still get back y without allocating a new work vector
51+
CeedGetWorkVector(ceed, 20, &y);
52+
CeedGetWorkVectorMemoryUsage(ceed, &usage_mb);
53+
if (fabs(usage_mb - expected_usage(20)) > 100. * CEED_EPSILON) printf("Wrong usage: %0.8g MB != %0.8g MB\n", usage_mb, expected_usage(20));
54+
CeedRestoreWorkVector(ceed, &y);
55+
56+
CeedDestroy(&ceed);
57+
return 0;
58+
}

0 commit comments

Comments
 (0)