From 8b376050ef672f48e2047ae2bea2f1c620d43775 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 13 May 2026 12:11:25 +0200 Subject: [PATCH 01/11] Use a pass-through for gemm --- ext/StridedGPUArraysExt.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 6bfb21a..88a8609 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -5,6 +5,8 @@ using GPUArrays: Adapt, KernelAbstractions using GPUArrays.KernelAbstractions: @kernel, @index using StridedViews: ParentIndex +import Strided: isblasmatrix + ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)} # StridedView backed by any GPU array type, with element type linked to the parent. @@ -129,4 +131,6 @@ function Strided._mapreduce_block!( return nothing end +Strided.isblasmatrix(A::GPUStridedView) = false + end From e7f30751b7c836b215be31e79667ae17e55535c3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 2 Jun 2026 16:22:45 -0400 Subject: [PATCH 02/11] Add tests and remove extraneous ones --- test/cuda.jl | 30 ------------------------------ test/gpu.jl | 29 +++++++++++++++++++++++++++++ test/jlarrays.jl | 25 ------------------------- 3 files changed, 29 insertions(+), 55 deletions(-) delete mode 100644 test/cuda.jl delete mode 100644 test/jlarrays.jl diff --git a/test/cuda.jl b/test/cuda.jl deleted file mode 100644 index bc24b26..0000000 --- a/test/cuda.jl +++ /dev/null @@ -1,30 +0,0 @@ -using Test -using Strided -using CUDA: CuMatrix -using CUDA.Adapt: adapt - -for T in (Float32, Float64, Complex{Float32}, Complex{Float64}) - @testset "Copy with CuStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint) - for m1 in (0, 16, 32), m2 in (0, 16, 32) - A1 = cuRAND.randn(T, (m1, m2)) - A2 = similar(A1) - zA1 = CuMatrix(f1(zeros(T, (m1, m2)))) - zA2 = CuMatrix(f2(zeros(T, (m1, m2)))) - A1c = copy(A1) - A2c = copy(A2) - B1 = f1(StridedView(A1c)) - B2 = f2(StridedView(A2c)) - axes(f1(A1)) == axes(f2(A2)) || continue - @test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDACore.Adapt.adapt(Vector{T}, copy!(B2, B1)) - @test copy!(zA1, f1(A1)) == copy!(zA2, B1) - A3 = CuArray(randn(T, (m1, m2))) - A3c = copy(A3) - B3 = f1(StridedView(A3c)) - @. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted - @. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted - @test CUDACore.Adapt.adapt(Vector{T}, f1(A1)) == CUDACore.Adapt.adapt(Vector{T}, B1) - x = rand(T) - @test f1(StridedView(CUDACore.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == CUDACore.Adapt.adapt(Vector{T}, fill!(B1, x)) - end - end -end diff --git a/test/gpu.jl b/test/gpu.jl index 4959789..e720d67 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -38,6 +38,19 @@ Metal.functional() && push!(ATs, MtlArray) end end +@testset "mul! ($AT)" for AT in ATs + N = 2 + for T in (Float32, ComplexF32) + α = rand(T) + β = rand(T) + dims = ntuple(Returns(div(64, N)), N) + A1 = permutedims(StridedView(rand(T, dims)), randperm(N)) + A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) + A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) + end +end @testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs for T in (Float32, ComplexF32) for N in 2:6 @@ -69,6 +82,22 @@ end end end +@testset "copy ($AT)" for AT in ATs + N = 2 + for m1 in (0, 16, 32), m2 in (0, 16, 32), T in (Float32, ComplexF32) + dims = (m1, m2) + A1 = StridedView(rand(T, dims)) + A2 = StridedView(rand(T, dims)) + A3 = StridedView(rand(T, dims)) + for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint) + axes(f1(A1)) == axes(f2(A2)) || continue + B1 = f1(copy(A1)) + B2 = f2(copy(A2)) + @test compare((x, y) -> copy!(y, x), AT, B1, B2) + end + end +end + @testset "broadcasting ($AT)" for AT in ATs for T in (Float32, ComplexF32) A0 = StridedView(rand(T, ())) diff --git a/test/jlarrays.jl b/test/jlarrays.jl deleted file mode 100644 index fa163f6..0000000 --- a/test/jlarrays.jl +++ /dev/null @@ -1,25 +0,0 @@ -@testset for T in (Float32, Float64, Complex{Float32}, Complex{Float64}) - @testset "Copy with JLArrayStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint) - for m1 in (0, 16, 32), m2 in (0, 16, 32) - A1 = JLArray(randn(T, (m1, m2))) - A2 = similar(A1) - zA1 = JLArray(f1(zeros(T, (m1, m2)))) - zA2 = JLArray(f2(zeros(T, (m1, m2)))) - A1c = copy(A1) - A2c = copy(A2) - B1 = f1(StridedView(A1c)) - B2 = f2(StridedView(A2c)) - axes(f1(A1)) == axes(f2(A2)) || continue - @test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1)) - @test copy!(zA1, f1(A1)) == copy!(zA2, B1) - A3 = JLArray(randn(T, (m1, m2))) - A3c = copy(A3) - B3 = f1(StridedView(A3c)) - @. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted - @. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted - @test JLArrays.Adapt.adapt(Vector{T}, f1(A1)) == JLArrays.Adapt.adapt(Vector{T}, B1) - x = rand(T) - @test f1(StridedView(JLArrays.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == JLArrays.Adapt.adapt(Vector{T}, fill!(B1, x)) - end - end -end From 69d18288c1d0a9bc2f96a753f190026bc5f99d84 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 10:35:39 +0200 Subject: [PATCH 03/11] Fixes --- ext/StridedGPUArraysExt.jl | 2 +- test/gpu.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 88a8609..c23c615 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -131,6 +131,6 @@ function Strided._mapreduce_block!( return nothing end -Strided.isblasmatrix(A::GPUStridedView) = false +Strided.isblasmatrix(A::GPUStridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat} = false end diff --git a/test/gpu.jl b/test/gpu.jl index e720d67..523280d 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -18,6 +18,13 @@ CUDACore.functional() && push!(ATs, CuArray) AMDGPU.functional() && push!(ATs, ROCArray) Metal.functional() && push!(ATs, MtlArray) +@testset "isblasmatrix ($AT)" for AT in ATs + for T in (Float32, ComplexF32) + A1 = StridedView(AT(randn(T, 20, 20))) + @test !Strided.isblasmatrix(A1) + end +end + @testset "in-place matrix operations ($AT)" for AT in ATs for T in (Float32, ComplexF32) A1 = StridedView(randn(T, 20, 20)) @@ -51,6 +58,7 @@ end @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) end end + @testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs for T in (Float32, ComplexF32) for N in 2:6 From 0e908bfd9c314e9e94ddbfe2974dfc9d690a3db8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 3 Jun 2026 10:39:14 +0200 Subject: [PATCH 04/11] Formatter --- test/gpu.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu.jl b/test/gpu.jl index 523280d..689e763 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -18,7 +18,7 @@ CUDACore.functional() && push!(ATs, CuArray) AMDGPU.functional() && push!(ATs, ROCArray) Metal.functional() && push!(ATs, MtlArray) -@testset "isblasmatrix ($AT)" for AT in ATs +@testset "isblasmatrix ($AT)" for AT in ATs for T in (Float32, ComplexF32) A1 = StridedView(AT(randn(T, 20, 20))) @test !Strided.isblasmatrix(A1) From d6edc8868a1436e5b626efa4f9067b9a1100f9da Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 4 Jun 2026 16:26:51 +0200 Subject: [PATCH 05/11] Insert blas_mul layer to shim to vendor libraries --- Project.toml | 8 +++++++- ext/StridedAMDGPUExt.jl | 19 +++++++++++++++++++ ext/StridedCUDACoreExt.jl | 16 ---------------- ext/StridedGPUArraysExt.jl | 10 +++++++++- ext/StridedcuBLASExt.jl | 19 +++++++++++++++++++ src/linalg.jl | 11 +++++++++-- test/gpu.jl | 21 +++++++++------------ 7 files changed, 72 insertions(+), 32 deletions(-) create mode 100644 ext/StridedAMDGPUExt.jl delete mode 100644 ext/StridedCUDACoreExt.jl create mode 100644 ext/StridedcuBLASExt.jl diff --git a/Project.toml b/Project.toml index 62e81db..cc5280c 100644 --- a/Project.toml +++ b/Project.toml @@ -9,16 +9,21 @@ StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +cuBLAS = "182d3088-87b7-4494-8cad-fc6afaa545bc" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" [extensions] +StridedcuBLASExt = "cuBLAS" StridedGPUArraysExt = "GPUArrays" +StridedAMDGPUExt = "AMDGPU" [compat] AMDGPU = "2" Aqua = "0.8" Adapt = "4" CUDACore = "6" +cuBLAS = "6" cuRAND = "6" GPUArrays = "11.4.1" JLArrays = "0.3.1" @@ -35,6 +40,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2" +cuBLAS = "182d3088-87b7-4494-8cad-fc6afaa545bc" cuRAND = "20fd9a0b-12d5-4c2f-a8af-7c34e9e60431" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -43,4 +49,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random", "Aqua", "AMDGPU", "CUDACore", "cuRAND", "GPUArrays", "JLArrays", "Metal", "Adapt"] +test = ["Test", "Random", "Aqua", "AMDGPU", "CUDACore", "cuBLAS", "cuRAND", "GPUArrays", "JLArrays", "Metal", "Adapt"] diff --git a/ext/StridedAMDGPUExt.jl b/ext/StridedAMDGPUExt.jl new file mode 100644 index 0000000..627446e --- /dev/null +++ b/ext/StridedAMDGPUExt.jl @@ -0,0 +1,19 @@ +module StridedAMDGPUExt + +using Strided, StridedViews, AMDGPU, AMDGPU.rocBLAS, LinearAlgebra +import Strided: blas_mul! + +const ROCStridedView{T, N, A <: ROCArray{T}} = StridedViews.StridedView{T, N, A} + +function Strided.blas_mul!(C::ROCStridedView{T, 2}, A::ROCStridedView{T, 2}, B::ROCStridedView{T, 2}, α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat} + A2, CA = Strided.getblasmatrix(A) + B2, CB = Strided.getblasmatrix(B) + C2, CC = Strided.getblasmatrix(C) + A2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(A2), size(A2)) + B2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(B2), size(B2)) + C2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(C2), size(C2)) + AMDGPU.rocBLAS.gemm!(CA, CB, convert(T, α), A2a, B2a, convert(T, β), C2a) + return C +end + +end diff --git a/ext/StridedCUDACoreExt.jl b/ext/StridedCUDACoreExt.jl deleted file mode 100644 index a92bd97..0000000 --- a/ext/StridedCUDACoreExt.jl +++ /dev/null @@ -1,16 +0,0 @@ -module StridedCUDACoreExt - -using Strided, StridedViews, CUDACore -using CUDACore: Adapt, KernelAdaptor -using CUDACore: GPUArrays - -const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)} - -function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TAD <: CuArray{TD}, FD <: ALL_FS, TS <: Number, NS, TAS <: CuArray{TS}, FS <: ALL_FS} - bc_style = Base.Broadcast.BroadcastStyle(TAS) - bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst)) - GPUArrays._copyto!(dst, bc) - return dst -end - -end diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index c23c615..0dc5f53 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -131,6 +131,14 @@ function Strided._mapreduce_block!( return nothing end -Strided.isblasmatrix(A::GPUStridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat} = false +function Strided.isblasmatrix(A::GPUStridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat} + if A.op == identity + return stride(A, 1) == 1 || stride(A, 2) == 1 + elseif A.op == conj + return stride(A, 2) == 1 + else # should never happen + return false + end +end end diff --git a/ext/StridedcuBLASExt.jl b/ext/StridedcuBLASExt.jl new file mode 100644 index 0000000..3c41a2c --- /dev/null +++ b/ext/StridedcuBLASExt.jl @@ -0,0 +1,19 @@ +module StridedcuBLASExt + +using Strided, StridedViews, cuBLAS, cuBLAS.CUDACore, LinearAlgebra +import Strided: blas_mul! + +const CuStridedView{T, N, A <: CuArray{T}} = StridedViews.StridedView{T, N, A} + +function Strided.blas_mul!(C::CuStridedView{T, 2}, A::CuStridedView{T, 2}, B::CuStridedView{T, 2}, α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat} + A2, CA = Strided.getblasmatrix(A) + B2, CB = Strided.getblasmatrix(B) + C2, CC = Strided.getblasmatrix(C) + A2a = Base.unsafe_wrap(CuMatrix{T}, pointer(A2), size(A2)) + B2a = Base.unsafe_wrap(CuMatrix{T}, pointer(B2), size(B2)) + C2a = Base.unsafe_wrap(CuMatrix{T}, pointer(C2), size(C2)) + cuBLAS.gemm!(CA, CB, convert(T, α), A2a, B2a, convert(T, β), C2a) + return C +end + +end diff --git a/src/linalg.jl b/src/linalg.jl index febc573..2c2d343 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -100,13 +100,20 @@ function _mul!( α::Number, β::Number ) where {T <: LinearAlgebra.BlasFloat} if stride(C, 1) == 1 && isblasmatrix(A) && isblasmatrix(B) - nthreads = use_threaded_mul() ? get_num_threads() : 1 - _threaded_blas_mul!(C, A, B, α, β, nthreads) + return blas_mul!(C, A, B, α, β) else return __mul!(C, A, B, α, β) end end +# for CPU based arrays, this is valid +function blas_mul!( + C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2}, + α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat} + nthreads = use_threaded_mul() ? get_num_threads() : 1 + return _threaded_blas_mul!(C, A, B, α, β, nthreads) +end + function _threaded_blas_mul!( C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2}, α::Number, β::Number, diff --git a/test/gpu.jl b/test/gpu.jl index 689e763..1c3228c 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -21,7 +21,7 @@ Metal.functional() && push!(ATs, MtlArray) @testset "isblasmatrix ($AT)" for AT in ATs for T in (Float32, ComplexF32) A1 = StridedView(AT(randn(T, 20, 20))) - @test !Strided.isblasmatrix(A1) + @test Strided.isblasmatrix(A1) end end @@ -45,18 +45,15 @@ end end end -@testset "mul! ($AT)" for AT in ATs +@testset "mul! ($AT{$T})" for AT in ATs, T in (Float32, ComplexF32) N = 2 - for T in (Float32, ComplexF32) - α = rand(T) - β = rand(T) - dims = ntuple(Returns(div(64, N)), N) - A1 = permutedims(StridedView(rand(T, dims)), randperm(N)) - A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) - A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) - @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) - @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) - end + α = rand(T) + β = rand(T) + dims = ntuple(Returns(div(64, N)), N) + A1 = permutedims(StridedView(rand(T, dims)), randperm(N)) + A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) + A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) end @testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs From 78d6b17a937fe444ff43ed302d02c8e1f2264aec Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 4 Jun 2026 16:31:10 +0200 Subject: [PATCH 06/11] Formatter --- src/linalg.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/linalg.jl b/src/linalg.jl index 2c2d343..cb56c51 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -109,7 +109,8 @@ end # for CPU based arrays, this is valid function blas_mul!( C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2}, - α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat} + α::Number, β::Number + ) where {T <: LinearAlgebra.BlasFloat} nthreads = use_threaded_mul() ? get_num_threads() : 1 return _threaded_blas_mul!(C, A, B, α, β, nthreads) end From 45fda9765a40ede9c6f6bc103d00aa82cb9c5616 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 4 Jun 2026 11:00:07 -0400 Subject: [PATCH 07/11] Actually use cuBLAS --- test/gpu.jl | 2 +- test/runtests.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/gpu.jl b/test/gpu.jl index 1c3228c..34d0f4a 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -14,7 +14,7 @@ end # types to test for ATs = [] !is_buildkite && push!(ATs, JLArray) -CUDACore.functional() && push!(ATs, CuArray) +CUDACore.functional() && cuBLAS.functional() && push!(ATs, CuArray) AMDGPU.functional() && push!(ATs, ROCArray) Metal.functional() && push!(ATs, MtlArray) diff --git a/test/runtests.jl b/test/runtests.jl index e5d0da0..a091a4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,7 @@ using Aqua using Adapt, GPUArrays using JLArrays using AMDGPU -using CUDACore, cuRAND +using CUDACore, cuRAND, cuBLAS using Metal Random.seed!(1234) From 7cfc9f21ea2b3f3152d301acce93a565b213a8d7 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 5 Jun 2026 14:22:26 -0400 Subject: [PATCH 08/11] Test some non-BLAS matrices for GPU --- test/gpu.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/gpu.jl b/test/gpu.jl index 34d0f4a..c44b522 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -22,6 +22,10 @@ Metal.functional() && push!(ATs, MtlArray) for T in (Float32, ComplexF32) A1 = StridedView(AT(randn(T, 20, 20))) @test Strided.isblasmatrix(A1) + A2 = view(A1, 1:4:20, 1:5:20) + @test !Strided.isblasmatrix(A2) + A3 = view(conj!(A1), 1:4:20, 1:20) # stride(A3, 2) is not 1 + @test !Strided.isblasmatrix(A3) end end @@ -54,6 +58,10 @@ end A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) + vA1 = view(StridedView(rand(T, (32, 32))), 1:2:32, 1:2:32) + vA2 = view(StridedView(rand(T, (32, 32))), 1:2:32, 1:2:32) + vA3 = view(StridedView(rand(T, (32, 32))), 1:2:32, 1:2:32) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, vA2, vA3) end @testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs From bf8978efc25ca8553350488efb45dc4f3f7b8435 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 6 Jun 2026 01:41:23 -0400 Subject: [PATCH 09/11] Cover more blas paths --- test/gpu.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/gpu.jl b/test/gpu.jl index c44b522..6899a06 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -58,9 +58,17 @@ end A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) - vA1 = view(StridedView(rand(T, (32, 32))), 1:2:32, 1:2:32) - vA2 = view(StridedView(rand(T, (32, 32))), 1:2:32, 1:2:32) - vA3 = view(StridedView(rand(T, (32, 32))), 1:2:32, 1:2:32) + # non-BLAS fallback path + vA1 = view(StridedView(rand(T, (32, 32))), 1:32, 1:32) + vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64) + vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) + for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3)) + end + # non-BLAS fallback path + vA1 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) + vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64) + vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, vA2, vA3) end From 6f07ee78db0001b1d609f9d51b40d9029c862c23 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 6 Jun 2026 02:02:02 -0400 Subject: [PATCH 10/11] Formatter --- test/gpu.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/gpu.jl b/test/gpu.jl index 6899a06..8db4ba4 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -59,14 +59,14 @@ end A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) # non-BLAS fallback path - vA1 = view(StridedView(rand(T, (32, 32))), 1:32, 1:32) + vA1 = view(StridedView(rand(T, (32, 32))), 1:32, 1:32) vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64) vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3)) end # non-BLAS fallback path - vA1 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) + vA1 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64) vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, vA2, vA3) From 353c4cbd2ce37461473320c750f6b467706fc5f0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 8 Jun 2026 03:06:30 -0400 Subject: [PATCH 11/11] More tests and a fix --- ext/StridedGPUArraysExt.jl | 7 +++++-- test/gpu.jl | 11 ++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 0dc5f53..04728e8 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -133,9 +133,12 @@ end function Strided.isblasmatrix(A::GPUStridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat} if A.op == identity - return stride(A, 1) == 1 || stride(A, 2) == 1 + # unsafe wrap approach doesn't work if second condition not met + return stride(A, 1) == 1 && size(A, 1) == size(parent(A), 1) elseif A.op == conj - return stride(A, 2) == 1 + # this is converted to adjoint + # unsafe wrap approach doesn't work if second condition not met + return stride(A, 2) == 1 && size(A, 2) == size(parent(A), 2) else # should never happen return false end diff --git a/test/gpu.jl b/test/gpu.jl index 8db4ba4..596ca00 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -58,11 +58,20 @@ end A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) + # test BLAS for all op combinations + @testset for sz in ((32, 64), (64, 64), (64, 32)) + vA1 = view(StridedView(rand(T, sz)), 1:32, 1:32) + vA2 = view(StridedView(rand(T, sz)), 1:32, 1:32) + vA3 = view(StridedView(rand(T, sz)), 1:32, 1:32) + @testset for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3)) + end + end # non-BLAS fallback path vA1 = view(StridedView(rand(T, (32, 32))), 1:32, 1:32) vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64) vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) - for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose) + @testset for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose) @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3)) end # non-BLAS fallback path