Skip to content

Commit b1b36a4

Browse files
authored
Merge branch 'main' into ksh/cuda_tweaks
2 parents fecb81d + 518cb2a commit b1b36a4

14 files changed

Lines changed: 476 additions & 107 deletions

File tree

.github/workflows/CompatCheck.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
matrix:
5050
downgrade_mode: ['deps']
5151
group: ${{ fromJSON(needs.setup-matrix.outputs.groups) }}
52-
julia-version: ['1', '1.10']
52+
julia-version: ['1.10']
5353
steps:
5454
- uses: actions/checkout@v6
5555
- uses: julia-actions/setup-julia@v3

CITATION.cff

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ authors:
88
given-names: "Jutho"
99
orcid: "https://orcid.org/0000-0002-0858-291X"
1010
title: "TensorKit.jl"
11-
version: "0.16.3"
11+
version: "0.16.4"
1212
doi: "10.5281/zenodo.8421339"
13-
date-released: "2026-02-22"
13+
date-released: "2026-04-23"
1414
url: "https://github.com/QuantumKitHub/TensorKit.jl"
1515
preferred-citation:
1616
type: article

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorKit"
22
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
3-
version = "0.16.3"
3+
version = "0.17.0"
44
authors = ["Jutho Haegeman, Lukas Devos"]
55

66
[deps]

docs/src/Changelog.md

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,42 @@ When making changes to this project, please update the "Unreleased" section with
1818

1919
When releasing a new version, move the "Unreleased" changes to a new version section with the release date.
2020

21-
## [Unreleased](https://github.com/QuantumKitHub/TensorKit.jl/compare/v0.16.3...HEAD)
21+
## [Unreleased](https://github.com/QuantumKitHub/TensorKit.jl/compare/v0.16.4...HEAD)
2222

2323
### Added
2424

25-
2625
### Changed
2726

28-
2927
### Deprecated
3028

31-
3229
### Removed
3330

