Skip to content

Commit 6e9056e

Browse files
committed
Implement ReinterpretArray
This redoes `reinterpret` in julia rather than punning the memory of the actual array. The motivation for this is to avoid the API limitations of the current reinterpret implementation (Array only, preventing strong TBAA, alignment problems). The surface API essentially unchanged, though the shape argument to reinterpret is removed, since those concepts are now orthogonal. The return type from `reinterpret` is now `ReinterpretArray`, which implements the AbstractArray interface and does the reinterpreting lazily on demand. The compiler is able to fold away the abstraction and generate very tight IR: ``` julia> ar = reinterpret(Complex{Int64}, rand(Int64, 1000)); julia> typeof(ar) Base.ReinterpretArray{Complex{Int64},Int64,1,Array{Int64,1}} julia> f(ar) = @inbounds return ar[1] f (generic function with 1 method) julia> @code_llvm f(ar) ; Function f ; Location: REPL[2] define void @julia_f_63575({ i64, i64 } addrspace(11)* noalias nocapture sret, %jl_value_t addrspace(10)* dereferenceable(8)) #0 { top: ; Location: REPL[2]:1 ; Function getindex; { ; Location: reinterpretarray.jl:31 %2 = addrspacecast %jl_value_t addrspace(10)* %1 to %jl_value_t addrspace(11)* %3 = bitcast %jl_value_t addrspace(11)* %2 to %jl_value_t addrspace(10)* addrspace(11)* %4 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)* addrspace(11)* %3, align 8 %5 = addrspacecast %jl_value_t addrspace(10)* %4 to %jl_value_t addrspace(11)* %6 = bitcast %jl_value_t addrspace(11)* %5 to i64* addrspace(11)* %7 = load i64*, i64* addrspace(11)* %6, align 8 %8 = load i64, i64* %7, align 8 %9 = getelementptr i64, i64* %7, i64 1 %10 = load i64, i64* %9, align 8 %.sroa.0.0..sroa_idx = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 0 store i64 %8, i64 addrspace(11)* %.sroa.0.0..sroa_idx, align 8 %.sroa.3.0..sroa_idx13 = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 1 store i64 %10, i64 addrspace(11)* %.sroa.3.0..sroa_idx13, align 8 ;} ret void } julia> g(a) = @inbounds return reinterpret(Complex{Int64}, a)[1] g (generic function with 1 method) julia> @code_llvm g(randn(1000)) ; Function g ; Location: REPL[4] define void @julia_g_63642({ i64, i64 } addrspace(11)* noalias nocapture sret, %jl_value_t addrspace(10)* dereferenceable(40)) #0 { top: ; Location: REPL[4]:1 ; Function getindex; { ; Location: reinterpretarray.jl:31 %2 = addrspacecast %jl_value_t addrspace(10)* %1 to %jl_value_t addrspace(11)* %3 = bitcast %jl_value_t addrspace(11)* %2 to double* addrspace(11)* %4 = load double*, double* addrspace(11)* %3, align 8 %5 = bitcast double* %4 to i64* %6 = load i64, i64* %5, align 8 %7 = getelementptr double, double* %4, i64 1 %8 = bitcast double* %7 to i64* %9 = load i64, i64* %8, align 8 %.sroa.0.0..sroa_idx = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 0 store i64 %6, i64 addrspace(11)* %.sroa.0.0..sroa_idx, align 8 %.sroa.3.0..sroa_idx13 = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 1 store i64 %9, i64 addrspace(11)* %.sroa.3.0..sroa_idx13, align 8 ;} ret void } ``` In addition, the new `reinterpret` implementation is able to handle any AbstractArray (whether useful or not is a separate decision): ``` invoke(reinterpret, Tuple{Type{Complex{Float64}}, AbstractArray}, Complex{Float64}, speye(10)) 5×10 Base.ReinterpretArray{Complex{Float64},Float64,2,SparseMatrixCSC{Float64,Int64}}: 1.0+0.0im 0.0+1.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 1.0+0.0im 0.0+1.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 1.0+0.0im 0.0+1.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 1.0+0.0im 0.0+1.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 0.0+0.0im 1.0+0.0im 0.0+1.0im ``` The remaining todo is to audit the uses of reinterpret in base. I've fixed up the uses themselves, but there's code deeper in the array code that needs to be broadened to allow ReinterpretArray. Fixes #22849 Fixes #19238
1 parent e7a9389 commit 6e9056e

23 files changed

