@@ -309,15 +309,18 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
309309 CeedCallBackend (CeedVectorGetArray (vec_copy , CEED_MEM_DEVICE , & copy_array ));
310310#if (HIP_VERSION >= 60000000 )
311311 hipblasHandle_t handle ;
312+ hipStream_t stream ;
312313 Ceed ceed ;
313314
314315 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
315316 CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
317+ CeedCallHipblas (ceed , hipblasGetStream (handle , & stream ));
316318#if defined(CEED_SCALAR_IS_FP32 )
317319 CeedCallHipblas (ceed , hipblasScopy_64 (handle , (int64_t )(stop - start ), impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
318320#else /* CEED_SCALAR */
319321 CeedCallHipblas (ceed , hipblasDcopy_64 (handle , (int64_t )(stop - start ), impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
320322#endif /* CEED_SCALAR */
323+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
321324#else /* HIP_VERSION */
322325 CeedCallBackend (CeedDeviceCopyStrided_Hip (impl -> d_array , start , stop , step , copy_array ));
323326#endif /* HIP_VERSION */
@@ -557,14 +560,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
557560 const CeedScalar * d_array ;
558561 CeedVector_Hip * impl ;
559562 hipblasHandle_t handle ;
563+ hipStream_t stream ;
560564 Ceed_Hip * hip_data ;
561565
562566 CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
563567 CeedCallBackend (CeedGetData (ceed , & hip_data ));
564568 CeedCallBackend (CeedVectorGetData (vec , & impl ));
565569 CeedCallBackend (CeedVectorGetLength (vec , & length ));
566570 CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
567-
571+ CeedCallHipblas ( ceed , hipblasGetStream ( handle , & stream ));
568572#if (HIP_VERSION < 60000000 )
569573 // With ROCm 6, we can use the 64-bit integer interface. Prior to that,
570574 // we need to check if the vector is too long to handle with int32,
@@ -581,6 +585,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
581585#if defined(CEED_SCALAR_IS_FP32 )
582586#if (HIP_VERSION >= 60000000 ) // We have ROCm 6, and can use 64-bit integers
583587 CeedCallHipblas (ceed , hipblasSasum_64 (handle , (int64_t )length , (float * )d_array , 1 , (float * )norm ));
588+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
584589#else /* HIP_VERSION */
585590 float sub_norm = 0.0 ;
586591 float * d_array_start ;
@@ -591,12 +596,14 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
591596 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
592597
593598 CeedCallHipblas (ceed , hipblasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
599+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
594600 * norm += sub_norm ;
595601 }
596602#endif /* HIP_VERSION */
597603#else /* CEED_SCALAR */
598604#if (HIP_VERSION >= 60000000 )
599605 CeedCallHipblas (ceed , hipblasDasum_64 (handle , (int64_t )length , (double * )d_array , 1 , (double * )norm ));
606+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
600607#else /* HIP_VERSION */
601608 double sub_norm = 0.0 ;
602609 double * d_array_start ;
@@ -607,6 +614,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
607614 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
608615
609616 CeedCallHipblas (ceed , hipblasDasum (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & sub_norm ));
617+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
610618 * norm += sub_norm ;
611619 }
612620#endif /* HIP_VERSION */
@@ -617,6 +625,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
617625#if defined(CEED_SCALAR_IS_FP32 )
618626#if (HIP_VERSION >= 60000000 )
619627 CeedCallHipblas (ceed , hipblasSnrm2_64 (handle , (int64_t )length , (float * )d_array , 1 , (float * )norm ));
628+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
620629#else /* HIP_VERSION */
621630 float sub_norm = 0.0 , norm_sum = 0.0 ;
622631 float * d_array_start ;
@@ -627,13 +636,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
627636 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
628637
629638 CeedCallHipblas (ceed , hipblasSnrm2 (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
639+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
630640 norm_sum += sub_norm * sub_norm ;
631641 }
632642 * norm = sqrt (norm_sum );
633643#endif /* HIP_VERSION */
634644#else /* CEED_SCALAR */
635645#if (HIP_VERSION >= 60000000 )
636646 CeedCallHipblas (ceed , hipblasDnrm2_64 (handle , (int64_t )length , (double * )d_array , 1 , (double * )norm ));
647+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
637648#else /* HIP_VERSION */
638649 double sub_norm = 0.0 , norm_sum = 0.0 ;
639650 double * d_array_start ;
@@ -644,6 +655,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
644655 CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
645656
646657 CeedCallHipblas (ceed , hipblasDnrm2 (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & sub_norm ));
658+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
647659 norm_sum += sub_norm * sub_norm ;
648660 }
649661 * norm = sqrt (norm_sum );
@@ -658,7 +670,8 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
658670 CeedScalar norm_no_abs ;
659671
660672 CeedCallHipblas (ceed , hipblasIsamax_64 (handle , (int64_t )length , (float * )d_array , 1 , & index ));
661- CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
673+ CeedCallHip (ceed , hipMemcpyAsync (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
674+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
662675 * norm = fabs (norm_no_abs );
663676#else /* HIP_VERSION */
664677 CeedInt index ;
@@ -672,10 +685,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
672685
673686 CeedCallHipblas (ceed , hipblasIsamax (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & index ));
674687 if (hip_data -> has_unified_addressing ) {
675- CeedCallHip (ceed , hipDeviceSynchronize ( ));
688+ CeedCallHip (ceed , hipStreamSynchronize ( stream ));
676689 sub_max = fabs (d_array [index - 1 ]);
677690 } else {
678- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
691+ CeedCallHip (ceed , hipMemcpyAsync (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
692+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
679693 }
680694 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
681695 }
@@ -688,10 +702,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
688702
689703 CeedCallHipblas (ceed , hipblasIdamax_64 (handle , (int64_t )length , (double * )d_array , 1 , & index ));
690704 if (hip_data -> has_unified_addressing ) {
691- CeedCallHip (ceed , hipDeviceSynchronize ( ));
705+ CeedCallHip (ceed , hipStreamSynchronize ( stream ));
692706 norm_no_abs = fabs (d_array [index - 1 ]);
693707 } else {
694- CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
708+ CeedCallHip (ceed , hipMemcpyAsync (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
709+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
695710 }
696711 * norm = fabs (norm_no_abs );
697712#else /* HIP_VERSION */
@@ -706,10 +721,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
706721
707722 CeedCallHipblas (ceed , hipblasIdamax (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & index ));
708723 if (hip_data -> has_unified_addressing ) {
709- CeedCallHip (ceed , hipDeviceSynchronize ( ));
724+ CeedCallHip (ceed , hipStreamSynchronize ( stream ));
710725 sub_max = fabs (d_array [index - 1 ]);
711726 } else {
712- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
727+ CeedCallHip (ceed , hipMemcpyAsync (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
728+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
713729 }
714730 if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
715731 }
@@ -780,13 +796,16 @@ static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
780796 if (impl -> d_array ) {
781797#if (HIP_VERSION >= 60000000 )
782798 hipblasHandle_t handle ;
799+ hipStream_t stream ;
783800
784801 CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (x ), & handle ));
802+ CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasGetStream (handle , & stream ));
785803#if defined(CEED_SCALAR_IS_FP32 )
786804 CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasSscal_64 (handle , (int64_t )length , & alpha , impl -> d_array , 1 ));
787805#else /* CEED_SCALAR */
788806 CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasDscal_64 (handle , (int64_t )length , & alpha , impl -> d_array , 1 ));
789807#endif /* CEED_SCALAR */
808+ CeedCallHip (CeedVectorReturnCeed (x ), hipStreamSynchronize (stream ));
790809#else /* HIP_VERSION */
791810 CeedCallBackend (CeedDeviceScale_Hip (impl -> d_array , alpha , length ));
792811#endif /* HIP_VERSION */
@@ -827,13 +846,16 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
827846 CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_DEVICE ));
828847#if (HIP_VERSION >= 60000000 )
829848 hipblasHandle_t handle ;
849+ hipStream_t stream ;
830850
831- CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (y ), & handle ));
851+ CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (x ), & handle ));
852+ CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasGetStream (handle , & stream ));
832853#if defined(CEED_SCALAR_IS_FP32 )
833854 CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasSaxpy_64 (handle , (int64_t )length , & alpha , x_impl -> d_array , 1 , y_impl -> d_array , 1 ));
834855#else /* CEED_SCALAR */
835856 CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasDaxpy_64 (handle , (int64_t )length , & alpha , x_impl -> d_array , 1 , y_impl -> d_array , 1 ));
836857#endif /* CEED_SCALAR */
858+ CeedCallHip (CeedVectorReturnCeed (y ), hipStreamSynchronize (stream ));
837859#else /* HIP_VERSION */
838860 CeedCallBackend (CeedDeviceAXPY_Hip (y_impl -> d_array , alpha , x_impl -> d_array , length ));
839861#endif /* HIP_VERSION */
0 commit comments