Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/LinearSolveAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase

# LU Factorization
function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::AMDGPUOffloadLUFactorization;
cache::LinearSolve.LinearCacheType, alg::AMDGPUOffloadLUFactorization;
kwargs...
)
if cache.isfresh
Expand Down Expand Up @@ -36,7 +36,7 @@ end

# QR Factorization
function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::AMDGPUOffloadQRFactorization;
cache::LinearSolve.LinearCacheType, alg::AMDGPUOffloadQRFactorization;
kwargs...
)
if cache.isfresh
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolveAlgebraicMultigridExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearSolveAlgebraicMultigridExt

using LinearSolve, AlgebraicMultigrid, LinearAlgebra
using LinearSolve: LinearCache, LinearVerbosity, OperatorAssumptions
using LinearSolve: LinearCache, LinearCacheType, LinearVerbosity, OperatorAssumptions
using SciMLBase: SciMLBase, ReturnCode

function LinearSolve.init_cacheval(
Expand All @@ -19,7 +19,7 @@ function LinearSolve.init_cacheval(
return SciMLBase.init(amg_alg, A, b; alg.kwargs...)
end

function SciMLBase.solve!(cache::LinearCache, alg::AlgebraicMultigridJL; kwargs...)
function SciMLBase.solve!(cache::LinearCacheType, alg::AlgebraicMultigridJL; kwargs...)
if cache.isfresh
cache.cacheval = LinearSolve.init_cacheval(
alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr,
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolveBLISExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using LinearSolve
using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase, LinearVerbosity, get_blas_operation_info, blas_info_msg
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, LinearCacheType, SciMLBase, LinearVerbosity, get_blas_operation_info, blas_info_msg
using SciMLLogging: SciMLLogging, @SciMLMessage
using SciMLBase: ReturnCode

Expand Down Expand Up @@ -272,7 +272,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearCache, alg::BLISLUFactorization;
cache::LinearCacheType, alg::BLISLUFactorization;
kwargs...
)
A = cache.A
Expand Down
8 changes: 4 additions & 4 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
cache::LinearSolve.LinearCacheType, alg::CudaOffloadLUFactorization;
kwargs...
)
if cache.isfresh
Expand Down Expand Up @@ -92,7 +92,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::CudaOffloadQRFactorization;
cache::LinearSolve.LinearCacheType, alg::CudaOffloadQRFactorization;
kwargs...
)
if cache.isfresh
Expand Down Expand Up @@ -120,7 +120,7 @@ end

# Keep the deprecated CudaOffloadFactorization working by forwarding to QR
function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
cache::LinearSolve.LinearCacheType, alg::CudaOffloadFactorization;
kwargs...
)
if cache.isfresh
Expand Down Expand Up @@ -164,7 +164,7 @@ end

# Mixed precision CUDA LU implementation
function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
cache::LinearSolve.LinearCacheType, alg::CUDAOffload32MixedLUFactorization;
kwargs...
)
if cache.isfresh
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearSolveCUSOLVERRFExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function LinearSolve.init_cacheval(
return RFLU(A_gpu; nrhs = nrhs, symbolic = symbolic)
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::LinearSolve.CUSOLVERRFFactorization; kwargs...)
function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::LinearSolve.CUSOLVERRFFactorization; kwargs...)
A = cache.A

# Convert to appropriate GPU format if needed
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearSolveCliqueTreesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function LinearSolve.init_cacheval(
return makefactor(A, alg.alg, alg.snd)
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CliqueTreesFactorization; kwargs...)
function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::CliqueTreesFactorization; kwargs...)
A = cache.A
u = cache.u
b = cache.b
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolveFastLapackInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::FastLUFactorization; kwargs...
cache::LinearSolve.LinearCacheType, alg::FastLUFactorization; kwargs...
)
A = cache.A
A = convert(AbstractMatrix, A)
Expand Down Expand Up @@ -78,7 +78,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::FastQRFactorization{P};
cache::LinearSolve.LinearCacheType, alg::FastQRFactorization{P};
kwargs...
) where {P}
A = cache.A
Expand Down
6 changes: 3 additions & 3 deletions ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module LinearSolveHYPREExt
using LinearAlgebra
using HYPRE.LibHYPRE: HYPRE_Complex
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
using LinearSolve: HYPREAlgorithm, LinearCache, LinearCacheType, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning, LinearSolveAdjoint, LinearVerbosity
using SciMLLogging: SciMLLogging, verbosity_to_int, @SciMLMessage
Expand Down Expand Up @@ -176,7 +176,7 @@ create_solver(::Type{S}, comm) where {S <: COMM_SOLVERS} = S(comm)
const NO_COMM_SOLVERS = Union{HYPRE.BoomerAMG, HYPRE.Hybrid, HYPRE.ILU}
create_solver(::Type{S}, comm) where {S <: NO_COMM_SOLVERS} = S()

