Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Michael Goerz <mail@michaelgoerz.net>"]
version = "0.8.5+dev"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -25,6 +26,7 @@ QuantumPropagatorsRecursiveArrayToolsExt = "RecursiveArrayTools"
QuantumPropagatorsStaticArraysExt = "StaticArrays"

[compat]
ArrayInterface = "7.0"
OffsetArrays = "1"
OrdinaryDiffEq = "6.59"
ProgressMeter = "1"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
DocInventories = "43dc2714-ed3b-44b5-b226-857eda1aa7de"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ links = InterLinks(
"StaticArrays" => "https://juliaarrays.github.io/StaticArrays.jl/stable/",
"ComponentArrays" => "https://sciml.github.io/ComponentArrays.jl/stable/",
"RecursiveArrayTools" => "https://docs.sciml.ai/RecursiveArrayTools/stable/",
"ArrayInterface" => "https://docs.sciml.ai/ArrayInterface/stable/",
"qutip" => "https://qutip.readthedocs.io/en/qutip-5.0.x/",
)

Expand Down
23 changes: 14 additions & 9 deletions src/interfaces/operator.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test

using LinearAlgebra
import ArrayInterface
using ..Controls: get_controls, evaluate


Expand Down Expand Up @@ -52,14 +53,18 @@ for two-dimensional arrays:
* `length(op)` must equal `prod(size(op))`
* `iterate(op)` must be defined
* `similar(op)` must be defined and return a mutable array with the same shape
and element type
and element type. "Mutability" is determined by
[`ArrayInterface.ismutable`](@extref).
* `similar(op, ::Type{S})` must return a mutable array with the same shape and
element type `S`
* `similar(op, dims::Dims)` must return a mutable array with the same element
type and the given dimensions
* `similar(op, ::Type{S}, dims::Dims)` must return a mutable array with the
given element type and dimensions

The read-write method `setindex!` is not a requirement for
`supports_matrix_interface`.

