Skip to content

Commit 974f049

Browse files
committed
Fix char types to rocblas
1 parent a854ce2 commit 974f049

1 file changed

Lines changed: 22 additions & 18 deletions

File tree

ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -490,20 +490,21 @@ for (heevd, heev, heevx, heevj, elty, relty) in
490490
lda = max(1, stride(A, 2))
491491
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
492492
if length(V) == 0
493-
jobz = 'N'
493+
jobz = rocSOLVER.rocblas_evect_none
494494
else
495495
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
496-
jobz = 'O'
496+
jobz = rocSOLVER.rocblas_evect_original
497497
end
498498
dh = rocBLAS.handle()
499499
work = ROCVector{$relty}(undef, n)
500500
dev_info = ROCVector{Cint}(undef, 1)
501-
$heevd(dh, jobz, uplo, n, A, lda, W, work, dev_info)
501+
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
502+
$heevd(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info)
502503

503504
info = @allowscalar dev_info[1]
504505
chkargsok(BlasInt(info))
505506

506-
if jobz == 'O' && V !== A
507+
if jobz == rocSOLVER.rocblas_evect_original && V !== A
507508
copy!(V, A)
508509
end
509510
return W, V
@@ -517,20 +518,21 @@ for (heevd, heev, heevx, heevj, elty, relty) in
517518
lda = max(1, stride(A, 2))
518519
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
519520
if length(V) == 0
520-
jobz = 'N'
521+
jobz = rocSOLVER.rocblas_evect_none
521522
else
522523
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
523-
jobz = 'O'
524+
jobz = rocSOLVER.rocblas_evect_original
524525
end
525526
dh = rocBLAS.handle()
526527
work = ROCVector{$relty}(undef, n)
527528
dev_info = ROCVector{Cint}(undef, 1)
528-
$heev(dh, jobz, uplo, n, A, lda, W, work, dev_info)
529+
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
530+
$heev(dh, jobz, roc_uplo, n, A, lda, W, work, dev_info)
529531

530532
info = @allowscalar dev_info[1]
531533
chkargsok(BlasInt(info))
532534

533-
if jobz == 'O' && V !== A
535+
if jobz == rocSOLVER.rocblas_evect_original && V !== A
534536
copy!(V, A)
535537
end
536538
return W, V
@@ -548,22 +550,22 @@ for (heevd, heev, heevx, heevj, elty, relty) in
548550
il = first(kwargs[:irange])
549551
iu = last(kwargs[:irange])
550552
vl = vu = zero($relty)
551-
range = 'I'
553+
range = rocSOLVER.rocblas_erange_index
552554
elseif haskey(kwargs, :vl) || haskey(kwargs, :vu)
553555
vl = convert($relty, get(kwargs, :vl, -Inf))
554556
vu = convert($relty, get(kwargs, :vu, +Inf))
555557
il = iu = 0
556-
range = 'V'
558+
range = rocSOLVER.rocblas_erange_value
557559
else
558560
il = iu = 0
559561
vl = vu = zero($relty)
560-
range = 'A'
562+
range = rocSOLVER.rocblas_erange_all
561563
end
562564
if length(V) == 0
563-
jobz = 'N'
565+
jobz = rocSOLVER.rocblas_evect_none
564566
else
565567
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
566-
jobz = 'O'
568+
jobz = rocSOLVER.rocblas_evect_original
567569
end
568570
dh = rocBLAS.handle()
569571
abstol = -one($relty)
@@ -572,7 +574,8 @@ for (heevd, heev, heevx, heevj, elty, relty) in
572574
work = ROCVector{$relty}(undef, n)
573575
ifail = ROCVector{BlasInt}(undef, n)
574576
dev_info = ROCVector{Cint}(undef, 1)
575-
$heevx(dh, jobz, range, uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info)
577+
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
578+
$heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, m, W, V, ldv, ifail, dev_info)
576579

577580
info = @allowscalar dev_info[1]
578581
chkargsok(BlasInt(info))
@@ -589,21 +592,22 @@ for (heevd, heev, heevx, heevj, elty, relty) in
589592
lda = max(1, stride(A, 2))
590593
length(W) == n || throw(DimensionMismatch("size mismatch between A and W"))
591594
if length(V) == 0
592-
jobz = 'N'
595+
jobz = rocSOLVER.rocblas_evect_none
593596
else
594597
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
595-
jobz = 'O'
598+
jobz = rocSOLVER.rocblas_evect_original
596599
end
597600
dh = rocBLAS.handle()
598601
dev_info = ROCVector{Cint}(undef, 1)
599602
residual = ROCVector{$relty}(undef, 1)
600603
n_sweeps = ROCVector{Cint}(undef, 1)
601-
$heev(dh, jobz, uplo, n, A, lda, abstol, residual, max_sweeps, n_sweeps, W, dev_info)
604+
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
605+
$heev(dh, jobz, roc_uplo, n, A, lda, abstol, residual, max_sweeps, n_sweeps, W, dev_info)
602606

603607
info = @allowscalar dev_info[1]
604608
chkargsok(BlasInt(info))
605609

606-
if jobz == 'O' && V !== A
610+
if jobz == rocSOLVER.rocblas_evect_original && V !== A
607611
copy!(V, A)
608612
end
609613
return W, V

0 commit comments

Comments
 (0)