From 8fcbec2b034f99de343714754804b172e4cbe854 Mon Sep 17 00:00:00 2001
From: Fredrik Ekre <fredrik.ekre@chalmers.se>
Date: Sun, 9 Jul 2017 03:32:53 +0200
Subject: [PATCH] parametrize Diagonal on the wrapped vector type

---
 NEWS.md                 |  6 ++++++
 base/linalg/diagonal.jl | 27 ++++++++++++++-------------
 test/linalg/diagonal.jl |  6 +++---
 test/show.jl            |  2 +-
 4 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/NEWS.md b/NEWS.md
index 1c356a33be8b5..3bd9aaab6716b 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -60,6 +60,9 @@ This section lists changes that do not have deprecation warnings.
     longer present. Use `first(R)` and `last(R)` to obtain
     start/stop. ([#20974])
 
+  * The `Diagonal` type definition has changed from `Diagonal{T}` to
+    `Diagonal{T,V<:AbstractVector{T}}` ([#22718]).
+
 Library improvements
 --------------------
 
@@ -110,6 +113,9 @@ Library improvements
 
   * `Char`s can now be concatenated with `String`s and/or other `Char`s using `*` ([#22532]).
 
+  * `Diagonal` is now parameterized on the type of the wrapped vector. This allows
+    for `Diagonal` matrices with arbitrary `AbstractVector`s ([#22718]).
+
 Compiler/Runtime improvements
 -----------------------------
 
diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl
index b93f3a0232cd9..f146773d7e9b0 100644
--- a/base/linalg/diagonal.jl
+++ b/base/linalg/diagonal.jl
@@ -2,16 +2,15 @@
 
 ## Diagonal matrices
 
-struct Diagonal{T} <: AbstractMatrix{T}
-    diag::Vector{T}
+struct Diagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T}
+    diag::V
 end
 """
     Diagonal(A::AbstractMatrix)
 
-Constructs a matrix from the diagonal of `A`.
-
-# Example
+Construct a matrix from the diagonal of `A`.
 
+# Examples
 ```jldoctest
 julia> A = [1 2 3; 4 5 6; 7 8 9]
 3×3 Array{Int64,2}:
@@ -20,36 +19,38 @@ julia> A = [1 2 3; 4 5 6; 7 8 9]
  7  8  9
 
 julia> Diagonal(A)
-3×3 Diagonal{Int64}:
+3×3 Diagonal{Int64,Array{Int64,1}}:
  1  ⋅  ⋅
  ⋅  5  ⋅
  ⋅  ⋅  9
 ```
 """
 Diagonal(A::AbstractMatrix) = Diagonal(diag(A))
+
 """
     Diagonal(V::AbstractVector)
 
-Constructs a matrix with `V` as its diagonal.
-
-# Example
+Construct a matrix with `V` as its diagonal.
 
+# Examples
 ```jldoctest
-julia> V = [1; 2]
+julia> V = [1, 2]
 2-element Array{Int64,1}:
  1
  2
 
 julia> Diagonal(V)
-2×2 Diagonal{Int64}:
+2×2 Diagonal{Int64,Array{Int64,1}}:
  1  ⋅
  ⋅  2
 ```
 """
-Diagonal(V::AbstractVector) = Diagonal(collect(V))
+Diagonal(V::AbstractVector{T}) where {T} = Diagonal{T,typeof(V)}(V)
+Diagonal{T}(V::AbstractVector{T}) where {T} = Diagonal{T,typeof(V)}(V)
+Diagonal{T}(V::AbstractVector) where {T} = Diagonal{T}(convert(AbstractVector{T}, V))
 
 convert(::Type{Diagonal{T}}, D::Diagonal{T}) where {T} = D
-convert(::Type{Diagonal{T}}, D::Diagonal) where {T} = Diagonal{T}(convert(Vector{T}, D.diag))
+convert(::Type{Diagonal{T}}, D::Diagonal) where {T} = Diagonal{T}(convert(AbstractVector{T}, D.diag))
 convert(::Type{AbstractMatrix{T}}, D::Diagonal) where {T} = convert(Diagonal{T}, D)
 convert(::Type{Matrix}, D::Diagonal) = diagm(D.diag)
 convert(::Type{Array}, D::Diagonal) = convert(Matrix, D)
diff --git a/test/linalg/diagonal.jl b/test/linalg/diagonal.jl
index 64c6de02af80b..40194a5f20a39 100644
--- a/test/linalg/diagonal.jl
+++ b/test/linalg/diagonal.jl
@@ -21,8 +21,8 @@ srand(1)
     @testset "Basic properties" begin
         @test eye(Diagonal{elty},n) == Diagonal(ones(elty,n))
         @test_throws ArgumentError size(D,0)
-        @test typeof(convert(Diagonal{Complex64},D)) == Diagonal{Complex64}
-        @test typeof(convert(AbstractMatrix{Complex64},D))   == Diagonal{Complex64}
+        @test typeof(convert(Diagonal{Complex64},D)) <: Diagonal{Complex64}
+        @test typeof(convert(AbstractMatrix{Complex64},D)) <: Diagonal{Complex64}
 
         @test Array(real(D)) == real(DM)
         @test Array(abs.(D)) == abs.(DM)
@@ -312,7 +312,7 @@ end
 end
 
 # allow construct from range
-@test Diagonal(linspace(1,3,3)) == Diagonal([1.,2.,3.])
+@test all(Diagonal(linspace(1,3,3)) .== Diagonal([1.0,2.0,3.0]))
 
 # Issue 12803
 for t in (Float32, Float64, Int, Complex{Float64}, Rational{Int})
diff --git a/test/show.jl b/test/show.jl
index e53b98ba63130..1dd14e81c2b89 100644
--- a/test/show.jl
+++ b/test/show.jl
@@ -547,7 +547,7 @@ end
 
 # test structured zero matrix printing for select structured types
 A = reshape(1:16,4,4)
-@test replstr(Diagonal(A)) == "4×4 Diagonal{$Int}:\n 1  ⋅   ⋅   ⋅\n ⋅  6   ⋅   ⋅\n ⋅  ⋅  11   ⋅\n ⋅  ⋅   ⋅  16"
+@test replstr(Diagonal(A)) == "4×4 Diagonal{$(Int),Array{$(Int),1}}:\n 1  ⋅   ⋅   ⋅\n ⋅  6   ⋅   ⋅\n ⋅  ⋅  11   ⋅\n ⋅  ⋅   ⋅  16"
 @test replstr(Bidiagonal(A,:U)) == "4×4 Bidiagonal{$Int}:\n 1  5   ⋅   ⋅\n ⋅  6  10   ⋅\n ⋅  ⋅  11  15\n ⋅  ⋅   ⋅  16"
 @test replstr(Bidiagonal(A,:L)) == "4×4 Bidiagonal{$Int}:\n 1  ⋅   ⋅   ⋅\n 2  6   ⋅   ⋅\n ⋅  7  11   ⋅\n ⋅  ⋅  12  16"
 @test replstr(SymTridiagonal(A+A')) == "4×4 SymTridiagonal{$Int}:\n 2   7   ⋅   ⋅\n 7  12  17   ⋅\n ⋅  17  22  27\n ⋅   ⋅  27  32"