Skip to content

Commit dea5766

Browse files
committed
Merge pull request #8407 from nwh/blas-bug-fix
fix several issues in blas.jl
2 parents a9c18ba + 07ea2f4 commit dea5766

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

base/linalg/blas.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ export
5353

5454
const libblas = Base.libblas_name
5555

56-
import ..LinAlg: BlasFloat, BlasChar, BlasInt, blas_int, DimensionMismatch, chksquare, axpy!
56+
import ..LinAlg: BlasReal, BlasComplex, BlasFloat, BlasChar, BlasInt, blas_int, DimensionMismatch, chksquare, axpy!
5757

5858
# Level 1
5959
## copy
@@ -154,17 +154,17 @@ for (fname, elty) in ((:cblas_zdotu_sub,:Complex128),
154154
end
155155
end
156156
end
157-
function dot{T<:BlasFloat}(DX::StridedArray{T}, DY::StridedArray{T})
157+
function dot{T<:BlasReal}(DX::StridedArray{T}, DY::StridedArray{T})
158158
n = length(DX)
159159
n == length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
160160
dot(n, DX, stride(DX, 1), DY, stride(DY, 1))
161161
end
162-
function dotc{T<:BlasFloat}(DX::StridedArray{T}, DY::StridedArray{T})
162+
function dotc{T<:BlasComplex}(DX::StridedArray{T}, DY::StridedArray{T})
163163
n = length(DX)
164164
n == length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
165165
dotc(n, DX, stride(DX, 1), DY, stride(DY, 1))
166166
end
167-
function dotu{T<:BlasFloat}(DX::StridedArray{T}, DY::StridedArray{T})
167+
function dotu{T<:BlasComplex}(DX::StridedArray{T}, DY::StridedArray{T})
168168
n = length(DX)
169169
n == length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
170170
dotu(n, DX, stride(DX, 1), DY, stride(DY, 1))
@@ -249,7 +249,6 @@ for (fname, elty) in ((:idamax_,:Float64),
249249
(:icamax_,:Complex64))
250250
@eval begin
251251
function iamax(n::BlasInt, dx::Union(StridedVector{$elty}, Ptr{$elty}), incx::BlasInt)
252-
n*incx >= length(x) || throw(DimensionMismatch(""))
253252
ccall(($(string(fname)), libblas),BlasInt,
254253
(Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
255254
&n, dx, &incx)
@@ -319,8 +318,9 @@ for (fname, elty) in ((:dgbmv_,:Float64),
319318
y
320319
end
321320
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
322-
n = stride(A,2)
323-
gbmv!(trans, m, kl, ku, alpha, A, x, zero($elty), similar(x, $elty, n))
321+
n = size(A,2)
322+
leny = trans == 'N' ? m : n
323+
gbmv!(trans, m, kl, ku, alpha, A, x, zero($elty), similar(x, $elty, leny))
324324
end
325325
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer, A::StridedMatrix{$elty}, x::StridedVector{$elty})
326326
gbmv(trans, m, kl, ku, one($elty), A, x)
@@ -366,7 +366,7 @@ end
366366

367367
### hemv
368368
for (fname, elty) in ((:zhemv_,:Complex128),
369-
(:cgemv_,:Complex64))
369+
(:chemv_,:Complex64))
370370
@eval begin
371371
function hemv!(uplo::Char, α::$elty, A::StridedMatrix{$elty}, x::StridedVector{$elty}, β::$elty, y::StridedVector{$elty})
372372
n = size(A, 2)

0 commit comments

Comments
 (0)