The function returns `true` for a valid operator and `false` for an invalid
operator. Unless `quiet=true`, it will log an error to indicate which of the
conditions failed.
Expand Down Expand Up @@ -328,9 +333,9 @@ function check_operator(

try
op2 = similar(op)
if !supports_inplace(op2)
if !ArrayInterface.ismutable(op2)
quiet ||
@error "$(px)`similar(op)` must return a mutable array (`supports_inplace` must be `true`), got $(typeof(op2))"
@error "$(px)`similar(op)` must return a mutable array (`ArrayInterface.ismutable` must be `true`), got $(typeof(op2))"
success = false
end
if size(op2) != size(op)
Expand All @@ -354,9 +359,9 @@ function check_operator(
try
S = (eltype(op) == ComplexF64) ? ComplexF32 : ComplexF64
op2 = similar(op, S)
if !supports_inplace(op2)
if !ArrayInterface.ismutable(op2)
quiet ||
@error "$(px)`similar(op, $S)` must return a mutable array (`supports_inplace` must be `true`), got $(typeof(op2))"
@error "$(px)`similar(op, $S)` must return a mutable array (`ArrayInterface.ismutable` must be `true`), got $(typeof(op2))"
success = false
end
if size(op2) != size(op)
Expand All @@ -380,9 +385,9 @@ function check_operator(
try
dims = size(op)
op2 = similar(op, dims)
if !supports_inplace(op2)
if !ArrayInterface.ismutable(op2)
quiet ||
@error "$(px)`similar(op, dims)` must return a mutable array (`supports_inplace` must be `true`), got $(typeof(op2))"
@error "$(px)`similar(op, dims)` must return a mutable array (`ArrayInterface.ismutable` must be `true`), got $(typeof(op2))"
success = false
end
if size(op2) != dims
Expand All @@ -407,9 +412,9 @@ function check_operator(
S = (eltype(op) == ComplexF64) ? ComplexF32 : ComplexF64
dims = size(op)
op2 = similar(op, S, dims)
if !supports_inplace(op2)
if !ArrayInterface.ismutable(op2)
quiet ||
@error "$(px)`similar(op, $S, dims)` must return a mutable array (`supports_inplace` must be `true`), got $(typeof(op2))"
@error "$(px)`similar(op, $S, dims)` must return a mutable array (`ArrayInterface.ismutable` must be `true`), got $(typeof(op2))"
success = false
end
if size(op2) != dims
Expand Down
24 changes: 13 additions & 11 deletions src/interfaces/state.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test

import ArrayInterface
using LinearAlgebra


Expand Down Expand Up @@ -36,8 +37,8 @@ Any `state` must support the following not-in-place operations:
If `supports_inplace(state)` is `true`, the `state` must also support the
following:

* `similar(state)` must be defined and return a valid state of the same type a
`state`
* `similar(state)` must be defined and return a valid state of the same type as
the original `state`
* `copyto!(other, state)` must be defined
* `fill!(state, c)` must be defined
* `LinearAlgebra.lmul!(c, state)` for a scalar `c` must be defined
Expand All @@ -60,7 +61,8 @@ for one-dimensional arrays:
* `length(state)` must equal `prod(size(state))`
* `iterate(state)` must be defined
* `similar(state)` must be defined and return a mutable vector with the same
length and element type.
length and element type. "Mutability" is determined by
[`ArrayInterface.ismutable`](@extref).
* `similar(state, ::Type{S})` must return a mutable vector with the same length
and element type `S`
* `similar(state, dims::Dims)` must return a mutable array with the same
Expand Down Expand Up @@ -454,9 +456,9 @@ function check_state(

try
st2 = similar(state)
if !supports_inplace(st2)
if !ArrayInterface.ismutable(st2)
quiet ||
@error "$(px)`similar(state)` must return a mutable vector (`supports_inplace` must be `true`), got $(typeof(st2))"
@error "$(px)`similar(state)` must return a mutable vector (`ArrayInterface.ismutable` must be `true`), got $(typeof(st2))"
success = false
end
if size(st2) != size(state)
Expand All @@ -480,9 +482,9 @@ function check_state(
try
S = (eltype(state) == ComplexF64) ? ComplexF32 : ComplexF64
st2 = similar(state, S)
if !supports_inplace(st2)
if !ArrayInterface.ismutable(st2)
quiet ||
@error "$(px)`similar(state, $S)` must return a mutable vector (`supports_inplace` must be `true`), got $(typeof(st2))"
@error "$(px)`similar(state, $S)` must return a mutable vector (`ArrayInterface.ismutable` must be `true`), got $(typeof(st2))"
success = false
end
if size(st2) != size(state)
Expand All @@ -506,9 +508,9 @@ function check_state(
try
dims = size(state)
st2 = similar(state, dims)
if !supports_inplace(st2)
if !ArrayInterface.ismutable(st2)
quiet ||
@error "$(px)`similar(state, dims)` must return a mutable array (`supports_inplace` must be `true`), got $(typeof(st2))"
@error "$(px)`similar(state, dims)` must return a mutable array (`ArrayInterface.ismutable` must be `true`), got $(typeof(st2))"
success = false
end
if size(st2) != dims
Expand All @@ -533,9 +535,9 @@ function check_state(
S = (eltype(state) == ComplexF64) ? ComplexF32 : ComplexF64
dims = size(state)
st2 = similar(state, S, dims)
if !supports_inplace(st2)
if !ArrayInterface.ismutable(st2)
quiet ||
@error "$(px)`similar(state, $S, dims)` must return a mutable array (`supports_inplace` must be `true`), got $(typeof(st2))"
@error "$(px)`similar(state, $S, dims)` must return a mutable array (`ArrayInterface.ismutable` must be `true`), got $(typeof(st2))"
success = false
end
if size(st2) != dims
Expand Down
31 changes: 19 additions & 12 deletions src/interfaces/supports_inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import ..Operator
import ..ScaledOperator
import LinearAlgebra
import SparseArrays: SparseMatrixCSC
import ArrayInterface

"""Indicate whether a type supports in-place operations.

Expand All @@ -22,27 +23,33 @@ see [`QuantumPropagators.Interfaces.check_operator`](@ref).

For operators, a `true` result indicates that the operator can be evaluated
in-place with [`evaluate!`](@ref), see
[`QuantumPropagators.Interfaces.check_generator`](@ref).

Note that `supports_inplace` is not quite the same as
[`Base.ismutabletype`](@extref) and/or [`Base.ismutable`](@extref): When using
[custom structs](@extref Julia :label:`Mutable-Composite-Types`) for states or
[`QuantumPropagators.Interfaces.check_generator`](@ref). Again, this is
intended only as an indicator for what assumptions can be made in the
implementation of a particular propagator: `supports_inplace` is semantically
separate from [`Base.ismutabletype`](@extref), [`Base.ismutable`](@extref), or
similar "traits": When using [custom structs](@extref Julia
:label:`Mutable-Composite-Types`) for states or
operators, even if those structs are not defined as `mutable`, they may still
define the in-place interface (typically because their *components* are
mutable).
mutable). Conversely, even types that are "mutable" may want to opt out
Comment thread
goerz marked this conversation as resolved.
of `evaluate!` for performance reasons.

Mutable abstract arrays ([`ArrayInterface.ismutable`](@extref)) without
considerable performance issues
([`ArrayInterface.fast_scalar_indexing`](@extref))
should support in-place operations.
"""
supports_inplace(::Type{<:Vector{ComplexF64}}) = true
supports_inplace(::Type{T}) where {T<:AbstractVector} = ismutabletype(T) # fallback
# The fallback doesn't actually guarantee that the required interface implied
# by `supports_inplace` is fulfilled, but it's a reasonable expectation to
# have, and the `check_state` function will test it.

supports_inplace(::Type{<:Matrix}) = true
supports_inplace(::Type{<:Operator}) = true
supports_inplace(::Type{<:LinearAlgebra.Diagonal}) = true
supports_inplace(::Type{<:SparseMatrixCSC}) = true
supports_inplace(::Type{<:SparseMatrixCSC}) = true # XXX is this a good idea?
supports_inplace(::Type{<:ScaledOperator{<:Any,OT}}) where {OT} = supports_inplace(OT)
supports_inplace(::Type{T}) where {T<:AbstractMatrix} = ismutabletype(T) # fallback

# Fallback (both for operators and states)
supports_inplace(::Type{T}) where {T<:AbstractArray} =
ArrayInterface.ismutable(T) && ArrayInterface.fast_scalar_indexing(T)

# Generic catch-all for types without a specific method (prevents StackOverflow
# from the value→type fallback below)
Expand Down
7 changes: 6 additions & 1 deletion src/interfaces/supports_matrix_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ returns `true` if `T` implements the
for two-dimensional arrays. This is `true` for all subtypes of
`AbstractMatrix`, but may also be `true` for types that implement an array
interface (`size`, `getindex`, etc.) without declaring themselves subtypes
of `AbstractMatrix`. Calling `supports_matrix_interface` on an instance
of `AbstractMatrix`.

A `setindex!` method is not a requirement for `supports_matrix_interface`. That
is, only the read-interface for matrices is enforced.

Calling `supports_matrix_interface` on an instance
`x` also works via a convenience fallback that forwards to
`supports_matrix_interface(typeof(x))`.

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
Expand Down
10 changes: 2 additions & 8 deletions test/test_invalid_interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,7 @@ end
end
@test captured.value ≡ false
# similar(op): wrong mutability, shape, and eltype
@test contains(
captured.output,
"`similar(op)` must return a mutable array (`supports_inplace` must be `true`)"
)
@test contains(captured.output, "`similar(op)` must return a mutable array")
@test contains(
captured.output,
"`similar(op)` must return an array with the same shape"
Expand Down Expand Up @@ -627,10 +624,7 @@ end
end
@test captured.value ≡ false
# similar(state): wrong mutability, shape, and eltype
@test contains(
captured.output,
"`similar(state)` must return a mutable vector (`supports_inplace` must be `true`)"
)
@test contains(captured.output, "`similar(state)` must return a mutable vector")
@test contains(
captured.output,
"`similar(state)` must return a vector with the same shape"
Expand Down
28 changes: 27 additions & 1 deletion test/test_operator_linalg.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Test
using LinearAlgebra
using QuantumControlTestUtils.RandomObjects: random_matrix, random_state_vector
using QuantumPropagators.Interfaces: check_operator, supports_matrix_interface
using QuantumPropagators.Interfaces:
check_operator, check_generator, supports_matrix_interface
import QuantumPropagators.Interfaces: supports_inplace
import QuantumPropagators.Controls: get_controls, evaluate
import ArrayInterface
using StaticArrays: SMatrix, SVector

using QuantumPropagators: Generator, Operator, ScaledOperator
Expand Down Expand Up @@ -332,3 +334,27 @@ end
@test check_operator(SFreeOp; state = Ψ)

end


@testset "Hermitian matrix supports in-place operations" begin

# Test the resolution of
# https://github.com/JuliaQuantumControl/QuantumPropagators.jl/issues/102

Ψ0 = ComplexF64[1, 0]
Ĥ = ComplexF64[
0 0.5
0.5 0
]
tlist = collect(range(0.0, 1.0, length = 101))
generator = (Hermitian(Ĥ),)
op = evaluate(generator, tlist, 1)
T = Hermitian{ComplexF64,Matrix{ComplexF64}}
@test op isa T
@test supports_inplace(T)
op2 = similar(op)
@test op2 isa T
@test ArrayInterface.ismutable(T)
@test check_generator(generator; state = Ψ0, tlist)

end