Skip to content

Commit f44d26a

Browse files
committed
use more tuples instead of vectors
1 parent 1c77760 commit f44d26a

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

src/blockmap.jl

+25-25
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}}} <: LinearMap{T}
1+
struct BlockMap{T,As<:Tuple{Vararg{LinearMap}},Rs<:Tuple{Vararg{Int}},Rranges<:Tuple{Vararg{UnitRange{Int}}},Cranges<:Tuple{Vararg{UnitRange{Int}}}} <: LinearMap{T}
22
maps::As
33
rows::Rs
4-
rowranges::Vector{UnitRange{Int}}
5-
colranges::Vector{UnitRange{Int}}
4+
rowranges::Rranges
5+
colranges::Cranges
66
function BlockMap{T,R,S}(maps::R, rows::S) where {T, R<:Tuple{Vararg{LinearMap}}, S<:Tuple{Vararg{Int}}}
77
for A in maps
88
promote_type(T, eltype(A)) == T || throw(InexactError())
99
end
1010
rowranges, colranges = rowcolranges(maps, rows)
11-
return new{T,R,S}(maps, rows, rowranges, colranges)
11+
return new{T,R,S,typeof(rowranges),typeof(colranges)}(maps, rows, rowranges, colranges)
1212
end
1313
end
1414

@@ -28,28 +28,28 @@ Determines the range of rows for each block row and the range of columns for eac
2828
map in `maps`, according to its position in a virtual matrix representation of the
2929
block linear map obtained from `hvcat(rows, maps...)`.
3030
"""
31-
function rowcolranges(maps, rows)::Tuple{Vector{UnitRange{Int}},Vector{UnitRange{Int}}}
32-
rowranges = Vector{UnitRange{Int}}(undef, length(rows))
33-
colranges = Vector{UnitRange{Int}}(undef, length(maps))
31+
function rowcolranges(maps, rows)
32+
rowranges = ()
33+
colranges = ()
3434
mapind = 0
3535
rowstart = 1
36-
for rowind in 1:length(rows)
37-
xinds = vcat(1, map(a -> size(a, 2), maps[mapind+1:mapind+rows[rowind]])...)
36+
for row in rows
37+
xinds = vcat(1, map(a -> size(a, 2), maps[mapind+1:mapind+row])...)
3838
cumsum!(xinds, xinds)
3939
mapind += 1
4040
rowend = rowstart + size(maps[mapind], 1) - 1
41-
rowranges[rowind] = rowstart:rowend
42-
colranges[mapind] = xinds[1]:xinds[2]-1
43-
for colind in 2:rows[rowind]
41+
rowranges = (rowranges..., rowstart:rowend)
42+
colranges = (colranges..., xinds[1]:xinds[2]-1)
43+
for colind in 2:row
4444
mapind +=1
45-
colranges[mapind] = xinds[colind]:xinds[colind+1]-1
45+
colranges = (colranges..., xinds[colind]:xinds[colind+1]-1)
4646
end
4747
rowstart = rowend + 1
4848
end
49-
return rowranges, colranges
49+
return rowranges::NTuple{length(rows), UnitRange{Int}}, colranges::NTuple{length(maps), UnitRange{Int}}
5050
end
5151

52-
Base.size(A::BlockMap) = (last(A.rowranges[end]), last(A.colranges[end]))
52+
Base.size(A::BlockMap) = (last(last(A.rowranges)), last(last(A.colranges)))
5353

5454
############
5555
# concatenation
@@ -305,11 +305,11 @@ LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A)
305305
@inline function _blockmul!(y, A::BlockMap, x, α, β)
306306
maps, rows, yinds, xinds = A.maps, A.rows, A.rowranges, A.colranges
307307
mapind = 0
308-
@views @inbounds for rowind in 1:length(rows)
309-
yrow = selectdim(y, 1, yinds[rowind])
308+
@views @inbounds for (row, yi) in zip(rows, yinds)
309+
yrow = selectdim(y, 1, yi)
310310
mapind += 1
311311
mul!(yrow, maps[mapind], selectdim(x, 1, xinds[mapind]), α, β)
312-
for colind in 2:rows[rowind]
312+
for colind in 2:row
313313
mapind +=1
314314
mul!(yrow, maps[mapind], selectdim(x, 1, xinds[mapind]), α, true)
315315
end
@@ -399,23 +399,23 @@ end
399399
# BlockDiagonalMap
400400
############
401401

402-
struct BlockDiagonalMap{T,As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
402+
struct BlockDiagonalMap{T,As<:Tuple{Vararg{LinearMap}},Ranges<:Tuple{Vararg{UnitRange{Int}}}} <: LinearMap{T}
403403
maps::As
404-
rowranges::Vector{UnitRange{Int}}
405-
colranges::Vector{UnitRange{Int}}
404+
rowranges::Ranges
405+
colranges::Ranges
406406
function BlockDiagonalMap{T,As}(maps::As) where {T, As<:Tuple{Vararg{LinearMap}}}
407407
for A in maps
408408
promote_type(T, eltype(A)) == T || throw(InexactError())
409409
end
410410
# row ranges
411411
inds = vcat(1, size.(maps, 1)...)
412412
cumsum!(inds, inds)
413-
rowranges = map(i -> inds[i]:inds[i+1]-1, 1:length(maps))
413+
rowranges = ntuple(i -> inds[i]:inds[i+1]-1, Val(length(maps)))
414414
# column ranges
415415
inds[2:end] .= size.(maps, 2)
416416
cumsum!(inds, inds)
417-
colranges = map(i -> inds[i]:inds[i+1]-1, 1:length(maps))
418-
return new{T,As}(maps, rowranges, colranges)
417+
colranges = ntuple(i -> inds[i]:inds[i+1]-1, Val(length(maps)))
418+
return new{T,As,typeof(rowranges)}(maps, rowranges, colranges)
419419
end
420420
end
421421

@@ -477,7 +477,7 @@ end
477477
@inline function _blockscaling!(y, A::BlockDiagonalMap, x, α, β)
478478
maps, yinds, xinds = A.maps, A.rowranges, A.colranges
479479
# TODO: think about multi-threading here
480-
@views @inbounds for i in 1:length(maps)
480+
@views @inbounds for i in eachindex(yinds, maps, xinds)
481481
mul!(selectdim(y, 1, yinds[i]), maps[i], selectdim(x, 1, xinds[i]), α, β)
482482
end
483483
return y

0 commit comments

Comments
 (0)