Skip to content

Commit b6189c7

Browse files
dkarraschpull[bot]
authored andcommitted
include review comments
1 parent 8d84c4a commit b6189c7

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

stdlib/LinearAlgebra/src/uniformscaling.jl

+9-10
Original file line numberDiff line numberDiff line change
@@ -419,19 +419,18 @@ promote_to_arrays(n,k, ::Type{T}, A, B, Cs...) where {T} =
419419
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays(n,k+2, T, Cs...)...)
420420
promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling,Number}}}) = Matrix
421421

422-
_us2number(A) = A
423-
_us2number(J::UniformScaling) = J.λ
422+
szfun(::UniformScaling, _) = -1
423+
szfun(A, dim) = (require_one_based_indexing(A); return size(A, dim))
424424

425425
for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"))
426426
@eval begin
427427
@inline $f(A::Union{AbstractVecOrMat,UniformScaling}...) = $_f(A...)
428-
@inline $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...) = $f(map(_us2number, A)...)
428+
@inline $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...) = $_f(A...)
429429
function $_f(A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
430430
n = -1
431-
for a in A
432-
if !isa(a, UniformScaling)
433-
require_one_based_indexing(a)
434-
na = size(a,$dim)
431+
sizes = map(a -> szfun(a, $dim), A)
432+
for na in sizes
433+
if na != -1
435434
n >= 0 && n != na &&
436435
throw(DimensionMismatch(string("number of ", $name,
437436
" of each array must match (got ", n, " and ", na, ")")))
@@ -455,9 +454,9 @@ function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScali
455454
j = 0
456455
for i = 1:nr # infer UniformScaling sizes from row counts, if possible:
457456
ni = -1 # number of rows in this block-row, -1 indicates unknown
458-
for k = 1:rows[i]
459-
if !isa(A[j+k], UniformScaling)
460-
na = size(A[j+k], 1)
457+
sizes = map(a -> szfun(a, 1), A[j+1:j+rows[i]])
458+
for na in sizes
459+
if na != -1
461460
ni >= 0 && ni != na &&
462461
throw(DimensionMismatch("mismatch in number of rows"))
463462
ni = na

0 commit comments

Comments
 (0)