Skip to content

Commit 2c63b5d

Browse files
committed
Merge pull request #9701 from spencerlyon2/ordschur_gen
RFC: Ordering by Generalized Eigenvalues for Generalized Schur methods
2 parents 64ad746 + 0223400 commit 2c63b5d

File tree

4 files changed

+245
-76
lines changed

4 files changed

+245
-76
lines changed

base/linalg/factorization.jl

+5
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,11 @@ schurfact!{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = Generalized
732732
schurfact{T<:BlasFloat}(A::StridedMatrix{T},B::StridedMatrix{T}) = schurfact!(copy(A),copy(B))
733733
schurfact{TA,TB}(A::StridedMatrix{TA}, B::StridedMatrix{TB}) = (S = promote_type(Float32,typeof(one(TA)/norm(one(TA))),TB); schurfact!(S != TA ? convert(AbstractMatrix{S},A) : copy(A), S != TB ? convert(AbstractMatrix{S},B) : copy(B)))
734734

735+
ordschur!{Ty<:BlasFloat}(S::StridedMatrix{Ty}, T::StridedMatrix{Ty}, Q::StridedMatrix{Ty}, Z::StridedMatrix{Ty}, select::Array{Int}) = GeneralizedSchur(LinAlg.LAPACK.tgsen!(select, S, T, Q, Z)...)
736+
ordschur{Ty<:BlasFloat}(S::StridedMatrix{Ty}, T::StridedMatrix{Ty}, Q::StridedMatrix{Ty}, Z::StridedMatrix{Ty}, select::Array{Int}) = ordschur!(copy(S), copy(T), copy(Q), copy(Z), select)
737+
ordschur!{Ty<:BlasFloat}(gschur::GeneralizedSchur{Ty}, select::Array{Int}) = (res=ordschur!(gschur.S, gschur.T, gschur.Q, gschur.Z, select); gschur[:alpha][:]=res[:alpha]; gschur[:beta][:]=res[:beta]; res)
738+
ordschur{Ty<:BlasFloat}(gschur::GeneralizedSchur{Ty}, select::Array{Int}) = ordschur(gschur.S, gschur.T, gschur.Q, gschur.Z, select)
739+
735740
function getindex(F::GeneralizedSchur, d::Symbol)
736741
d == :S && return F.S
737742
d == :T && return F.T

base/linalg/lapack.jl

+146-24
Original file line numberDiff line numberDiff line change
@@ -1245,8 +1245,8 @@ for (geevx, ggev, elty) in
12451245
chkstride1(A,B)
12461246
n, m = chksquare(A,B)
12471247
n==m || throw(DimensionMismatch("matrices must have same size"))
1248-
lda = max(1, n)
1249-
ldb = max(1, n)
1248+
lda = max(1, stride(A, 2))
1249+
ldb = max(1, stride(B, 2))
12501250
alphar = similar(A, $elty, n)
12511251
alphai = similar(A, $elty, n)
12521252
beta = similar(A, $elty, n)
@@ -1351,7 +1351,8 @@ for (geevx, ggev, elty, relty) in
13511351
chkstride1(A, B)
13521352
n, m = chksquare(A, B)
13531353
n==m || throw(DimensionMismatch("matrices must have same size"))
1354-
lda = ldb = max(1, n)
1354+
lda = max(1, stride(A, 2))
1355+
ldb = max(1, stride(B, 2))
13551356
alpha = similar(A, $elty, n)
13561357
beta = similar(A, $elty, n)
13571358
ldvl = jobvl == 'V' ? n : 1
@@ -2920,7 +2921,8 @@ for (syev, syevr, sygvd, elty) in
29202921
chkstride1(A, B)
29212922
n, m = chksquare(A, B)
29222923
n==m || throw(DimensionMismatch("Matrices must have same size"))
2923-
lda = ldb = max(1, n)
2924+
lda = max(1, stride(A, 2))
2925+
ldb = max(1, stride(B, 2))
29242926
w = similar(A, $elty, n)
29252927
work = Array($elty, 1)
29262928
lwork = -one(BlasInt)
@@ -3071,7 +3073,8 @@ for (syev, syevr, sygvd, elty, relty) in
30713073
chkstride1(A, B)
30723074
n, m = chksquare(A, B)
30733075
n==m || throw(DimensionMismatch("Matrices must have same size"))
3074-
lda = ldb = max(1, n)
3076+
lda = max(1, stride(A, 2))
3077+
ldb = max(1, stride(B, 2))
30753078
w = similar(A, $relty, n)
30763079
work = Array($elty, 1)
30773080
lwork = -one(BlasInt)
@@ -3307,7 +3310,7 @@ for (gehrd, elty) in
33073310
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
33083311
Ptr{BlasInt}),
33093312
&n, &ilo, &ihi, A,
3310-
&max(1,n), tau, work, &lwork,
3313+
&max(1, stride(A, 2)), tau, work, &lwork,
33113314
info)
33123315
@lapackerror
33133316
if lwork < 0
@@ -3346,7 +3349,7 @@ for (orghr, elty) in
33463349
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
33473350
Ptr{BlasInt}),
33483351
&n, &ilo, &ihi, A,
3349-
&max(1,n), tau, work, &lwork,
3352+
&max(1, stride(A, 2)), tau, work, &lwork,
33503353
info)
33513354
@lapackerror
33523355
if lwork < 0
@@ -3389,7 +3392,7 @@ for (gees, gges, elty) in
33893392
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
33903393
Ptr{BlasInt}, Ptr{Void}, Ptr{BlasInt}),
33913394
&jobvs, &'N', C_NULL, &n,
3392-
A, &max(1, n), sdim, wr,
3395+
A, &max(1, stride(A, 2)), sdim, wr,
33933396
wi, vs, &ldvs, work,
33943397
&lwork, C_NULL, info)
33953398
@lapackerror
@@ -3433,8 +3436,8 @@ for (gees, gges, elty) in
34333436
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{Void},
34343437
Ptr{BlasInt}),
34353438
&jobvsl, &jobvsr, &'N', C_NULL,
3436-
&n, A, &max(1,n), B,
3437-
&max(1,n), &sdim, alphar, alphai,
3439+
&n, A, &max(1,stride(A, 2)), B,
3440+
&max(1,stride(B, 2)), &sdim, alphar, alphai,
34383441
beta, vsl, &ldvsl, vsr,
34393442
&ldvsr, work, &lwork, C_NULL,
34403443
info)
@@ -3479,7 +3482,7 @@ for (gees, gges, elty, relty) in
34793482
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
34803483
Ptr{$relty}, Ptr{Void}, Ptr{BlasInt}),
34813484
&jobvs, &sort, C_NULL, &n,
3482-
A, &max(1, n), &sdim, w,
3485+
A, &max(1, stride(A, 2)), &sdim, w,
34833486
vs, &ldvs, work, &lwork,
34843487
rwork, C_NULL, info)
34853488
@lapackerror
@@ -3524,8 +3527,8 @@ for (gees, gges, elty, relty) in
35243527
Ptr{$elty}, Ptr{BlasInt}, Ptr{$relty}, Ptr{Void},
35253528
Ptr{BlasInt}),
35263529
&jobvsl, &jobvsr, &'N', C_NULL,
3527-
&n, A, &max(1,n), B,
3528-
&max(1,n), &sdim, alpha, beta,
3530+
&n, A, &max(1, stride(A, 2)), B,
3531+
&max(1, stride(B, 2)), &sdim, alpha, beta,
35293532
vsl, &ldvsl, vsr, &ldvsr,
35303533
work, &lwork, rwork, C_NULL,
35313534
info)
@@ -3540,9 +3543,9 @@ for (gees, gges, elty, relty) in
35403543
end
35413544
end
35423545
# Reorder Schur forms
3543-
for (trsen, elty) in
3544-
((:dtrsen_,:Float64),
3545-
(:strsen_,:Float32))
3546+
for (trsen, tgsen, elty) in
3547+
((:dtrsen_, :dtgsen_, :Float64),
3548+
(:strsen_, :stgsen_, :Float32))
35463549
@eval begin
35473550
function trsen!(select::Array{Int}, T::StridedMatrix{$elty}, Q::StridedMatrix{$elty})
35483551
# * .. Scalar Arguments ..
@@ -3556,7 +3559,8 @@ for (trsen, elty) in
35563559
# DOUBLE PRECISION Q( LDQ, * ), T( LDT, * ), WI( * ), WORK( * ), WR( * )
35573560
chkstride1(T, Q)
35583561
n = chksquare(T)
3559-
ld = max(1, n)
3562+
ldt = max(1, stride(T, 2))
3563+
ldq = max(1, stride(Q, 2))
35603564
wr = similar(T, $elty, n)
35613565
wi = similar(T, $elty, n)
35623566
m = sum(select)
@@ -3572,10 +3576,10 @@ for (trsen, elty) in
35723576
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasInt}, Ptr{BlasInt},
35733577
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
35743578
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{Void}, Ptr{Void},
3575-
Ptr{$elty}, Ptr {BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
3579+
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
35763580
Ptr{BlasInt}),
35773581
&'N', &'V', select, &n,
3578-
T, &ld, Q, &ld,
3582+
T, &ldt, Q, &ldq,
35793583
wr, wi, &m, C_NULL, C_NULL,
35803584
work, &lwork, iwork, &liwork,
35813585
info)
@@ -3589,12 +3593,71 @@ for (trsen, elty) in
35893593
end
35903594
T, Q, all(wi .== 0) ? wr : complex(wr, wi)
35913595
end
3596+
function tgsen!(select::Array{Int}, S::StridedMatrix{$elty}, T::StridedMatrix{$elty},
3597+
Q::StridedMatrix{$elty}, Z::StridedMatrix{$elty})
3598+
# * .. Scalar Arguments ..
3599+
# * LOGICAL WANTQ, WANTZ
3600+
# * INTEGER IJOB, INFO, LDA, LDB, LDQ, LDZ, LIWORK, LWORK,
3601+
# * $ M, N
3602+
# * DOUBLE PRECISION PL, PR
3603+
# * ..
3604+
# * .. Array Arguments ..
3605+
# * LOGICAL SELECT( * )
3606+
# * INTEGER IWORK( * )
3607+
# * DOUBLE PRECISION A( LDA, * ), ALPHAI( * ), ALPHAR( * ),
3608+
# * $ B( LDB, * ), BETA( * ), DIF( * ), Q( LDQ, * ),
3609+
# * $ WORK( * ), Z( LDZ, * )
3610+
# * ..
3611+
chkstride1(S, T, Q, Z)
3612+
n, nt, nq, nz = chksquare(S, T, Q, Z)
3613+
n==nt==nq==nz || throw(DimensionMismatch("matrices are not of same size"))
3614+
lds = max(1, stride(S, 2))
3615+
ldt = max(1, stride(T, 2))
3616+
ldq = max(1, stride(Q, 2))
3617+
ldz = max(1, stride(Z, 2))
3618+
m = sum(select)
3619+
alphai = similar(T, $elty, n)
3620+
alphar = similar(T, $elty, n)
3621+
beta = similar(T, $elty, n)
3622+
lwork = blas_int(-1)
3623+
work = Array($elty, 1)
3624+
liwork = blas_int(-1)
3625+
iwork = Array(BlasInt, 1)
3626+
info = Array(BlasInt, 1)
3627+
select = convert(Array{BlasInt}, select)
3628+
3629+
for i = 1:2
3630+
ccall(($(blasfunc(tgsen)), liblapack), Void,
3631+
(Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
3632+
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
3633+
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
3634+
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
3635+
Ptr{BlasInt}, Ptr{Void}, Ptr{Void}, Ptr{Void},
3636+
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
3637+
Ptr{BlasInt}),
3638+
&0, &1, &1, select,
3639+
&n, S, &lds, T,
3640+
&ldt, alphar, alphai, beta,
3641+
Q, &ldq, Z, &ldz,
3642+
&m, C_NULL, C_NULL, C_NULL,
3643+
work, &lwork, iwork, &liwork,
3644+
info)
3645+
@lapackerror
3646+
if i == 1 # only estimated optimal lwork, liwork
3647+
lwork = blas_int(real(work[1]))
3648+
work = Array($elty, lwork)
3649+
liwork = blas_int(real(iwork[1]))
3650+
iwork = Array(BlasInt, liwork)
3651+
end
3652+
end
3653+
S, T, complex(alphar, alphai), beta, Q, Z
3654+
end
35923655
end
35933656
end
35943657