function create_solver(alg::HYPREAlgorithm, cache::LinearCache)
function create_solver(alg::HYPREAlgorithm, cache::LinearCacheType)
# If the solver is already instantiated, return it directly
if alg.solver isa HYPRE.HYPRESolver
return alg.solver
Expand Down Expand Up @@ -231,7 +231,7 @@ function create_solver(alg::HYPREAlgorithm, cache::LinearCache)
end

# TODO: How are args... and kwargs... supposed to be used here?
function SciMLBase.solve!(cache::LinearCache, alg::HYPREAlgorithm, args...; kwargs...)
function SciMLBase.solve!(cache::LinearCacheType, alg::HYPREAlgorithm, args...; kwargs...)
# It is possible to reach here without HYPRE.Init() being called if HYPRE structures are
# only to be created here internally (i.e. when cache.A::SparseMatrixCSC and not a
# ::HYPREMatrix created externally by the user). Be nice to the user and call it :)
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearSolveIterativeSolversExt

using LinearSolve, LinearAlgebra
using LinearSolve: LinearCache, DEFAULT_PRECS, LinearVerbosity
using LinearSolve: LinearCache, LinearCacheType, DEFAULT_PRECS, LinearVerbosity
import LinearSolve: IterativeSolversJL
using SciMLLogging: SciMLLogging, @SciMLMessage

Expand Down Expand Up @@ -132,7 +132,7 @@ function LinearSolve.init_cacheval(
return iterable
end

function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
function SciMLBase.solve!(cache::LinearCacheType, alg::IterativeSolversJL; kwargs...)
if cache.precsisfresh && !isnothing(alg.precs)
Pl, Pr = alg.precs(cache.Pl, cache.Pr)
cache.Pl = Pl
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolveKrylovKitExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearSolveKrylovKitExt

using LinearSolve, KrylovKit, LinearAlgebra
using LinearSolve: LinearCache, DEFAULT_PRECS
using LinearSolve: LinearCache, LinearCacheType, DEFAULT_PRECS
using SciMLLogging: SciMLLogging, @SciMLMessage, verbosity_to_int

function LinearSolve.KrylovKitJL(
Expand All @@ -24,7 +24,7 @@ end
LinearSolve.default_alias_A(::KrylovKitJL, ::Any, ::Any) = true
LinearSolve.default_alias_b(::KrylovKitJL, ::Any, ::Any) = true

function SciMLBase.solve!(cache::LinearCache, alg::KrylovKitJL; kwargs...)
function SciMLBase.solve!(cache::LinearCacheType, alg::KrylovKitJL; kwargs...)
atol = float(cache.abstol)
rtol = float(cache.reltol)
maxiter = cache.maxiters
Expand Down
6 changes: 3 additions & 3 deletions ext/LinearSolveMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Metal, LinearSolve
using LinearAlgebra, SciMLBase
using SciMLBase: AbstractSciMLOperator
using LinearSolve: ArrayInterface, MKLLUFactorization, MetalOffload32MixedLUFactorization,
@get_cacheval, LinearCache, SciMLBase, OperatorAssumptions, LinearVerbosity
@get_cacheval, LinearCache, LinearCacheType, SciMLBase, OperatorAssumptions, LinearVerbosity

@static if Sys.isapple()

Expand All @@ -24,7 +24,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearCache, alg::MetalLUFactorization;
cache::LinearCacheType, alg::MetalLUFactorization;
kwargs...
)
A = cache.A
Expand Down Expand Up @@ -63,7 +63,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearCache, alg::MetalOffload32MixedLUFactorization;
cache::LinearCacheType, alg::MetalOffload32MixedLUFactorization;
kwargs...
)
A = cache.A
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolvePETScExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using PETSc
using PETSc: MPI
using PETSc: petsclibs
using SparseArrays: SparseMatrixCSC, sparse
using LinearSolve: PETScAlgorithm, LinearCache, LinearProblem, LinearSolve,
using LinearSolve: PETScAlgorithm, LinearCache, LinearCacheType, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning, LinearSolveAdjoint, LinearVerbosity
using SciMLLogging: SciMLLogging, verbosity_to_int, @SciMLMessage
Expand Down Expand Up @@ -98,7 +98,7 @@ function pc_type_string(pc_type::Symbol)
return get(pc_types, pc_type, string(pc_type))
end

