Skip to content

Commit 229c3d0

Browse files
kshyattlkdvos
andauthored
Add a disamgiguating conversion (#47)
* Add a disamgiguating conversion * Add comment * Incremental progress on getting rid of allowscalar * Support BlockArrays in KernelAbstractions bcasting with new extension * Add test for different eltype * Update Project.toml * Update src/tensors/tensoroperations.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Move element type test out --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 0b3278f commit 229c3d0

6 files changed

Lines changed: 65 additions & 17 deletions

File tree

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,25 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1717

1818
[weakdeps]
1919
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
20+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
2021

2122
[extensions]
2223
BlockTensorKitAdaptExt = "Adapt"
24+
BlockTensorKitGPUArraysExt = "GPUArrays"
2325

2426
[compat]
2527
Adapt = "4"
2628
Aqua = "0.8"
2729
BlockArrays = "1"
2830
Combinatorics = "1"
2931
Compat = "4.13"
32+
GPUArrays = "11.4.1"
33+
JLArrays = "0.3"
3034
LinearAlgebra = "1"
3135
MatrixAlgebraKit = "0.6"
3236
Random = "1"
3337
SafeTestsets = "0.1"
34-
Strided = "2"
38+
Strided = "2.3.3"
3539
TensorKit = "0.16.1"
3640
TensorOperations = "5"
3741
Test = "1"
@@ -44,10 +48,12 @@ julia = "1.10"
4448
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
4549
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4650
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
51+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
52+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
4753
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4854
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4955
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5056
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5157

5258
[targets]
53-
test = ["Test", "TestExtras", "Random", "Combinatorics", "SafeTestsets", "Aqua", "Adapt"]
59+
test = ["Test", "TestExtras", "Random", "Combinatorics", "SafeTestsets", "Aqua", "Adapt", "JLArrays"]

ext/BlockTensorKitGPUArraysExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module BlockTensorKitGPUArraysExt
2+
3+
using BlockTensorKit, BlockArrays, GPUArrays, Strided
4+
using Strided: StridedViews
5+
using GPUArrays: KernelAbstractions
6+
7+
function KernelAbstractions.get_backend(BA::BlockArrays.BlockArray{T, N, A}) where {T, N, A <: AbstractArray{<:StridedView{T, N, <:AnyGPUArray}}}
8+
return KernelAbstractions.get_backend(first(BA.blocks))
9+
end
10+
11+
end

src/tensors/abstractblocktensor/conversion.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,50 @@
11
# Conversion
22
# ----------
3-
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
4-
S = spacetype(t)
5-
N₁, N₂ = numout(t), numin(t)
6-
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
7-
dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces))
8-
tdst = similar(t, cod dom)
9-
10-
issparse(t) && zerovector!(tdst)
113

4+
function _copy_subblocks!(tdst, tsrc)
5+
S = spacetype(tsrc)
6+
N₁, N₂ = numout(tsrc), numin(tsrc)
127
for ((f₁, f₂), arr) in subblocks(tdst)
138
blockax = ntuple(N₁ + N₂) do i
149
return if i <= N₁
15-
blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(t, i)))
10+
blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(tsrc, i)))
1611
else
17-
blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(t, i)'))
12+
blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(tsrc, i)'))
1813
end
1914
end
2015

21-
for (k, v) in nonzero_pairs(t)
16+
for (k, v) in nonzero_pairs(tsrc)
2217
indices = getindex.(blockax, Block.(Tuple(k)))
2318
arr_slice = arr[indices...]
2419
# need to check for empty since fusion tree pair might not be present
2520
isempty(arr_slice) || copy!(arr_slice, v[f₁, f₂])
2621
end
2722
end
23+
return tdst
24+
end
2825

26+
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
27+
S = spacetype(t)
28+
N₁, N₂ = numout(t), numin(t)
29+
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
30+
dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces))
31+
tdst = TensorKit.TensorMapWithStorage{scalartype(t), storagetype(t)}(undef, cod, dom)
32+
33+
issparse(t) && zerovector!(tdst)
34+
_copy_subblocks!(tdst, t)
2935
return tdst
3036
end
3137

32-
function Base.convert(::Type{T}, t::AbstractBlockTensorMap) where {T <: TensorMap}
33-
tdst = convert(TensorMap, t)
34-
return convert(T, tdst)
38+
function Base.convert(::Type{TT}, t::AbstractBlockTensorMap) where {TT <: TensorKit.TensorMap}
39+
S = spacetype(t)
40+
N₁, N₂ = numout(t), numin(t)
41+
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
42+
dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces))
43+
tdst = TT(undef, cod dom)
44+
issparse(t) && zerovector!(tdst)
45+
46+
_copy_subblocks!(tdst, t)
47+
return tdst
3548
end
3649

3750
function Base.convert(::Type{TT}, t::AbstractTensorMap) where {TT <: AbstractBlockTensorMap}

src/tensors/blocktensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function BlockTensorMap(t::AbstractTensorMap, space::TensorMapSumSpace)
114114
TT = tensormaptype(spacetype(t), numout(t), numin(t), storagetype(t))
115115
tdst = BlockTensorMap{TT}(undef, space)
116116
for (f₁, f₂) in fusiontrees(tdst)
117-
tdst[f₁, f₂] .= t[f₁, f₂]
117+
copy!(tdst[f₁, f₂], t[f₁, f₂])
118118
end
119119
return tdst
120120
end

src/tensors/tensoroperations.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ function TO.tensoradd_type(TC, A::AdjointBlockTensorMap, pA::Index2Tuple, conjA:
1515
return TO.tensoradd_type(TC, A', adjointtensorindices(A, pA), !conjA)
1616
end
1717

18+
function TO.tensorscalar(t::AbstractBlockTensorMap{T, S, 0, 0}) where {T, S}
19+
return nonzero_length(t) == 0 ? zero(T) : TO.tensorscalar(only(nonzero_values(t)))
20+
end
21+
1822
# tensoralloc_contract
1923
# --------------------
2024
for TTA in (:AbstractTensorMap, :AbstractBlockTensorMap), TTB in (:AbstractTensorMap, :AbstractBlockTensorMap)

test/abstracttensor/blocktensor.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using BlockTensorKit
55
using Random
66
using Combinatorics
77
using Adapt
8+
using JLArrays
89

910
Vtr = (
1011
SumSpace(ℂ^3),
@@ -82,7 +83,20 @@ end
8283
t2″ = @inferred BlockTensorMap(t2′, W)
8384
@test t1 t1″
8485
@test t2 t2″
86+
# test conversion to TensorMap that isn't backed by a Vector
87+
jl_bt1 = rand(JLVector{T}, W)
88+
TT = TensorKit.TensorMap{T, spacetype(t1′), numout(t1′), numin(t1′), JLVector{T}}
89+
jl_bt1′ = @constinferred convert(TT, jl_bt1)
90+
jl_bt1″ = @inferred BlockTensorMap(jl_bt1′, W)
91+
@test jl_bt1 jl_bt1″
8592
end
93+
# test conversion to TensorMap with a different element type
94+
t1 = rand(ComplexF32, W)
95+
TT = TensorKit.TensorMap{ComplexF64, spacetype(t1), numout(t1), numin(t1), Vector{ComplexF64}}
96+
t1′ = @constinferred convert(TT, t1)
97+
@test norm(t1) norm(t1′)
98+
t1″ = @inferred BlockTensorMap(t1′, W)
99+
@test t1 t1″
86100
end
87101

88102
@testset "Adapt" begin

0 commit comments

Comments
 (0)