3595-
for (trsen, elty) in
3596-
((:ztrsen_,:Complex128),
3597-
(:ctrsen_,:Complex64))
3658+
for (trsen, tgsen, elty) in
3659+
((:ztrsen_, :ztgsen_, :Complex128),
3660+
(:ctrsen_, :ctgsen_, :Complex64))
35983661
@eval begin
35993662
function trsen!(select::Array{Int}, T::StridedMatrix{$elty}, Q::StridedMatrix{$elty})
36003663
# * .. Scalar Arguments ..
@@ -3607,7 +3670,8 @@ for (trsen, elty) in
36073670
# COMPLEX Q( LDQ, * ), T( LDT, * ), W( * ), WORK( * )
36083671
chkstride1(T, Q)
36093672
n = chksquare(T)
3610-
ld = max(1, n)
3673+
ldt = max(1, stride(T, 2))
3674+
ldq = max(1, stride(Q, 2))
36113675
w = similar(T, $elty, n)
36123676
m = sum(select)
36133677
work = Array($elty, 1)
@@ -3623,7 +3687,7 @@ for (trsen, elty) in
36233687
Ptr{$elty}, Ptr {BlasInt},
36243688
Ptr{BlasInt}),
36253689
&'N', &'V', select, &n,
3626-
T, &ld, Q, &ld,
3690+
T, &ldt, Q, &ldq,
36273691
w, &m, C_NULL, C_NULL,
36283692
work, &lwork,
36293693
info)
@@ -3635,6 +3699,64 @@ for (trsen, elty) in
36353699
end
36363700
T, Q, w
36373701
end
3702+
function tgsen!(select::Array{Int}, S::StridedMatrix{$elty}, T::StridedMatrix{$elty},
3703+
Q::StridedMatrix{$elty}, Z::StridedMatrix{$elty})
3704+
# * .. Scalar Arguments ..
3705+
# * LOGICAL WANTQ, WANTZ
3706+
# * INTEGER IJOB, INFO, LDA, LDB, LDQ, LDZ, LIWORK, LWORK,
3707+
# * $ M, N
3708+
# * DOUBLE PRECISION PL, PR
3709+
# * ..
3710+
# * .. Array Arguments ..
3711+
# * LOGICAL SELECT( * )
3712+
# * INTEGER IWORK( * )
3713+
# * DOUBLE PRECISION DIF( * )
3714+
# * COMPLEX*16 A( LDA, * ), ALPHA( * ), B( LDB, * ),
3715+
# * $ BETA( * ), Q( LDQ, * ), WORK( * ), Z( LDZ, * )
3716+
# * ..
3717+
chkstride1(S, T, Q, Z)
3718+
n, nt, nq, nz = chksquare(S, T, Q, Z)
3719+
n==nt==nq==nz || throw(DimensionMismatch("matrices are not of same size"))
3720+
lds = max(1, stride(S, 2))
3721+
ldt = max(1, stride(T, 2))
3722+
ldq = max(1, stride(Q, 2))
3723+
ldz = max(1, stride(Z, 2))
3724+
m = sum(select)
3725+
alpha = similar(T, $elty, n)
3726+
beta = similar(T, $elty, n)
3727+
lwork = blas_int(-1)
3728+
work = Array($elty, 1)
3729+
liwork = blas_int(-1)
3730+
iwork = Array(BlasInt, 1)
3731+
info = Array(BlasInt, 1)
3732+
select = convert(Array{BlasInt}, select)
3733+
3734+
for i = 1:2
3735+
ccall(($(blasfunc(tgsen)), liblapack), Void,
3736+
(Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
3737+
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
3738+
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
3739+
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
3740+
Ptr{BlasInt}, Ptr{Void}, Ptr{Void}, Ptr{Void},
3741+
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
3742+
Ptr{BlasInt}),
3743+
&0, &1, &1, select,
3744+
&n, S, &lds, T,
3745+
&ldt, alpha, beta,
3746+
Q, &ldq, Z, &ldz,
3747+
&m, C_NULL, C_NULL, C_NULL,
3748+
work, &lwork, iwork, &liwork,
3749+
info)
3750+
@lapackerror
3751+
if i == 1 # only estimated optimal lwork, liwork
3752+
lwork = blas_int(real(work[1]))
3753+
work = Array($elty, lwork)
3754+
liwork = blas_int(real(iwork[1]))
3755+
iwork = Array(BlasInt, liwork)
3756+
end
3757+
end
3758+
S, T, alpha, beta, Q, Z
3759+
end
36383760
end
36393761
end
36403762

0 commit comments

Comments
 (0)