1
- struct BlockMap{T,As<: Tuple{Vararg{LinearMap}} ,Rs<: Tuple{Vararg{Int}} } <: LinearMap{T}
1
+ struct BlockMap{T,As<: Tuple{Vararg{LinearMap}} ,Rs<: Tuple{Vararg{Int}} ,Rranges <: Tuple{Vararg{UnitRange{Int}}} ,Cranges <: Tuple{Vararg{UnitRange{Int}}} } <: LinearMap{T}
2
2
maps:: As
3
3
rows:: Rs
4
- rowranges:: Vector{UnitRange{Int}}
5
- colranges:: Vector{UnitRange{Int}}
4
+ rowranges:: Rranges
5
+ colranges:: Cranges
6
6
function BlockMap {T,R,S} (maps:: R , rows:: S ) where {T, R<: Tuple{Vararg{LinearMap}} , S<: Tuple{Vararg{Int}} }
7
7
for A in maps
8
8
promote_type (T, eltype (A)) == T || throw (InexactError ())
9
9
end
10
10
rowranges, colranges = rowcolranges (maps, rows)
11
- return new {T,R,S} (maps, rows, rowranges, colranges)
11
+ return new {T,R,S,typeof(rowranges),typeof(colranges) } (maps, rows, rowranges, colranges)
12
12
end
13
13
end
14
14
@@ -28,28 +28,28 @@ Determines the range of rows for each block row and the range of columns for eac
28
28
map in `maps`, according to its position in a virtual matrix representation of the
29
29
block linear map obtained from `hvcat(rows, maps...)`.
30
30
"""
31
- function rowcolranges (maps, rows):: Tuple{Vector{UnitRange{Int}},Vector{UnitRange{Int}}}
32
- rowranges = Vector {UnitRange{Int}} (undef, length (rows) )
33
- colranges = Vector {UnitRange{Int}} (undef, length (maps) )
31
+ function rowcolranges (maps, rows)
32
+ rowranges = ( )
33
+ colranges = ( )
34
34
mapind = 0
35
35
rowstart = 1
36
- for rowind in 1 : length ( rows)
37
- xinds = vcat (1 , map (a -> size (a, 2 ), maps[mapind+ 1 : mapind+ rows[rowind] ])... )
36
+ for row in rows
37
+ xinds = vcat (1 , map (a -> size (a, 2 ), maps[mapind+ 1 : mapind+ row ])... )
38
38
cumsum! (xinds, xinds)
39
39
mapind += 1
40
40
rowend = rowstart + size (maps[mapind], 1 ) - 1
41
- rowranges[rowind] = rowstart: rowend
42
- colranges[mapind] = xinds[1 ]: xinds[2 ]- 1
43
- for colind in 2 : rows[rowind]
41
+ rowranges = (rowranges ... , rowstart: rowend)
42
+ colranges = (colranges ... , xinds[1 ]: xinds[2 ]- 1 )
43
+ for colind in 2 : row
44
44
mapind += 1
45
- colranges[mapind] = xinds[colind]: xinds[colind+ 1 ]- 1
45
+ colranges = (colranges ... , xinds[colind]: xinds[colind+ 1 ]- 1 )
46
46
end
47
47
rowstart = rowend + 1
48
48
end
49
- return rowranges, colranges
49
+ return rowranges:: NTuple{length(rows), UnitRange{Int}} , colranges:: NTuple{length(maps), UnitRange{Int}}
50
50
end
51
51
52
- Base. size (A:: BlockMap ) = (last (A. rowranges[ end ]) , last (A. colranges[ end ] ))
52
+ Base. size (A:: BlockMap ) = (last (last ( A. rowranges)) , last (last ( A. colranges) ))
53
53
54
54
# ###########
55
55
# concatenation
@@ -305,11 +305,11 @@ LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A)
305
305
@inline function _blockmul! (y, A:: BlockMap , x, α, β)
306
306
maps, rows, yinds, xinds = A. maps, A. rows, A. rowranges, A. colranges
307
307
mapind = 0
308
- @views @inbounds for rowind in 1 : length (rows)
309
- yrow = selectdim (y, 1 , yinds[rowind] )
308
+ @views @inbounds for (row, yi) in zip (rows, yinds )
309
+ yrow = selectdim (y, 1 , yi )
310
310
mapind += 1
311
311
mul! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, β)
312
- for colind in 2 : rows[rowind]
312
+ for colind in 2 : row
313
313
mapind += 1
314
314
mul! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, true )
315
315
end
@@ -399,23 +399,23 @@ end
399
399
# BlockDiagonalMap
400
400
# ###########
401
401
402
- struct BlockDiagonalMap{T,As<: Tuple{Vararg{LinearMap}} } <: LinearMap{T}
402
+ struct BlockDiagonalMap{T,As<: Tuple{Vararg{LinearMap}} ,Ranges <: Tuple{Vararg{UnitRange{Int}}} } <: LinearMap{T}
403
403
maps:: As
404
- rowranges:: Vector{UnitRange{Int}}
405
- colranges:: Vector{UnitRange{Int}}
404
+ rowranges:: Ranges
405
+ colranges:: Ranges
406
406
function BlockDiagonalMap {T,As} (maps:: As ) where {T, As<: Tuple{Vararg{LinearMap}} }
407
407
for A in maps
408
408
promote_type (T, eltype (A)) == T || throw (InexactError ())
409
409
end
410
410
# row ranges
411
411
inds = vcat (1 , size .(maps, 1 )... )
412
412
cumsum! (inds, inds)
413
- rowranges = map (i -> inds[i]: inds[i+ 1 ]- 1 , 1 : length (maps))
413
+ rowranges = ntuple (i -> inds[i]: inds[i+ 1 ]- 1 , Val ( length (maps) ))
414
414
# column ranges
415
415
inds[2 : end ] .= size .(maps, 2 )
416
416
cumsum! (inds, inds)
417
- colranges = map (i -> inds[i]: inds[i+ 1 ]- 1 , 1 : length (maps))
418
- return new {T,As} (maps, rowranges, colranges)
417
+ colranges = ntuple (i -> inds[i]: inds[i+ 1 ]- 1 , Val ( length (maps) ))
418
+ return new {T,As,typeof(rowranges) } (maps, rowranges, colranges)
419
419
end
420
420
end
421
421
477
477
@inline function _blockscaling! (y, A:: BlockDiagonalMap , x, α, β)
478
478
maps, yinds, xinds = A. maps, A. rowranges, A. colranges
479
479
# TODO : think about multi-threading here
480
- @views @inbounds for i in 1 : length ( maps)
480
+ @views @inbounds for i in eachindex (yinds, maps, xinds )
481
481
mul! (selectdim (y, 1 , yinds[i]), maps[i], selectdim (x, 1 , xinds[i]), α, β)
482
482
end
483
483
return y
0 commit comments