Skip to content

Commit 0002d81

Browse files
committed
Use stream sync instead of device sync for hipBlas calls
1 parent c5a41ae commit 0002d81

2 files changed

Lines changed: 35 additions & 10 deletions

File tree

backends/hip-ref/ceed-hip-ref-vector.c

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,18 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
309309
CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
310310
#if (HIP_VERSION >= 60000000)
311311
hipblasHandle_t handle;
312+
hipStream_t stream;
312313
Ceed ceed;
313314

314315
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
315316
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
317+
CeedCallHipblas(ceed, hipblasGetStream(handle, &stream));
316318
#if defined(CEED_SCALAR_IS_FP32)
317319
CeedCallHipblas(ceed, hipblasScopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
318320
#else /* CEED_SCALAR */
319321
CeedCallHipblas(ceed, hipblasDcopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
320322
#endif /* CEED_SCALAR */
323+
CeedCallHip(ceed, hipStreamSynchronize(stream));
321324
#else /* HIP_VERSION */
322325
CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, stop, step, copy_array));
323326
#endif /* HIP_VERSION */
@@ -557,14 +560,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
557560
const CeedScalar *d_array;
558561
CeedVector_Hip *impl;
559562
hipblasHandle_t handle;
563+
hipStream_t stream;
560564
Ceed_Hip *hip_data;
561565

562566
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
563567
CeedCallBackend(CeedGetData(ceed, &hip_data));
564568
CeedCallBackend(CeedVectorGetData(vec, &impl));
565569
CeedCallBackend(CeedVectorGetLength(vec, &length));
566570
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
567-
571+
CeedCallHipblas(ceed, hipblasGetStream(handle, &stream));
568572
#if (HIP_VERSION < 60000000)
569573
// With ROCm 6, we can use the 64-bit integer interface. Prior to that,
570574
// we need to check if the vector is too long to handle with int32,
@@ -581,6 +585,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
581585
#if defined(CEED_SCALAR_IS_FP32)
582586
#if (HIP_VERSION >= 60000000) // We have ROCm 6, and can use 64-bit integers
583587
CeedCallHipblas(ceed, hipblasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
588+
CeedCallHip(ceed, hipStreamSynchronize(stream));
584589
#else /* HIP_VERSION */
585590
float sub_norm = 0.0;
586591
float *d_array_start;
@@ -591,12 +596,14 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
591596
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
592597

593598
CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
599+
CeedCallHip(ceed, hipStreamSynchronize(stream));
594600
*norm += sub_norm;
595601
}
596602
#endif /* HIP_VERSION */
597603
#else /* CEED_SCALAR */
598604
#if (HIP_VERSION >= 60000000)
599605
CeedCallHipblas(ceed, hipblasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
606+
CeedCallHip(ceed, hipStreamSynchronize(stream));
600607
#else /* HIP_VERSION */
601608
double sub_norm = 0.0;
602609
double *d_array_start;
@@ -607,6 +614,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
607614
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
608615

609616
CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
617+
CeedCallHip(ceed, hipStreamSynchronize(stream));
610618
*norm += sub_norm;
611619
}
612620
#endif /* HIP_VERSION */
@@ -617,6 +625,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
617625
#if defined(CEED_SCALAR_IS_FP32)
618626
#if (HIP_VERSION >= 60000000)
619627
CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
628+
CeedCallHip(ceed, hipStreamSynchronize(stream));
620629
#else /* HIP_VERSION */
621630
float sub_norm = 0.0, norm_sum = 0.0;
622631
float *d_array_start;
@@ -627,13 +636,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
627636
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
628637

629638
CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
639+
CeedCallHip(ceed, hipStreamSynchronize(stream));
630640
norm_sum += sub_norm * sub_norm;
631641
}
632642
*norm = sqrt(norm_sum);
633643
#endif /* HIP_VERSION */
634644
#else /* CEED_SCALAR */
635645
#if (HIP_VERSION >= 60000000)
636646
CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
647+
CeedCallHip(ceed, hipStreamSynchronize(stream));
637648
#else /* HIP_VERSION */
638649
double sub_norm = 0.0, norm_sum = 0.0;
639650
double *d_array_start;
@@ -644,6 +655,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
644655
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
645656

646657
CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
658+
CeedCallHip(ceed, hipStreamSynchronize(stream));
647659
norm_sum += sub_norm * sub_norm;
648660
}
649661
*norm = sqrt(norm_sum);
@@ -658,7 +670,8 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
658670
CeedScalar norm_no_abs;
659671