+214
-132
lines changed

base/array.jl

-27
Original file line numberDiff line numberDiff line change
@@ -214,33 +214,6 @@ original.
214214
"""
215215
copy(a::T) where {T<:Array} = ccall(:jl_array_copy, Ref{T}, (Any,), a)
216216

217-
function reinterpret(::Type{T}, a::Array{S,1}) where T where S
218-
nel = Int(div(length(a) * sizeof(S), sizeof(T)))
219-
# TODO: maybe check that remainder is zero?
220-
return reinterpret(T, a, (nel,))
221-
end
222-
223-
function reinterpret(::Type{T}, a::Array{S}) where T where S
224-
if sizeof(S) != sizeof(T)
225-
throw(ArgumentError("result shape not specified"))
226-
end
227-
reinterpret(T, a, size(a))
228-
end
229-
230-
function reinterpret(::Type{T}, a::Array{S}, dims::NTuple{N,Int}) where T where S where N
231-
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
232-
@_noinline_meta
233-
throw(ArgumentError("cannot reinterpret Array{$(S)} to ::Type{Array{$(T)}}, type $(U) is not a bits type"))
234-
end
235-
isbits(T) || throwbits(S, T, T)
236-
isbits(S) || throwbits(S, T, S)
237-
nel = div(length(a) * sizeof(S), sizeof(T))
238-
if prod(dims) != nel
239-
_throw_dmrsa(dims, nel)
240-
end
241-
ccall(:jl_reshape_array, Array{T,N}, (Any, Any, Any), Array{T,N}, a, dims)
242-
end
243-
244217
# reshaping to same # of dimensions
245218
function reshape(a::Array{T,N}, dims::NTuple{N,Int}) where T where N
246219
if prod(dims) != length(a)

base/essentials.jl

+1-9
Original file line numberDiff line numberDiff line change
@@ -313,20 +313,12 @@ unsafe_convert(::Type{P}, x::Ptr) where {P<:Ptr} = convert(P, x)
313313
reinterpret(type, A)
314314
315315
Change the type-interpretation of a block of memory.
316-
For arrays, this constructs an array with the same binary data as the given
316+
For arrays, this constructs a view of the array with the same binary data as the given
317317
array, but with the specified element type.
318318
For example,
319319
`reinterpret(Float32, UInt32(7))` interprets the 4 bytes corresponding to `UInt32(7)` as a
320320
[`Float32`](@ref).
321321
322-
!!! warning
323-
324-
It is not allowed to `reinterpret` an array to an element type with a larger alignment then
325-
the alignment of the array. For a normal `Array`, this is the alignment of its element type.
326-
For a reinterpreted array, this is the alignment of the `Array` it was reinterpreted from.
327-
For example, `reinterpret(UInt32, UInt8[0, 0, 0, 0])` is not allowed but
328-
`reinterpret(UInt32, reinterpret(UInt8, Float32[1.0]))` is allowed.
329-
330322
# Examples
331323
```jldoctest
332324
julia> reinterpret(Float32, UInt32(7))

base/io.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,16 @@ readlines(s=STDIN; chomp::Bool=true) = collect(eachline(s, chomp=chomp))
267267

268268
## byte-order mark, ntoh & hton ##
269269

270-
let endian_boms = reinterpret(UInt8, UInt32[0x01020304])
270+
a = UInt32[0x01020304]
271+
let endian_bom = unsafe_load(convert(Ptr{UInt8}, pointer(a)))
271272
global ntoh, hton, ltoh, htol
272-
if endian_boms == UInt8[1:4;]
273+
if endian_bom == 0x01
273274
ntoh(x) = x
274275
hton(x) = x
275276
ltoh(x) = bswap(x)
276277
htol(x) = bswap(x)
277278
const global ENDIAN_BOM = 0x01020304
278-
elseif endian_boms == UInt8[4:-1:1;]
279+
elseif endian_bom == 0x04
279280
ntoh(x) = bswap(x)
280281
hton(x) = bswap(x)
281282
ltoh(x) = x

base/linalg/factorization.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ Base.isequal(F::T, G::T) where {T<:Factorization} = all(f -> isequal(getfield(F,
5656
# With a real lhs and complex rhs with the same precision, we can reinterpret
5757
# the complex rhs as a real rhs with twice the number of columns
5858
function (\)(F::Factorization{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
59-
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
59+
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
6060
x = A_ldiv_B!(F, c2r)
61-
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)), _ret_size(F, B))
61+
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))), _ret_size(F, B))
6262
end
6363

6464
for (f1, f2) in ((:\, :A_ldiv_B!),

base/linalg/lq.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,10 @@ end
267267
# With a real lhs and complex rhs with the same precision, we can reinterpret
268268
# the complex rhs as a real rhs with twice the number of columns
269269
function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
270-
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
270+
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
271271
x = A_ldiv_B!(F, c2r)
272-
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)),
273-
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
272+
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))),
273+
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
274274
end
275275

276276

base/linalg/matmul.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ A_mul_B!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where
9090
for elty in (Float32,Float64)
9191
@eval begin
9292
function A_mul_B!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty})
93-
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
93+
Afl = reinterpret($elty,A)
9494
yfl = reinterpret($elty,y)
9595
gemv!(yfl,'N',Afl,x)
9696
return y
@@ -148,8 +148,8 @@ A_mul_B!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) wher
148148
for elty in (Float32,Float64)
149149
@eval begin
150150
function A_mul_B!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
151-
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
152-
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
151+
Afl = reinterpret($elty, A)
152+
Cfl = reinterpret($elty, C)
153153
gemm_wrapper!(Cfl, 'N', 'N', Afl, B)
154154
return C
155155
end
@@ -190,8 +190,8 @@ A_mul_Bt!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) whe
190190
for elty in (Float32,Float64)
191191
@eval begin
192192
function A_mul_Bt!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
193-
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
194-
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
193+
Afl = reinterpret($elty, A)
194+
Cfl = reinterpret($elty, C)
195195
gemm_wrapper!(Cfl, 'N', 'T', Afl, B)
196196
return C
197197
end

base/linalg/qr.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -923,15 +923,15 @@ function (\)(A::Union{QR{T},QRCompactWY{T},QRPivoted{T}}, BIn::VecOrMat{Complex{
923923
# |z2|z4| -> |y1|y2|y3|y4| -> |x2|y2| -> |x2|y2|x4|y4|
924924
# |x3|y3|
925925
# |x4|y4|
926-
B = reshape(transpose(reinterpret(T, BIn, (2, length(BIn)))), size(BIn, 1), 2*size(BIn, 2))
926+
B = reshape(transpose(reinterpret(T, reshape(BIn, (1, length(BIn))))), size(BIn, 1), 2*size(BIn, 2))
927927

928928
X = A_ldiv_B!(A, _append_zeros(B, T, n))
929929

930930
# |z1|z3| reinterpret |x1|x2|x3|x4| transpose |x1|y1| reshape |x1|y1|x3|y3|
931931
# |z2|z4| <- |y1|y2|y3|y4| <- |x2|y2| <- |x2|y2|x4|y4|
932932
# |x3|y3|
933933
# |x4|y4|
934-
XX = reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)), _ret_size(A, BIn))
934+
XX = reshape(collect(reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)))), _ret_size(A, BIn))
935935
return _cut_B(XX, 1:n)
936936
end
937937

base/random/dSFMT.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
104104
val = s.val
105105
nval = length(val)
106106
index = val[nval - 1]
107-
work = zeros(UInt64, JN32 >> 1)
107+
work = zeros(Int32, JN32)
108+
rwork = reinterpret(UInt64, work)
108109
dsfmt = Vector{UInt64}(nval >> 1)
109110
ccall(:memcpy, Ptr{Void}, (Ptr{UInt64}, Ptr{Int32}, Csize_t),
110111
dsfmt, val, (nval - 1) * sizeof(Int32))
@@ -113,17 +114,17 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
113114
for c in jp
114115
bits = parse(UInt8,c,16)
115116
for j in 1:4
116-
(bits & 0x01) != 0x00 && dsfmt_jump_add!(work, dsfmt)
117+
(bits & 0x01) != 0x00 && dsfmt_jump_add!(rwork, dsfmt)
117118
bits = bits >> 0x01
118119
dsfmt_jump_next_state!(dsfmt)
119120
end
120121
end
121122

122-
work[end] = index
123-
return DSFMT_state(reinterpret(Int32, work))
123+
rwork[end] = index
124+
return DSFMT_state(work)
124125
end
125126

126-
function dsfmt_jump_add!(dest::Vector{UInt64}, src::Vector{UInt64})
127+
function dsfmt_jump_add!(dest::AbstractVector{UInt64}, src::Vector{UInt64})
127128
dp = dest[end] >> 1
128129
sp = src[end] >> 1
129130
diff = ((sp - dp + N) % N)

base/reinterpretarray.jl

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Gives a reinterpreted view (of element type T) of the underlying array (of element type S).
3+
If the size of `T` differs from the size of `S`, the array will be compressed/expanded in
4+
the first dimension.
5+
"""
6+
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
7+
parent::A
8+
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
9+
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
10+
@_noinline_meta
11+
throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type"))
12+
end
13+
function throwsize0(::Type{S}, ::Type{T})
14+
@_noinline_meta
15+
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size"))
16+
end
17+
function thrownonint(::Type{S}, ::Type{T}, dim)
18+
@_noinline_meta
19+
throw(ArgumentError("""
20+
cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`.
21+
The resulting array would have non-integral first dimension.
22+
"""))
23+
end
24+
isbits(T) || throwbits(S, T, T)
25+
isbits(S) || throwbits(S, T, S)
26+
(N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T)
27+
if N != 0 && sizeof(S) != sizeof(T)
28+
dim = size(a)[1]
29+
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
30+
end
31+
new{T, N, S, A}(a)
32+
end
33+
end
34+
35+
parent(a::ReinterpretArray) = a.parent
36+
37+
eltype(a::ReinterpretArray{T}) where {T} = T
38+
function size(a::ReinterpretArray{T,N,S} where {N}) where {T,S}
39+
psize = size(a.parent)
40+
size1 = div(psize[1]*sizeof(S), sizeof(T))
41+
tuple(size1, tail(psize)...)
42+
end
43+
44+
unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))
45+
46+
@inline @propagate_inbounds getindex(a::ReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
47+
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]
48+
49+
@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
50+
if sizeof(T) == sizeof(S)
51+
return reinterpret(T, a.parent[inds...])
52+
else
53+
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
54+
t = Ref{T}()
55+
s = Ref{S}()
56+
@gc_preserve t s begin
57+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
58+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
59+
i = 1
60+
nbytes_copied = 0
61+
# This is a bit complicated to deal with partial elements
62+
# at both the start and the end. LLVM will fold as appropriate,
63+
# once it knows the data layout
64+
while nbytes_copied < sizeof(T)
65+
s[] = a.parent[ind_start + i, tail(inds)...]
66+
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
67+
unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1)
68+
sidx += 1
69+
nbytes_copied += 1
70+
end
71+
sidx = 0
72+
i += 1
73+
end
74+
end
75+
return t[]
76+
end
77+
end
78+
79+
@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
80+
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)
81+
82+
@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
83+
v = convert(T, v)::T
84+
if sizeof(T) == sizeof(S)
85+
return setindex!(a.parent, reinterpret(S, v), inds...)
86+
else
87+
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
88+
t = Ref{T}(v)
89+
s = Ref{S}()
90+
@gc_preserve t s begin
91+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
92+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
93+
nbytes_copied = 0
94+
i = 1
95+
@inline function copy_element()
96+
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
97+
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
98+
sidx += 1
99+
nbytes_copied += 1
100+
end
101+
end
102+
# Deal with any partial elements at the start. We'll have to copy in the
103+
# element from the original array and overwrite the relevant parts
104+
if sidx != 0
105+
s[] = a.parent[ind_start + i, tail(inds)...]
106+
copy_element()
107+
a.parent[ind_start + i, tail(inds)...] = s[]
108+
i += 1
109+
sidx = 0
110+
end
111+
# Deal with the main body of elements
112+
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
113+
copy_element()
114+
a.parent[ind_start + i, tail(inds)...] = s[]
115+
i += 1
116+
sidx = 0
117+
end
118+
# Deal with trailing partial elements
119+
if nbytes_copied < sizeof(T)
120+
s[] = a.parent[ind_start + i, tail(inds)...]
121+
copy_element()
122+
a.parent[ind_start + i, tail(inds)...] = s[]
123+
end
124+
end
125+
end
126+
return a
127+
end