31+
### Fixed
32+
33+
### Performance
34+
35+
## [0.16.4](https://github.com/QuantumKitHub/TensorKit.jl/compare/v0.16.3...v0.16.4) - 2026-04-23
36+
37+
### Added
38+
39+
- Partial tensor support for AMDGPU via a new extension ([#341](https://github.com/QuantumKitHub/TensorKit.jl/pull/341))
40+
- Define `spacetype` for `TruncationSpace` ([#403](https://github.com/QuantumKitHub/TensorKit.jl/pull/403))
41+
42+
### Changed
43+
44+
- Updated MatrixAlgebraKit dependency to v0.6.5 with corresponding API updates ([#390](https://github.com/QuantumKitHub/TensorKit.jl/pull/390))
3445

3546
### Fixed
3647

48+
- Fix ignored `adjoint` flag in `BraidingTensor` ([#392](https://github.com/QuantumKitHub/TensorKit.jl/pull/392))
49+
- Fix `MethodError` for certain tensor operations ([#406](https://github.com/QuantumKitHub/TensorKit.jl/pull/406))
50+
- Add square checks for `project_(anti)hermitian` and eigenvalue decompositions ([#408](https://github.com/QuantumKitHub/TensorKit.jl/pull/408))
51+
52+
### Performance
53+
54+
- Vectorize fusiontree manipulations ([#261](https://github.com/QuantumKitHub/TensorKit.jl/pull/261))
55+
- Avoid generic matmul fallback in transformation kernel ([#378](https://github.com/QuantumKitHub/TensorKit.jl/pull/378))
56+
- Reduce cache footprint by decoupling degeneracy-dependent data ([#387](https://github.com/QuantumKitHub/TensorKit.jl/pull/387))
3757

3858
## [0.16.3](https://github.com/QuantumKitHub/TensorKit.jl/compare/v0.16.2...v0.16.3) - 2026-02-22
3959

ext/TensorKitAdaptExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
1515
data′ = adapt(to, x.data)
1616
return DiagonalTensorMap(data′, x.domain)
1717
end
18-
function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}}
19-
return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint)
18+
function Adapt.adapt_structure(::Type{T}, x::BraidingTensor{T′, S, A}) where {T <: Number, T′, S, A}
19+
A′ = TensorKit.similarstoragetype(A, T)
20+
return BraidingTensor{T, S, A′}(space(x), x.adjoint)
21+
end
22+
function Adapt.adapt_structure(::Type{TA}, x::BraidingTensor{T, S, A}) where {T′, TA <: DenseArray{T′}, T, S, A}
23+
return BraidingTensor{T′, S, TA}(space(x), x.adjoint)
2024
end
2125

2226
end

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@ module TensorKitCUDAExt
33
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
44
using CUDA: @allowscalar
55
using cuTENSOR: cuTENSOR
6+
using Strided: StridedViews
67
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
8+
using CUDA.KernelAbstractions: @kernel, @index, get_backend
79

810
using TensorKit
911
using TensorKit.Factorizations
1012
using TensorKit.Strided
1113
using TensorKit.Factorizations: AbstractAlgorithm
1214
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13-
import TensorKit: randisometry, rand, randn
15+
import TensorKit: randisometry, rand, randn, fill_braidingsubblock!
1416

1517
using TensorKit: MatrixAlgebraKit
1618

@@ -19,4 +21,18 @@ using Random
1921
include("cutensormap.jl")
2022
include("truncation.jl")
2123

24+
function TensorKit.fill_braidingsubblock!(data::TD, val) where {T, TD <: Union{<:CuMatrix{T}, <:StridedViews.StridedView{T, 4, <:CuArray{T}}}}
25+
# COV_EXCL_START
26+
# kernels are not reachable by coverage
27+
@kernel function fill_subblock_kernel!(subblock, val)
28+
idx = @index(Global, Cartesian)
29+
idx_val = idx[1] == idx[4] && idx[2] == idx[3] ? val : zero(val)
30+
@inbounds subblock[idx] = idx_val
31+
end
32+
# COV_EXCL_STOP
33+
kernel = fill_subblock_kernel!(get_backend(data))
34+
kernel(data, val; ndrange = size(data))
35+
return data
36+
end
37+
2238
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,7 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
158158
return tf
159159
end
160160
end
161+
162+
function TensorKit._add_transform_multi!(tdst::CuTensorMap, tsrc, p, (U, structs_dst, structs_src)::Tuple{<:Array, TD, TS}, buffers, alpha, beta, backend...) where {TD, TS}
163+
return TensorKit._add_transform_multi!(tdst, tsrc, p, (CUDA.Adapt.adapt(CuArray, U), structs_dst, structs_src), buffers, alpha, beta, backend...)
164+
end

src/planar/preprocessors.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ _add_adjoint(ex) = Expr(TO.prime, ex)
8383
# spaces from the rest of the expression. Construct the explicit BraidingTensor objects and
8484
# insert them in the expression.
8585
function _construct_braidingtensors(ex)
86+
function filter_f(expr)
87+
if TO.istensor(expr)
88+
return _remove_adjoint(TO.decomposetensor(expr)[1]) !=
89+
elseif TO.istensorexpr(expr)
90+
return any(filter_f, expr.args)
91+
else
92+
return false
93+
end
94+
end
95+
function extract_tensors(tensor_ex)
96+
if TO.istensor(tensor_ex)
97+
return [TO.decomposetensor(tensor_ex)[1]]
98+
elseif TO.istensorexpr(tensor_ex)
99+
return collect(Iterators.flatmap(extract_tensors, filter(filter_f, tensor_ex.args)))
100+
end
101+
end
102+
# get storagetype
86103
ex isa Expr || return ex
87104
if ex.head == :macrocall && ex.args[1] == Symbol("@notensor")
88105
return ex
@@ -104,7 +121,9 @@ function _construct_braidingtensors(ex)
104121
)
105122
end
106123
end
107-
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap)
124+
# if this is a definition, the lhs tensor is NOT yet defined
125+
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, rhs.args)); init = Symbol[])
126+
newrhs, success = _construct_braidingtensors!(rhs, preargs, indexmap, no_τ_ex)
108127
success ||
109128
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
110129
pre = Expr(
@@ -115,7 +134,8 @@ function _construct_braidingtensors(ex)
115134
elseif TO.istensorexpr(ex)
116135
preargs = Vector{Any}()
117136
indexmap = Dict{Any, Any}()
118-
newex, success = _construct_braidingtensors!(ex, preargs, indexmap)
137+
no_τ_ex = reduce(vcat, Iterators.flatmap(extract_tensors, filter(filter_f, ex.args)); init = Symbol[])
138+
newex, success = _construct_braidingtensors!(ex, preargs, indexmap, no_τ_ex)
119139
success ||
120140
throw(ArgumentError("cannot determine the spaces of all braiding tensors in $ex"))
121141
pre = Expr(
@@ -128,7 +148,7 @@ function _construct_braidingtensors(ex)
128148
end
129149
end
130150

131-
function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed to be a single tensor expression
151+
function _construct_braidingtensors!(ex, preargs, indexmap, non_braiding) # ex is guaranteed to be a single tensor expression
132152
if TO.isscalarexpr(ex)
133153
# ex could be tensorscalar call with more braiding tensors
134154
return _construct_braidingtensors(ex), true
@@ -163,7 +183,9 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
163183
end
164184
if foundV1 && foundV2
165185
s = gensym()
166-
constructex = Expr(:call, GlobalRef(TensorKit, :BraidingTensor), V1, V2)
186+
storageex = Expr(:call, GlobalRef(TensorKit, :promote_storagetype), non_braiding...)
187+
braidingex = Expr(:call, GlobalRef(TensorKit, :braidingtensortype), V1, V2, storageex)
188+
constructex = Expr(:call, braidingex, V1, V2)
167189
push!(preargs, Expr(:(=), s, constructex))
168190
obj = _is_adjoint(obj) ? _add_adjoint(s) : s
169191
success = true
@@ -196,7 +218,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
196218
newargs = Vector{Any}(undef, length(args))
197219
success = true
198220
for i in 1:length(ex.args)
199-
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap)
221+
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmap, non_braiding)
200222
success = success && successa
201223
end
202224
newex = Expr(ex.head, newargs...)
@@ -212,7 +234,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
212234
for i in 2:length(ex.args)
213235
successes[i] && continue
214236
newargs[i], successa = _construct_braidingtensors!(
215-
args[i], preargs, indexmap
237+
args[i], preargs, indexmap, non_braiding
216238
)
217239
successes[i] = successa
218240
end
@@ -232,7 +254,7 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
232254
indices = [TO.getindices(arg) for arg in args]
233255
for i in 2:length(ex.args)
234256
indexmapa = copy(indexmap)
235-
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa)
257+
newargs[i], successa = _construct_braidingtensors!(args[i], preargs, indexmapa, non_braiding)
236258
for l in indices[i]
237259
if !haskey(indexmap, l) && haskey(indexmapa, l)
238260
indexmap[l] = indexmapa[l]
@@ -243,10 +265,10 @@ function _construct_braidingtensors!(ex, preargs, indexmap) # ex is guaranteed t
243265
newex = Expr(ex.head, newargs...)
244266
return newex, success
245267
elseif isexpr(ex, :call) && ex.args[1] == :/ && length(ex.args) == 3
246-
newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap)
268+
newarg, success = _construct_braidingtensors!(ex.args[2], preargs, indexmap, non_braiding)
247269
return Expr(:call, :/, newarg, ex.args[3]), success
248270
elseif isexpr(ex, :call) && ex.args[1] == :\ && length(ex.args) == 3
249-
newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap)
271+
newarg, success = _construct_braidingtensors!(ex.args[3], preargs, indexmap, non_braiding)
250272
return Expr(:call, :\, ex.args[2], newarg), success
251273
else
252274
error("unexpected expression $ex")

src/spaces/vectorspaces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ while the second will return the spacetype if all types are equal, and throw a [
409409
"""
410410
check_spacetype(::Type{Bool}, x, y, z...) = _allequal(spacetype, (x, y, z...))
411411
@noinline function check_spacetype(x, y, z...)
412-
check_spacetype(Bool, x, y, z...) || throw(SpaceMismatch("incompatible space types"))
412+
check_spacetype(Bool, x, y, z...) || throw(SpaceMismatch(lazy"incompatible space types $(type_repr.(spacetype.((x, y, z...))))"))
413413
return spacetype(x)
414414
end
415415

0 commit comments

Comments
 (0)