|
22 | 22 | static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { |
23 | 23 | Ceed ceed; |
24 | 24 | CeedOperator_Hip_gen *impl; |
| 25 | + bool is_composite; |
25 | 26 |
|
26 | 27 | CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); |
27 | 28 | CeedCallBackend(CeedOperatorGetData(op, &impl)); |
| 29 | + CeedCallBackend(CeedOperatorIsComposite(op, &is_composite)); |
| 30 | + if (is_composite) { |
| 31 | + CeedInt num_suboperators; |
| 32 | + |
| 33 | + CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); |
| 34 | + for (CeedInt i = 0; i < num_suboperators; i++) { |
| 35 | + if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i])); |
| 36 | + impl->streams[i] = NULL; |
| 37 | + } |
| 38 | + } |
28 | 39 | if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module)); |
29 | 40 | if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem)); |
30 | 41 | CeedCallBackend(CeedFree(&impl)); |
@@ -239,28 +250,35 @@ static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, C |
239 | 250 | } |
240 | 251 |
|
241 | 252 | static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { |
242 | | - bool is_run_good[CEED_COMPOSITE_MAX] = {false}; |
243 | | - CeedInt num_suboperators; |
244 | | - const CeedScalar *input_arr = NULL; |
245 | | - CeedScalar *output_arr = NULL; |
246 | | - Ceed ceed; |
247 | | - CeedOperator *sub_operators; |
| 253 | + bool is_run_good[CEED_COMPOSITE_MAX] = {true}; |
| 254 | + CeedInt num_suboperators; |
| 255 | + const CeedScalar *input_arr = NULL; |
| 256 | + CeedScalar *output_arr; |
| 257 | + Ceed ceed; |
| 258 | + CeedOperator_Hip_gen *impl; |
| 259 | + CeedOperator *sub_operators; |
248 | 260 |
|
249 | 261 | CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); |
250 | | - CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); |
251 | | - CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators)); |
| 262 | + CeedCallBackend(CeedOperatorGetData(op, &impl)); |
| 263 | + CeedCallBackend(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); |
| 264 | + CeedCallBackend(CeedCompositeOperatorGetSubList(op, &sub_operators)); |
252 | 265 | if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); |
253 | 266 | if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); |
254 | 267 | for (CeedInt i = 0; i < num_suboperators; i++) { |
255 | 268 | CeedInt num_elem = 0; |
256 | 269 |
|
257 | | - CeedCall(CeedOperatorGetNumElements(sub_operators[i], &num_elem)); |
| 270 | + CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem)); |
258 | 271 | if (num_elem > 0) { |
259 | | - hipStream_t stream = NULL; |
| 272 | + if (!impl->streams[i]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[i])); |
| 273 | + CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[i], input_arr, output_arr, &is_run_good[i], request)); |
| 274 | + } else { |
| 275 | + is_run_good[i] = true; |
| 276 | + } |
| 277 | + } |
260 | 278 |
|
261 | | - CeedCallHip(ceed, hipStreamCreate(&stream)); |
262 | | - CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], stream, input_arr, output_arr, &is_run_good[i], request)); |
263 | | - CeedCallHip(ceed, hipStreamDestroy(stream)); |
| 279 | + for (CeedInt i = 0; i < num_suboperators; i++) { |
| 280 | + if (impl->streams[i]) { |
| 281 | + if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i])); |
264 | 282 | } |
265 | 283 | } |
266 | 284 | if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); |
|
0 commit comments