base/show.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,12 @@ function showarg(io::IO, r::ReshapedArray, toplevel)
18881888
toplevel && print(io, " with eltype ", eltype(r))
18891889
end
18901890

1891+
function showarg(io::IO, r::ReinterpretArray{T}, toplevel) where {T}
1892+
print(io, "reinterpret($T, ")
1893+
showarg(io, parent(r), false)
1894+
print(io, ')')
1895+
end
1896+
18911897
# n-dimensional arrays
18921898
function show_nd(io::IO, a::AbstractArray, print_matrix, label_slices)
18931899
limit::Bool = get(io, :limit, false)

base/sparse/abstractsparse.jl

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
abstract type AbstractSparseArray{Tv,Ti,N} <: AbstractArray{Tv,N} end
44

5-
const AbstractSparseVector{Tv,Ti} = AbstractSparseArray{Tv,Ti,1}
5+
const AbstractSparseVector{Tv,Ti} = Union{AbstractSparseArray{Tv,Ti,1}, Base.ReinterpretArray{Tv,1,T,<:AbstractSparseArray{T,Ti,1}} where T}
66
const AbstractSparseMatrix{Tv,Ti} = AbstractSparseArray{Tv,Ti,2}
77

88
"""
@@ -19,5 +19,21 @@ issparse(S::LowerTriangular{<:Any,<:AbstractSparseMatrix}) = true
1919
issparse(S::LinAlg.UnitLowerTriangular{<:Any,<:AbstractSparseMatrix}) = true
2020
issparse(S::UpperTriangular{<:Any,<:AbstractSparseMatrix}) = true
2121
issparse(S::LinAlg.UnitUpperTriangular{<:Any,<:AbstractSparseMatrix}) = true
22+
issparse(S::Base.ReinterpretArray) = issparse(S.parent)
2223

2324
indtype(S::AbstractSparseArray{<:Any,Ti}) where {Ti} = Ti
25+
26+
nonzeros(A::Base.ReinterpretArray{T}) where {T} = reinterpret(T, nonzeros(A.parent))
27+
function nonzeroinds(A::Base.ReinterpretArray{T,N,S} where {N}) where {T,S}
28+
if sizeof(T) == sizeof(S)
29+
return nonzeroinds(A.parent)
30+
elseif sizeof(T) > sizeof(S)
31+
unique(map(nonzeroinds(A.parent)) do ind
32+
div(ind, div(sizeof(T), sizeof(S)))
33+
end)
34+
else
35+
map(nonzeroinds(A.parent)) do ind
36+
ind * div(sizeof(S), sizeof(T))
37+
end
38+
end
39+
end

base/sparse/sparsematrix.jl

-30
Original file line numberDiff line numberDiff line change
@@ -212,17 +212,6 @@ end
212212

213213
## Reinterpret and Reshape
214214

215-
function reinterpret(::Type{T}, a::SparseMatrixCSC{Tv}) where {T,Tv}
216-
if sizeof(T) != sizeof(Tv)
217-
throw(ArgumentError("SparseMatrixCSC reinterpret is only supported for element types of the same size"))
218-
end
219-
mA, nA = size(a)
220-
colptr = copy(a.colptr)
221-
rowval = copy(a.rowval)
222-
nzval = reinterpret(T, a.nzval)
223-
return SparseMatrixCSC(mA, nA, colptr, rowval, nzval)
224-
end
225-
226215
function sparse_compute_reshaped_colptr_and_rowval(colptrS::Vector{Ti}, rowvalS::Vector{Ti},
227216
mS::Int, nS::Int, colptrA::Vector{Ti},
228217
rowvalA::Vector{Ti}, mA::Int, nA::Int) where Ti
@@ -257,25 +246,6 @@ function sparse_compute_reshaped_colptr_and_rowval(colptrS::Vector{Ti}, rowvalS:
257246
end
258247
end
259248

260-
function reinterpret(::Type{T}, a::SparseMatrixCSC{Tv,Ti}, dims::NTuple{N,Int}) where {T,Tv,Ti,N}
261-
if sizeof(T) != sizeof(Tv)
262-
throw(ArgumentError("SparseMatrixCSC reinterpret is only supported for element types of the same size"))
263-
end
264-
if prod(dims) != length(a)
265-
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $(length(a))"))
266-
end
267-
mS,nS = dims
268-
mA,nA = size(a)
269-
numnz = nnz(a)
270-
colptr = Vector{Ti}(nS+1)
271-
rowval = similar(a.rowval)
272-
nzval = reinterpret(T, a.nzval)
273-
274-
sparse_compute_reshaped_colptr_and_rowval(colptr, rowval, mS, nS, a.colptr, a.rowval, mA, nA)
275-
276-
return SparseMatrixCSC(mS, nS, colptr, rowval, nzval)
277-
end
278-
279249
function copy(ra::ReshapedArray{<:Any,2,<:SparseMatrixCSC})
280250
mS,nS = size(ra)
281251
a = parent(ra)

0 commit comments

Comments
 (0)