function SciMLBase.solve!(cache::LinearCache, alg::PETScAlgorithm, args...; kwargs...)
function SciMLBase.solve!(cache::LinearCacheType, alg::PETScAlgorithm, args...; kwargs...)
pcache = cache.cacheval

# Get element type from the problem
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearSolvePardisoExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function LinearSolve.init_cacheval(
return solver
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs...)
function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::PardisoJL; kwargs...)
(; A, b, u) = cache
A = convert(AbstractMatrix, A)
if cache.isfresh
Expand Down
6 changes: 3 additions & 3 deletions ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using SciMLLogging: @SciMLMessage
LinearSolve.userecursivefactorization(A::Union{Nothing, AbstractMatrix}) = true

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::RFLUFactorization{P, T};
cache::LinearSolve.LinearCacheType, alg::RFLUFactorization{P, T};
kwargs...
) where {P, T}
A = cache.A
Expand Down Expand Up @@ -63,7 +63,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::RF32MixedLUFactorization{P, T};
cache::LinearSolve.LinearCacheType, alg::RF32MixedLUFactorization{P, T};
kwargs...
) where {P, T}
A = cache.A
Expand Down Expand Up @@ -115,7 +115,7 @@ function SciMLBase.solve!(
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::ButterflyFactorization;
cache::LinearSolve.LinearCacheType, alg::ButterflyFactorization;
kwargs...
)
cache_A = cache.A
Expand Down
6 changes: 3 additions & 3 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ end
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs...
cache::LinearSolve.LinearCacheType, alg::UMFPACKFactorization; kwargs...
)
A = cache.A
A = convert(AbstractMatrix, A)
Expand Down Expand Up @@ -284,7 +284,7 @@ end

else
function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs...
cache::LinearSolve.LinearCacheType, alg::UMFPACKFactorization; kwargs...
)
error("UMFPACKFactorization requires GPL libraries (UMFPACK). Rebuild Julia with USE_GPL_LIBS=1 or use an alternative algorithm like SparspakFactorization")
end
Expand Down Expand Up @@ -399,7 +399,7 @@ function LinearSolve.init_cacheval(
end
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization; kwargs...)
function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::KLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearSolveSparspakExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::SparspakFactorization; kwargs...
cache::LinearSolve.LinearCacheType, alg::SparspakFactorization; kwargs...
)
A = cache.A
if cache.isfresh
Expand Down
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ end
const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64}

function defaultalg_symbol end
function defaultalg_symbol(::Type{T}) where {T}
return Base.typename(SciMLBase.parameterless_type(T)).name
end

include("verbosity.jl")
include("blas_logging.jl")
Expand All @@ -405,6 +408,7 @@ include("preconditioners.jl")
include("preferences.jl")
include("solve_function.jl")
include("default.jl")
include("vf64_types.jl")
include("init.jl")
include("adjoint.jl")

Expand Down Expand Up @@ -537,4 +541,6 @@ export LinearSolveAdjoint

export LinearVerbosity

export LinearCacheVF64, LinearCacheType, DefaultLinearSolverInitVF64, DefaultLinearSolverInitType

end
4 changes: 2 additions & 2 deletions src/appleaccelerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearCache, alg::AppleAccelerateLUFactorization;
cache::LinearCacheType, alg::AppleAccelerateLUFactorization;
kwargs...
)
__appleaccelerate_isavailable() ||
Expand Down Expand Up @@ -407,7 +407,7 @@ function LinearSolve.init_cacheval(
end

function SciMLBase.solve!(
cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization;
cache::LinearCacheType, alg::AppleAccelerate32MixedLUFactorization;
kwargs...
)
__appleaccelerate_isavailable() ||
Expand Down
Loading
Loading