|
| 1 | +""" |
| 2 | +Gives a reinterpreted view (of element type T) of the underlying array (of element type S). |
| 3 | +If the size of `T` differs from the size of `S`, the array will be compressed/expanded in |
| 4 | +the first dimension. |
| 5 | +""" |
| 6 | +struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N} |
| 7 | + parent::A |
| 8 | + function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}} |
| 9 | + function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U} |
| 10 | + @_noinline_meta |
| 11 | + throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type")) |
| 12 | + end |
| 13 | + function throwsize0(::Type{S}, ::Type{T}) |
| 14 | + @_noinline_meta |
| 15 | + throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size")) |
| 16 | + end |
| 17 | + function thrownonint(::Type{S}, ::Type{T}, dim) |
| 18 | + @_noinline_meta |
| 19 | + throw(ArgumentError(""" |
| 20 | + cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`. |
| 21 | + The resulting array would have non-integral first dimension. |
| 22 | + """)) |
| 23 | + end |
| 24 | + isbits(T) || throwbits(S, T, T) |
| 25 | + isbits(S) || throwbits(S, T, S) |
| 26 | + (N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T) |
| 27 | + if N != 0 && sizeof(S) != sizeof(T) |
| 28 | + dim = size(a)[1] |
| 29 | + rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim) |
| 30 | + end |
| 31 | + new{T, N, S, A}(a) |
| 32 | + end |
| 33 | +end |
| 34 | + |
| 35 | +parent(a::ReinterpretArray) = a.parent |
| 36 | + |
| 37 | +eltype(a::ReinterpretArray{T}) where {T} = T |
| 38 | +function size(a::ReinterpretArray{T,N,S} where {N}) where {T,S} |
| 39 | + psize = size(a.parent) |
| 40 | + size1 = div(psize[1]*sizeof(S), sizeof(T)) |
| 41 | + tuple(size1, tail(psize)...) |
| 42 | +end |
| 43 | + |
| 44 | +unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent)) |
| 45 | + |
| 46 | +@inline @propagate_inbounds getindex(a::ReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[]) |
| 47 | +@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1] |
| 48 | + |
| 49 | +@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S} |
| 50 | + if sizeof(T) == sizeof(S) |
| 51 | + return reinterpret(T, a.parent[inds...]) |
| 52 | + else |
| 53 | + ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S)) |
| 54 | + t = Ref{T}() |
| 55 | + s = Ref{S}() |
| 56 | + @gc_preserve t s begin |
| 57 | + tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t)) |
| 58 | + sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s)) |
| 59 | + i = 1 |
| 60 | + nbytes_copied = 0 |
| 61 | + # This is a bit complicated to deal with partial elements |
| 62 | + # at both the start and the end. LLVM will fold as appropriate, |
| 63 | + # once it knows the data layout |
| 64 | + while nbytes_copied < sizeof(T) |
| 65 | + s[] = a.parent[ind_start + i, tail(inds)...] |
| 66 | + while nbytes_copied < sizeof(T) && sidx < sizeof(S) |
| 67 | + unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1) |
| 68 | + sidx += 1 |
| 69 | + nbytes_copied += 1 |
| 70 | + end |
| 71 | + sidx = 0 |
| 72 | + i += 1 |
| 73 | + end |
| 74 | + end |
| 75 | + return t[] |
| 76 | + end |
| 77 | +end |
| 78 | + |
| 79 | +@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v)) |
| 80 | +@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v) |
| 81 | + |
| 82 | +@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S} |
| 83 | + v = convert(T, v)::T |
| 84 | + if sizeof(T) == sizeof(S) |
| 85 | + return setindex!(a.parent, reinterpret(S, v), inds...) |
| 86 | + else |
| 87 | + ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S)) |
| 88 | + t = Ref{T}(v) |
| 89 | + s = Ref{S}() |
| 90 | + @gc_preserve t s begin |
| 91 | + tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t)) |
| 92 | + sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s)) |
| 93 | + nbytes_copied = 0 |
| 94 | + i = 1 |
| 95 | + @inline function copy_element() |
| 96 | + while nbytes_copied < sizeof(T) && sidx < sizeof(S) |
| 97 | + unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1) |
| 98 | + sidx += 1 |
| 99 | + nbytes_copied += 1 |
| 100 | + end |
| 101 | + end |
| 102 | + # Deal with any partial elements at the start. We'll have to copy in the |
| 103 | + # element from the original array and overwrite the relevant parts |
| 104 | + if sidx != 0 |
| 105 | + s[] = a.parent[ind_start + i, tail(inds)...] |
| 106 | + copy_element() |
| 107 | + a.parent[ind_start + i, tail(inds)...] = s[] |
| 108 | + i += 1 |
| 109 | + sidx = 0 |
| 110 | + end |
| 111 | + # Deal with the main body of elements |
| 112 | + while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S) |
| 113 | + copy_element() |
| 114 | + a.parent[ind_start + i, tail(inds)...] = s[] |
| 115 | + i += 1 |
| 116 | + sidx = 0 |
| 117 | + end |
| 118 | + # Deal with trailing partial elements |
| 119 | + if nbytes_copied < sizeof(T) |
| 120 | + s[] = a.parent[ind_start + i, tail(inds)...] |
| 121 | + copy_element() |
| 122 | + a.parent[ind_start + i, tail(inds)...] = s[] |
| 123 | + end |
| 124 | + end |
| 125 | + end |
| 126 | + return a |
| 127 | +end |
0 commit comments