Skip to content

Commit 6be8f44

Browse files
lkdvoskshyatt
authored andcommitted
change blocktype of TensorMap to StridedView
1 parent d04b475 commit 6be8f44

2 files changed

Lines changed: 19 additions & 12 deletions

File tree

src/spaces/homspace.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int}
299299

300300
struct FusionBlockStructure{I, N, F₁, F₂}
301301
totaldim::Int
302-
blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}
302+
blockstructure::SectorDict{I, StridedStructure{2}}
303303
fusiontreelist::Vector{Tuple{F₁, F₂}}
304304
fusiontreestructure::Vector{StridedStructure{N}}
305305
fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int}
@@ -325,9 +325,9 @@ end
325325
F₂ = fusiontreetype(I, N₂)
326326

327327
# output structure
328-
blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() # size, range
328+
blockstructure = SectorDict{I, StridedStructure{2}}() # size, strides, offset
329329
fusiontreelist = Vector{Tuple{F₁, F₂}}()
330-
fusiontreestructure = Vector{Tuple{NTuple{N₁ + N₂, Int}, NTuple{N₁ + N₂, Int}, Int}}() # size, strides, offset
330+
fusiontreestructure = Vector{StridedStructure{N₁ + N₂}}() # size, strides, offset
331331

332332
# temporary data structures
333333
splittingtrees = Vector{F₁}()
@@ -367,8 +367,8 @@ end
367367
blocksize = (blockdim₁, blockdim₂)
368368
blocklength = blockdim₁ * blockdim₂
369369
blockrange = (blockoffset + 1):(blockoffset + blocklength)
370+
blockstructure[c] = (blocksize, strides, blockoffset)
370371
blockoffset = last(blockrange)
371-
blockstructure[c] = (blocksize, blockrange)
372372
end
373373

374374
fusiontreeindices = sizehint!(

src/tensors/tensor.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -458,28 +458,35 @@ blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)
458458
function blocktype(::Type{TT}) where {TT <: TensorMap}
459459
A = storagetype(TT)
460460
T = eltype(A)
461-
return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}
461+
@static if isdefined(Core, :Memory) # StridedViews normalizes parent types!
462+
if A <: Vector{T}
463+
A = GenericMemory{T}
464+
end
465+
end
466+
return StridedView{T, 2, A, typeof(identity)}
462467
end
463468

464469
function Base.iterate(iter::BlockIterator{<:TensorMap}, state...)
465470
next = iterate(iter.structure, state...)
466471
isnothing(next) && return next
467-
(c, (sz, r)), newstate = next
468-
return c => reshape(view(iter.t.data, r), sz), newstate
472+
(c, (sz, str, offset)), newstate = next
473+
return c => StridedView(iter.t.data, sz, str, offset), newstate
469474
end
470475

471476
function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector)
472477
sectortype(iter.t) === typeof(c) || throw(SectorMismatch())
473-
(d₁, d₂), r = get(iter.structure, c) do
474-
# is s is not a key, at least one of the two dimensions will be zero:
478+
(d₁, d₂), (s₁, s₂), offset = get(iter.structure, c) do
479+
# is c is not a key, at least one of the two dimensions will be zero:
475480
# it then does not matter where exactly we construct a view in `t.data`,
476481
# as it will have length zero anyway
477482
d₁′ = blockdim(codomain(iter.t), c)
478483
d₂′ = blockdim(domain(iter.t), c)
479-
l = d₁′ * d₂′
480-
return (d₁′, d₂′), 1:l
484+
s₁ = 1
485+
s₂ = 0
486+
offset = 0
487+
return (d₁′, d₂′), (s₁, s₂), offset
481488
end
482-
return reshape(view(iter.t.data, r), (d₁, d₂))
489+
return StridedView(iter.t.data, (d₁, d₂), (s₁, s₂), offset)
483490
end
484491

485492
# Getting and setting the data at the subblock level

0 commit comments

Comments
 (0)