660672
CeedCallHipblas(ceed, hipblasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
661-
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
673+
CeedCallHip(ceed, hipMemcpyAsync(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
674+
CeedCallHip(ceed, hipStreamSynchronize(stream));
662675
*norm = fabs(norm_no_abs);
663676
#else /* HIP_VERSION */
664677
CeedInt index;
@@ -672,10 +685,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
672685

673686
CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
674687
if (hip_data->has_unified_addressing) {
675-
CeedCallHip(ceed, hipDeviceSynchronize());
688+
CeedCallHip(ceed, hipStreamSynchronize(stream));
676689
sub_max = fabs(d_array[index - 1]);
677690
} else {
678-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
691+
CeedCallHip(ceed, hipMemcpyAsync(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
692+
CeedCallHip(ceed, hipStreamSynchronize(stream));
679693
}
680694
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
681695
}
@@ -688,10 +702,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
688702

689703
CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
690704
if (hip_data->has_unified_addressing) {
691-
CeedCallHip(ceed, hipDeviceSynchronize());
705+
CeedCallHip(ceed, hipStreamSynchronize(stream));
692706
norm_no_abs = fabs(d_array[index - 1]);
693707
} else {
694-
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
708+
CeedCallHip(ceed, hipMemcpyAsync(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
709+
CeedCallHip(ceed, hipStreamSynchronize(stream));
695710
}
696711
*norm = fabs(norm_no_abs);
697712
#else /* HIP_VERSION */
@@ -706,10 +721,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
706721

707722
CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
708723
if (hip_data->has_unified_addressing) {
709-
CeedCallHip(ceed, hipDeviceSynchronize());
724+
CeedCallHip(ceed, hipStreamSynchronize(stream));
710725
sub_max = fabs(d_array[index - 1]);
711726
} else {
712-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
727+
CeedCallHip(ceed, hipMemcpyAsync(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
728+
CeedCallHip(ceed, hipStreamSynchronize(stream));
713729
}
714730
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
715731
}
@@ -780,13 +796,16 @@ static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
780796
if (impl->d_array) {
781797
#if (HIP_VERSION >= 60000000)
782798
hipblasHandle_t handle;
799+
hipStream_t stream;
783800

784801
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
802+
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasGetStream(handle, &stream));
785803
#if defined(CEED_SCALAR_IS_FP32)
786804
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
787805
#else /* CEED_SCALAR */
788806
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
789807
#endif /* CEED_SCALAR */
808+
CeedCallHip(CeedVectorReturnCeed(x), hipStreamSynchronize(stream));
790809
#else /* HIP_VERSION */
791810
CeedCallBackend(CeedDeviceScale_Hip(impl->d_array, alpha, length));
792811
#endif /* HIP_VERSION */
@@ -827,13 +846,16 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
827846
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
828847
#if (HIP_VERSION >= 60000000)
829848
hipblasHandle_t handle;
849+
hipStream_t stream;
830850

831-
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(y), &handle));
851+
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
852+
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasGetStream(handle, &stream));
832853
#if defined(CEED_SCALAR_IS_FP32)
833854
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
834855
#else /* CEED_SCALAR */
835856
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
836857
#endif /* CEED_SCALAR */
858+
CeedCallHip(CeedVectorReturnCeed(y), hipStreamSynchronize(stream));
837859
#else /* HIP_VERSION */
838860
CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
839861
#endif /* HIP_VERSION */

backends/hip-ref/ceed-hip-ref.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ int CeedGetHipblasHandle_Hip(Ceed ceed, hipblasHandle_t *handle) {
2929
Ceed_Hip *data;
3030

3131
CeedCallBackend(CeedGetData(ceed, &data));
32-
if (!data->hipblas_handle) CeedCallHipblas(ceed, hipblasCreate(&data->hipblas_handle));
32+
if (!data->hipblas_handle) {
33+
CeedCallHipblas(ceed, hipblasCreate(&data->hipblas_handle));
34+
CeedCallHipblas(ceed, hipblasSetPointerMode(data->hipblas_handle, HIPBLAS_POINTER_MODE_HOST));
35+
}
3336
*handle = data->hipblas_handle;
3437
return CEED_ERROR_SUCCESS;
3538
}

0 commit comments

Comments
 (0)